refactor: 分页 KV cache 替换固定 slot,删除 PrefixCache 及相关死代码

- 用 PagedCache + CacheView 替换固定 slot 式 KV cache,attention 层只通过 page_table 间接索引
- 删除 PrefixCache(radix tree)及 scheduler 中所有 prefix cache 命中/插入/释放逻辑
- 删除无用函数:pin、version、free_count、_mark_seq_mask 及 seq_mask 分配
- 修复 write 在多页 prefill 时 offset 为负导致 chunk 计算错误
- _make_page_table_tensor 改用 list 拼接一次 tensor,去掉逐元素赋值
- 清理 model 接口参数:kv_cache, slot_indices → paged_cache(CacheView)
- 精简 docstring 为单行,删除冗余 section 注释和旧代码
- 修复 test_scheduler_concurrency.py 缺少 import pytest
This commit is contained in:
ViperEkura 2026-05-08 20:44:05 +08:00
parent 7ddebf2cd9
commit 30cc2d67a4
7 changed files with 244 additions and 777 deletions

View File

@ -1,241 +1,135 @@
"""KV cache slot allocation and prefix cache management.
"""Page-based KV cache with page-table-indirected read/write.
Provides:
- SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask.
- PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse.
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
import time
from collections import OrderedDict
from typing import Dict, List, Tuple
from typing import List, Tuple
import torch
from torch import Tensor
_STOP = object()
STOP = object()
class _RadixNode:
"""Internal node for the radix tree prefix cache.
class PagedCache:
"""Paged KV cache with page-table-indirected read/write.
Attributes:
children: Mapping from token ID to child node.
slot: KV cache slot index for the prefix ending at this node.
slot_ver: Version counter of the slot at insertion time.
ref_count: Number of tasks currently referencing this node.
last_access: Timestamp of the most recent access (for LRU ordering).
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
Call :meth:`bind` to obtain a batch view for the attention layers.
"""
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
def __init__(self):
self.children: Dict[int, "_RadixNode"] = {}
self.slot: int = -1
self.slot_ver: int = 0
self.ref_count: int = 0
self.last_access: float = 0.0
class SlotAllocator:
"""KV cache slot allocator using bitmask for O(1) alloc/free.
Implements the Object Pool pattern: pre-allocated KV cache slots
are managed via a bitmask, providing constant-time allocation and
deallocation with version counters for staleness detection.
"""
def __init__(self, max_slots: int):
self._max_slots = max_slots
self._free_mask = (1 << max_slots) - 1
self._versions: List[int] = [0] * max_slots
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
self.v_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
dtype=dtype,
)
def alloc(self) -> int:
"""Allocates a free slot.
Returns:
Slot index on success, -1 if all slots are occupied.
"""
lsb = self._free_mask & -self._free_mask
if lsb == 0:
return -1
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._versions[idx] += 1
self._refs[idx] = 1
return idx
def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)]
if any(p < 0 for p in pages):
for p in pages:
if p >= 0:
self.free(p)
return []
return pages
def free(self, idx: int) -> None:
"""Releases a slot back to the free pool."""
self._free_mask |= 1 << idx
self._refs[idx] -= 1
if self._refs[idx] == 0:
self._free_mask |= 1 << idx
def occupy(self, idx: int) -> None:
"""Marks a currently free slot as occupied without bumping its version.
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
Used for direct slot reuse when a prefix-cached slot is still valid.
"""
self._free_mask ^= 1 << idx
def write(
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
) -> None:
seq_len = k.size(1)
if seq_len == 0:
return
page_size = self.page_size
written = 0
first_page = start_pos // page_size
last_page = (start_pos + seq_len - 1) // page_size
for pi in range(first_page, last_page + 1):
phys_pages = page_table[:, pi]
page_start = pi * page_size
write_start = max(page_start, start_pos)
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def is_free(self, idx: int) -> bool:
"""Checks whether a slot is currently free."""
return (self._free_mask >> idx) & 1 == 1
def version(self, idx: int) -> int:
"""Returns the current version counter for a slot."""
return self._versions[idx]
@property
def free_count(self) -> int:
"""Returns the number of currently free slots."""
return self._free_mask.bit_count()
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
k_parts, v_parts = [], []
for pi in range(page_table.size(1)):
phys_pages = page_table[:, pi]
if not (phys_pages >= 0).any():
break
k_parts.append(self.k_cache[layer_id, phys_pages])
v_parts.append(self.v_cache[layer_id, phys_pages])
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
return k, v
class PrefixCacheManager:
"""Radix-tree prefix cache with LRU eviction.
class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len.
Maps token ID sequences to KV cache slots. Intermediate tree nodes
also store slot information, allowing direct slot reuse when the
cached slot is free and its version matches (no intervening writes).
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
"""
def __init__(self, max_capacity: int = 1000):
"""Initializes the prefix cache.
__slots__ = ("_cache", "_page_table", "_total_len")
Args:
max_capacity: Maximum number of nodes in the LRU list.
"""
self.root = _RadixNode()
self.max_capacity = max_capacity
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache
self._page_table = page_table
self._total_len = total_len
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
"""Inserts a token sequence into the prefix cache.
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
Every node along the path records the slot and its version,
enabling direct slot reuse for partial prefix matches.
Args:
token_ids: The token ID sequence to cache.
slot: The KV cache slot containing this prefix's computed keys/values.
slot_ver: The slot version at insertion time, used for staleness detection.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
nxt = _RadixNode()
node.children[tid] = nxt
node = nxt
node.slot = slot
node.slot_ver = slot_ver
node.last_access = time.time()
self._lru[id(node)] = node
node.ref_count += 1
self._evict_if_needed()
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
"""Finds the longest matching prefix in the cache.
Walks the radix tree token by token, recording the deepest match.
Args:
token_ids: The token sequence to match against.
Returns:
Tuple of (prefix_len, slot, slot_ver):
prefix_len: Number of matching tokens (0 if no match).
slot: KV cache slot of the matched prefix (-1 if no match).
slot_ver: Version of that slot when the prefix was inserted.
"""
node = self.root
best_len, best_slot, best_ver = 0, -1, 0
for i, tid in enumerate(token_ids):
nxt = node.children.get(tid)
if nxt is None:
break
node = nxt
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
node.last_access = time.time()
self._lru.move_to_end(id(node))
return best_len, best_slot, best_ver
def pin(self, token_ids: Tuple[int, ...]) -> None:
"""Increments the reference count of a cached prefix.
Called when a task reuses a cached prefix to prevent eviction.
Args:
token_ids: The token sequence whose node's ref_count to increment.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
node.ref_count += 1
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Decrements the reference count of a cached prefix.
The node's slot is preserved even when ref_count reaches zero,
allowing future tasks to reuse the slot directly if it remains free.
Args:
token_ids: The token sequence whose node's ref_count to decrement.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
if node.ref_count > 0:
node.ref_count -= 1
def copy_kv(
self,
token_ids: Tuple[int, ...],
target_slot: int,
kv_cache: Tuple[Tensor, Tensor],
n_layers: int,
) -> None:
"""Copies cached KV data from the source slot to a target slot.
Args:
token_ids: The prefix token sequence identifying the source cache node.
target_slot: The destination KV cache slot to copy into.
kv_cache: Tuple of (k_cache, v_cache) tensors.
n_layers: Number of transformer layers to copy.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
src_slot = node.slot
if src_slot < 0:
return
prefix_len = len(token_ids)
k_cache, v_cache = kv_cache
for li in range(n_layers):
k_cache[target_slot, :prefix_len, li].copy_(
k_cache[src_slot, :prefix_len, li]
)
v_cache[target_slot, :prefix_len, li].copy_(
v_cache[src_slot, :prefix_len, li]
)
def _evict_if_needed(self) -> None:
"""Evicts least-recently-used nodes until under capacity.
Skips nodes with ref_count > 0 (still in use by active tasks).
Evicted nodes have their slot and children cleared.
"""
while len(self._lru) > self.max_capacity:
key, node = next(iter(self._lru.items()))
if node.ref_count > 0:
self._lru.move_to_end(key)
continue
self._lru.pop(key)
node.slot = -1
node.slot_ver = 0
node.children.clear()
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v

View File

@ -16,7 +16,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
from astrai.inference.cache import _STOP
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer
@ -118,15 +118,15 @@ class _Result:
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel _STOP marks a task as complete.
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or _STOP sentinel.
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._lock:
self.tokens.append(token)
if token is not _STOP:
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
@ -186,38 +186,28 @@ class InferenceEngine:
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
cache_capacity: int = 1000,
page_size: int = 128,
):
"""Initializes the engine and starts the scheduler background thread.
"""Initializes the inference engine.
Args:
model: The language model (nn.Module, e.g. Transformer).
tokenizer: Tokenizer for encoding/decoding.
max_batch_size: Maximum concurrent tasks in the scheduler.
max_seq_len: Maximum sequence length (defaults to model config).
max_prompt_len: Maximum prompt tokens (longer prompts truncated).
cache_capacity: Maximum prefix cache nodes.
model: The model instance.
tokenizer: The tokenizer instance.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
"""
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.model = model
self.tokenizer = tokenizer
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len,
cache_capacity=cache_capacity,
device=device,
dtype=dtype,
page_size=page_size,
)
self.scheduler.start()
@ -383,7 +373,7 @@ class InferenceEngine:
while True:
tokens = result.pop_all()
for token in tokens:
if token is _STOP:
if token is STOP:
return
yield token
if not result.wait(timeout=0.05):

View File

@ -9,7 +9,7 @@ parameters, so a single pipeline works for any batch size.
"""
from abc import ABC, abstractmethod
from typing import List, Optional, Union
from typing import List, Union
import torch
from torch import Tensor

View File

@ -1,24 +1,19 @@
"""Inference scheduler for single-GPU continuous batching.
Splits scheduling concerns across modules:
- cache.py: SlotAllocator (Object Pool), PrefixCacheManager
- sampling.py: Strategy-pattern logit transformations
"""
"""Inference scheduler for single-GPU continuous batching with paged KV cache."""
import logging
import threading
import time
import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional
import torch
from torch import Tensor
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator
from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample
from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__)
@ -33,29 +28,7 @@ class TaskStatus(Enum):
class Task:
"""Represents a single generation request within the scheduler.
Tracks prompt tokens, generated output, sampling parameters,
KV cache slot assignment, and prefix cache matching state.
"""
__slots__ = (
"task_id",
"prompt_ids",
"max_tokens",
"temperature",
"top_p",
"top_k",
"status",
"output_ids",
"input_tokens",
"output_tokens",
"slot",
"prefix_len",
"arrival_time",
"finish_time",
"stream_callback",
)
"""Represents a single generation request with paged KV cache tracking."""
def __init__(
self,
@ -67,17 +40,6 @@ class Task:
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
):
"""Initializes a new task.
Args:
task_id: Unique identifier for this task.
prompt_ids: Tokenized prompt sequence.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
stream_callback: Optional callback invoked per decoded token.
"""
self.task_id = task_id
self.prompt_ids = prompt_ids
self.max_tokens = max_tokens
@ -89,26 +51,17 @@ class Task:
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.slot: int = -1
self.prefix_len: int = 0
self.page_table: List[int] = []
self.n_pages: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
@property
def next_pos(self) -> int:
"""Returns the next KV cache position to write during decode."""
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
"""Checks whether the task has reached a stopping condition.
Args:
stop_ids: List of stop token IDs (e.g., EOS).
Returns:
True if max_tokens reached or the last output token is a stop ID.
"""
if self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
@ -117,16 +70,13 @@ class Task:
class InferenceScheduler:
"""Continuous batching scheduler for single-GPU inference.
"""Continuous batching scheduler with paged KV cache.
Runs a background generation loop with four phases per iteration:
1. Cleanup finished tasks and release resources.
2. Refill active batch from the waiting queue.
3. Prefill newly activated tasks (full, partial, or fully cached).
3. Prefill newly activated tasks.
4. Decode the largest same-position group of active tasks.
Tasks at different positions are never batched together in decode,
avoiding RoPE corruption from misaligned KV cache writes.
"""
def __init__(
@ -136,22 +86,10 @@ class InferenceScheduler:
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 512,
cache_capacity: int = 1000,
page_size: int = 64,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
):
"""Initializes the scheduler and pre-allocates the KV cache.
Args:
model: The language model (must have config with n_layers, n_kv_heads, etc.).
tokenizer: Tokenizer for encoding prompts and decoding outputs.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length (defaults to config.max_len).
max_prompt_len: Maximum prompt tokens (longer prompts are truncated).
cache_capacity: Maximum prefix cache node count.
device: Target device for tensors.
dtype: Data type for KV cache tensors.
"""
config = model.config
self.model = model
@ -159,35 +97,25 @@ class InferenceScheduler:
self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len
self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
n_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads
n_layers = config.n_layers
self._n_layers = n_layers
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
k_cache = torch.empty(
(max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim),
device=self.device,
dtype=self.dtype,
)
v_cache = torch.empty(
(max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim),
device=self.device,
dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.zeros(
(max_batch_size, self.max_seq_len),
device=self.device,
dtype=torch.bool,
self.page_cache = PagedCache(
n_layers,
n_pages,
page_size,
n_kv_heads,
head_dim,
self.device,
self.dtype,
)
self.slot_allocator = SlotAllocator(max_batch_size)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
@ -198,41 +126,8 @@ class InferenceScheduler:
self._total_tasks = 0
self._total_tokens = 0
def _alloc_slot(self) -> int:
"""Allocates a free KV cache slot using the Object Pool.
Returns:
Slot index on success, -1 if all slots are occupied.
"""
return self.slot_allocator.alloc()
def _free_slot(self, idx: int) -> None:
"""Releases a KV cache slot back to the free pool.
Args:
idx: Slot index to free.
"""
self.slot_allocator.free(idx)
self.seq_mask[idx, :] = False
def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]:
"""Attempts to reuse a prefix-cached slot directly without KV copy.
The slot is reusable only if it is free and its version matches
the current slot version (no intervening allocation overwrote it).
Args:
prefix: The matched prefix token sequence.
Returns:
Tuple of (slot, True) on success, or (-1, False) if reuse is not possible.
"""
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot):
if cached_ver == self.slot_allocator.version(cached_slot):
self.slot_allocator.occupy(cached_slot)
return cached_slot, True
return -1, False
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def add_task(
self,
@ -243,25 +138,8 @@ class InferenceScheduler:
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
) -> str:
"""Adds a generation task to the waiting queue.
Encodes the prompt, queries the prefix cache for a match,
and enqueues the task for the background generation loop.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
stream_callback: Called per decoded token with the string representation.
Returns:
Unique task ID string.
"""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :]
@ -275,9 +153,6 @@ class InferenceScheduler:
stream_callback=stream_callback,
)
prefix_len, _cached_slot, _cached_ver = self.prefix_cache.find(prompt_ids)
task.prefix_len = prefix_len
with self._lock:
self.waiting_queue.append(task)
self._total_tasks += 1
@ -286,32 +161,21 @@ class InferenceScheduler:
return task_id
def remove_task(self, task_id: str) -> None:
"""Removes a task from both the waiting queue and active tasks.
Args:
task_id: The task to remove.
"""
with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active:
if task.prefix_len > 0:
prefix = tuple(task.prompt_ids[: task.prefix_len])
self.prefix_cache.release(prefix)
if task.prefix_len < len(task.prompt_ids):
self.prefix_cache.release(tuple(task.prompt_ids))
if task.slot >= 0:
self._free_slot(task.slot)
task.slot = -1
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
def _free_pages(self, indices: List[int]) -> None:
for idx in indices:
self.page_cache.free(idx)
def _remove_finished_tasks(self) -> None:
"""Removes all finished tasks from the active batch.
Releases prefix cache references and frees the KV cache slot
for each completed task.
"""
finished = []
for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids):
@ -321,25 +185,15 @@ class InferenceScheduler:
self._total_tokens += task.output_tokens
for task in finished:
if task.prefix_len > 0:
prefix = tuple(task.prompt_ids[: task.prefix_len])
self.prefix_cache.release(prefix)
if task.prefix_len < len(task.prompt_ids):
self.prefix_cache.release(tuple(task.prompt_ids))
if task.slot >= 0:
self._free_slot(task.slot)
task.slot = -1
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
]
def _refill_active_batch(self) -> None:
"""Moves waiting tasks into the active batch, up to max_batch_size.
Attempts direct slot reuse for prefix-matched tasks; falls back
to allocating a fresh slot with KV cache copy when reuse is not possible.
"""
available = self.max_batch_size - len(self.active_tasks)
if available <= 0:
return
@ -350,122 +204,71 @@ class InferenceScheduler:
for _ in range(n):
to_add.append(self.waiting_queue.pop(0))
for i, task in enumerate(to_add):
slot = -1
reused = False
if task.prefix_len > 0:
prefix = tuple(task.prompt_ids[: task.prefix_len])
cached_slot, reused = self._try_reuse_slot(prefix)
if reused:
slot = cached_slot
if slot < 0:
slot = self._alloc_slot()
if slot < 0:
with self._lock:
self.waiting_queue[:0] = to_add[i:]
break
task.slot = slot
for task in to_add:
prompt_len = len(task.prompt_ids)
n_pages = self._n_pages_for(prompt_len)
task.page_table = self.page_cache.alloc_n(n_pages)
if not task.page_table:
with self._lock:
self.waiting_queue.insert(0, task)
break
task.n_pages = len(task.page_table)
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
if task.prefix_len > 0 and not reused:
prefix = tuple(task.prompt_ids[: task.prefix_len])
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
if cached_slot >= 0 and cached_ver == self.slot_allocator.version(
cached_slot
):
self.prefix_cache.pin(prefix)
self.prefix_cache.copy_kv(
prefix, slot, self.kv_cache, self._n_layers
)
else:
task.prefix_len = 0
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Runs batched prefill for newly activated tasks.
Fully-cached tasks skip the model. Others are grouped by prefix_len
so tasks sharing the same start_pos are batched together.
"""
if not tasks:
def _execute_prefill(self) -> None:
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if not to_prefill:
return
for t in to_prefill:
prompt_len = len(t.prompt_ids)
t.input_tokens = prompt_len
t.output_tokens = 0
groups: Dict[int, List[Task]] = {}
for t in tasks:
plen = len(t.prompt_ids)
if t.prefix_len == plen:
t.input_tokens = plen
t.output_tokens = 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
else:
groups.setdefault(t.prefix_len, []).append(t)
for t in to_prefill:
groups.setdefault(len(t.prompt_ids), []).append(t)
for prefix_len, group in groups.items():
slot_indices = torch.tensor([t.slot for t in group], device=self.device)
self._execute_prefill_batch(group, prefix_len, slot_indices)
for prompt_len, group in groups.items():
self._execute_prefill_batch(group, prompt_len)
def _execute_prefill_batch(
self, tasks: List[Task], prefix_len: int, slot_indices: Tensor
) -> None:
"""Unified prefill for tasks sharing a common prefix_len.
Args:
tasks: Tasks with the same prefix_len < len(prompt_ids).
prefix_len: Number of cached prefix tokens (0 for full prefill).
slot_indices: Tensor of slot indices for KV cache mapping.
"""
tasks = sorted(tasks, key=lambda t: t.slot)
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
new_lens = [len(t.prompt_ids) - prefix_len for t in tasks]
max_new_len = max(new_lens)
input_ids = torch.zeros(
batch_sz, max_new_len, dtype=torch.long, device=self.device
batch_sz,
prompt_len,
dtype=torch.long,
device=self.device,
)
input_mask = torch.zeros(
batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device
input_mask = torch.ones(
batch_sz,
prompt_len,
dtype=torch.bool,
device=self.device,
)
for i, t in enumerate(tasks):
new_ids = t.prompt_ids[prefix_len:]
nl = len(new_ids)
if nl > 0:
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
input_mask[i, : prefix_len + nl] = True
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
with torch.inference_mode():
self.model(
input_ids,
input_mask=input_mask,
start_pos=prefix_len,
persistent_key_values=self.kv_cache,
slot_indices=slot_indices,
start_pos=0,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
for i, t in enumerate(tasks):
t.input_tokens = len(t.prompt_ids)
t.output_tokens = 0
self.prefix_cache.insert(
tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot)
)
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Executes the decode phase for a group of tasks at the same position.
Args:
tasks: Tasks sharing the same next_pos value.
start_pos: Common KV cache write position for the batch.
"""
if not tasks:
return
tasks = sorted(tasks, key=lambda t: t.slot)
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
for i, t in enumerate(tasks):
@ -473,13 +276,15 @@ class InferenceScheduler:
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
with torch.inference_mode():
outputs = self.model(
input_ids.unsqueeze(1),
input_mask=active_mask,
persistent_key_values=self.kv_cache,
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos,
slot_indices=slot_indices,
)
logits = outputs["logits"][:, -1, :]
@ -496,23 +301,30 @@ class InferenceScheduler:
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
if t.slot >= 0 and pos < self.max_seq_len:
self.seq_mask[t.slot, pos] = True
self._maybe_alloc_page(t, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(_STOP)
t.stream_callback(STOP)
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _maybe_alloc_page(self, task: Task, pos: int) -> None:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
break
task.page_table.append(p)
task.n_pages += 1
def _run_generation_loop(self) -> None:
"""Main generation loop run in a daemon thread.
Continuously cycles through cleanup, refill, prefill, and decode.
Decode processes only the largest position group to ensure all
batched tasks share the same KV cache write position.
"""
try:
while self._running:
self._remove_finished_tasks()
@ -523,11 +335,8 @@ class InferenceScheduler:
self._task_event.clear()
self._task_event.wait(timeout=0.01)
continue
tasks = self.active_tasks[:]
to_prefill = [t for t in tasks if t.output_tokens == 0]
if to_prefill:
self._execute_prefill(to_prefill)
self._execute_prefill()
pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks:
@ -544,21 +353,20 @@ class InferenceScheduler:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks:
if task.stream_callback:
task.stream_callback(_STOP)
task.stream_callback(STOP)
for task in self.waiting_queue:
if task.stream_callback:
task.stream_callback(_STOP)
task.stream_callback(STOP)
raise
def start(self) -> None:
"""Starts the background generation loop thread."""
if not self._running:
self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start()
self._loop_thread = t
def stop(self) -> None:
"""Stops the generation loop and releases all resources."""
self._running = False
self._task_event.set()
if hasattr(self, "_loop_thread"):
@ -569,11 +377,6 @@ class InferenceScheduler:
torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]:
"""Returns current scheduler statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return {
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,

View File

@ -5,17 +5,11 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.inference.cache import CacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
"""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
"""Repeat KV heads n_rep times for GQA."""
bs, slen, n_heads, head_dim = x.shape
if n_rep == 1:
return x
@ -32,49 +26,25 @@ def get_rotary_emb(
base: float = 10000,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
"""Precompute cos/sin for RoPE."""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
"""
Apply rotary embedding to the input tensor using cos/sin form.
Args:
x (Tensor): The input tensor (shape [..., seq_len, dim]).
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
Returns:
Tensor: The output tensor (rotated, same shape as input).
"""
"""Apply rotary embedding via cos/sin (shape-preserving)."""
dtype = x.dtype
cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
cos = cos.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2)
x_real = x[..., 0::2]
x_imag = x[..., 1::2]
x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = x_out.view(*x_out.shape[:-2], -1)
return x_out.to(dtype)
@ -95,13 +65,10 @@ class RotaryEmbedding(nn.Module):
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin)
@ -185,14 +152,13 @@ class GQA(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -201,18 +167,14 @@ class GQA(nn.Module):
if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id)
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3)
@ -224,7 +186,6 @@ class GQA(nn.Module):
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out)
return out
@ -257,7 +218,7 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# KV (k_nope, k_rope, v)
# fused KV: (k_nope, k_rope, v)
self.kv_b_proj = Linear(
kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -273,9 +234,8 @@ class MLA(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = mask is None
@ -303,12 +263,9 @@ class MLA(nn.Module):
q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1)
if kv_cache is not None:
k_cache, v_cache = kv_cache
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k
v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
if paged_cache is not None:
paged_cache.write(self.layer_id, start_pos, k, v)
k, v = paged_cache.gather(self.layer_id)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
@ -321,7 +278,6 @@ class MLA(nn.Module):
attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out)
return out
@ -356,24 +312,19 @@ class DecoderBlock(nn.Module):
x: Tensor,
rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor:
# attention
attn_output = self.attention(
self.input_norm(x),
rotary_emb,
attention_mask,
kv_cache,
paged_cache,
start_pos,
slot_indices,
)
x = attn_output + x
# feed forward
x = self.mlp(self.post_attention_norm(x)) + x
return x

View File

@ -1,10 +1,11 @@
from typing import Any, Mapping, Optional, Tuple
from typing import Any, Mapping, Optional
import torch
import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.inference.cache import CacheView
from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
@ -21,39 +22,25 @@ def process_attention_mask(
start_pos: int = 0,
is_causal: bool = False,
) -> Tensor:
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
input_tensor (Tensor): The input tensor.
start_pos (int): The starting position of the sequence.
is_causal (bool): Whether the attention is causal or not.
Returns:
Tensor: The attention mask tensor.
"""
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
device = input_tensor.device
dtype = input_tensor.dtype
seq_len = input_tensor.size(1)
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len
)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
@ -62,16 +49,13 @@ def process_attention_mask(
attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@AutoModel.register("transformer")
class Transformer(AutoModel):
"""
Transformer language model.
"""
"""Transformer language model with paged KV cache."""
def __init__(self, config: ModelConfig):
super().__init__(config)
@ -114,18 +98,15 @@ class Transformer(AutoModel):
lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight"
# Make a copy to avoid modifying the original state_dict
state_dict = dict(state_dict)
if self.config.tie_weight:
# same tensor
# same tensor for embed and lm_head
if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key]
else:
# If lm_head.weight exists in checkpoint, use it directly
# If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict:
# use clone to avoid sharing the same tensor
# clone to avoid sharing gradients
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign)
@ -146,9 +127,8 @@ class Transformer(AutoModel):
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
paged_cache: Optional[CacheView] = None,
start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2
@ -157,13 +137,8 @@ class Transformer(AutoModel):
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
if slot_indices is None:
slot_indices = slice(input_ids.size(0))
for layer in self.layers:
x = layer(
x, rotary_emb, attn_mask, persistent_key_values, start_pos, slot_indices
)
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
hidden_states = self.norm(x)
logits = self.lm_head(hidden_states)

View File

@ -6,102 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest
from astrai.inference.cache import PrefixCacheManager
from astrai.inference.scheduler import (
InferenceScheduler,
)
def test_prefix_cache_concurrent_insert_find():
"""Test concurrent insert and find operations."""
cache = PrefixCacheManager(max_capacity=100)
results = {"errors": [], "inserts": 0, "finds": 0}
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10, slot_ver=0)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
def find_worker():
try:
for i in range(50):
cache.find([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
threads += [threading.Thread(target=find_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert results["inserts"] == 150
assert results["finds"] == 150
def test_prefix_cache_concurrent_release():
"""Test concurrent release operations."""
cache = PrefixCacheManager(max_capacity=100)
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i, slot_ver=0)
results = {"errors": []}
def release_worker():
try:
for i in range(10):
cache.release((i,))
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=release_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
def test_prefix_cache_concurrent_insert_release_find():
"""Test mixed concurrent operations."""
cache = PrefixCacheManager(max_capacity=50)
results = {"errors": []}
def worker(worker_id):
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id, slot_ver=0)
# Find after insert
cache.find(list(token_ids))
# Release
cache.release(token_ids)
except Exception as e:
results["errors"].append(f"Worker {worker_id}: {str(e)}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
from astrai.inference.scheduler import InferenceScheduler
@pytest.fixture
@ -266,54 +171,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefix_cache_insert_same_prefix_concurrently():
"""Test inserting the same prefix concurrently."""
cache = PrefixCacheManager(max_capacity=100)
results = {"slot_values": [], "errors": []}
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=0, slot_ver=0)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
if node:
node = node.children.get(3)
if node:
results["slot_values"].append(node.slot)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All inserts should succeed, final slot should be one of the values
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
# Check ref_count is correct (should be 10)
node = cache.root.children.get(1).children.get(2).children.get(3)
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
cache.insert((1, 2, 3), slot=0, slot_ver=0)
# Release multiple times
for _ in range(5):
cache.release((1, 2, 3))
# Try to find it - should return None since ref_count would be negative
# or handle it gracefully
node = cache.root.children.get(1).children.get(2).children.get(3)
# The ref_count should be 0, not negative
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"