From 30cc2d67a417fd9a2e022b40845280bded9e22f6 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 8 May 2026 20:44:05 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=88=86=E9=A1=B5=20KV=20cache=20?= =?UTF-8?q?=E6=9B=BF=E6=8D=A2=E5=9B=BA=E5=AE=9A=20slot=EF=BC=8C=E5=88=A0?= =?UTF-8?q?=E9=99=A4=20PrefixCache=20=E5=8F=8A=E7=9B=B8=E5=85=B3=E6=AD=BB?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 用 PagedCache + CacheView 替换固定 slot 式 KV cache,attention 层只通过 page_table 间接索引 - 删除 PrefixCache(radix tree)及 scheduler 中所有 prefix cache 命中/插入/释放逻辑 - 删除无用函数:pin、version、free_count、_mark_seq_mask 及 seq_mask 分配 - 修复 write 在多页 prefill 时 offset 为负导致 chunk 计算错误 - _make_page_table_tensor 改用 list 拼接一次 tensor,去掉逐元素赋值 - 清理 model 接口参数:kv_cache, slot_indices → paged_cache(CacheView) - 精简 docstring 为单行,删除冗余 section 注释和旧代码 - 修复 test_scheduler_concurrency.py 缺少 import pytest --- astrai/inference/cache.py | 310 +++++--------- astrai/inference/engine.py | 40 +- astrai/inference/sampling.py | 2 +- astrai/inference/scheduler.py | 385 +++++------------- astrai/model/module.py | 95 ++--- astrai/model/transformer.py | 41 +- tests/inference/test_scheduler_concurrency.py | 148 +------ 7 files changed, 244 insertions(+), 777 deletions(-) diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index 28c96cd..08812c4 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -1,241 +1,135 @@ -"""KV cache slot allocation and prefix cache management. +"""Page-based KV cache with page-table-indirected read/write. Provides: - - SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask. - - PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse. + - PagedCache: paged KV cache combining page pool and tensor storage. """ -import time -from collections import OrderedDict -from typing import Dict, List, Tuple +from typing import List, Tuple +import torch from torch import Tensor -_STOP = object() +STOP = object() -class _RadixNode: - """Internal node for the radix tree prefix cache. +class PagedCache: + """Paged KV cache with page-table-indirected read/write. - Attributes: - children: Mapping from token ID to child node. - slot: KV cache slot index for the prefix ending at this node. - slot_ver: Version counter of the slot at insertion time. - ref_count: Number of tasks currently referencing this node. - last_access: Timestamp of the most recent access (for LRU ordering). + Combines: + - Page pool (ref-counted alloc/free via bitmask) + - KV tensor storage (k_cache, v_cache) + + Call :meth:`bind` to obtain a batch view for the attention layers. """ - __slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access") - - def __init__(self): - self.children: Dict[int, "_RadixNode"] = {} - self.slot: int = -1 - self.slot_ver: int = 0 - self.ref_count: int = 0 - self.last_access: float = 0.0 - - -class SlotAllocator: - """KV cache slot allocator using bitmask for O(1) alloc/free. - - Implements the Object Pool pattern: pre-allocated KV cache slots - are managed via a bitmask, providing constant-time allocation and - deallocation with version counters for staleness detection. - """ - - def __init__(self, max_slots: int): - self._max_slots = max_slots - self._free_mask = (1 << max_slots) - 1 - self._versions: List[int] = [0] * max_slots + def __init__( + self, + n_layers: int, + n_pages: int, + page_size: int, + n_kv_heads: int, + head_dim: int, + device: torch.device, + dtype: torch.dtype, + ): + self.page_size = page_size + self._free_mask = (1 << n_pages) - 1 + self._refs: List[int] = [0] * n_pages + self.k_cache = torch.empty( + (n_layers, n_pages, page_size, n_kv_heads, head_dim), + device=device, + dtype=dtype, + ) + self.v_cache = torch.empty( + (n_layers, n_pages, page_size, n_kv_heads, head_dim), + device=device, + dtype=dtype, + ) def alloc(self) -> int: - """Allocates a free slot. - - Returns: - Slot index on success, -1 if all slots are occupied. - """ lsb = self._free_mask & -self._free_mask if lsb == 0: return -1 idx = lsb.bit_length() - 1 self._free_mask ^= lsb - self._versions[idx] += 1 + self._refs[idx] = 1 return idx + def alloc_n(self, n: int) -> List[int]: + pages = [self.alloc() for _ in range(n)] + if any(p < 0 for p in pages): + for p in pages: + if p >= 0: + self.free(p) + return [] + return pages + def free(self, idx: int) -> None: - """Releases a slot back to the free pool.""" - self._free_mask |= 1 << idx + self._refs[idx] -= 1 + if self._refs[idx] == 0: + self._free_mask |= 1 << idx - def occupy(self, idx: int) -> None: - """Marks a currently free slot as occupied without bumping its version. + def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": + return CacheView(self, page_table, total_len) - Used for direct slot reuse when a prefix-cached slot is still valid. - """ - self._free_mask ^= 1 << idx + def write( + self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor + ) -> None: + seq_len = k.size(1) + if seq_len == 0: + return + page_size = self.page_size + written = 0 + first_page = start_pos // page_size + last_page = (start_pos + seq_len - 1) // page_size + for pi in range(first_page, last_page + 1): + phys_pages = page_table[:, pi] + page_start = pi * page_size + write_start = max(page_start, start_pos) + write_end = min(page_start + page_size, start_pos + seq_len) + offset = write_start - page_start + chunk = write_end - write_start + self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[ + :, written : written + chunk + ] + self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[ + :, written : written + chunk + ] + written += chunk - def is_free(self, idx: int) -> bool: - """Checks whether a slot is currently free.""" - return (self._free_mask >> idx) & 1 == 1 - - def version(self, idx: int) -> int: - """Returns the current version counter for a slot.""" - return self._versions[idx] - - @property - def free_count(self) -> int: - """Returns the number of currently free slots.""" - return self._free_mask.bit_count() + def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]: + k_parts, v_parts = [], [] + for pi in range(page_table.size(1)): + phys_pages = page_table[:, pi] + if not (phys_pages >= 0).any(): + break + k_parts.append(self.k_cache[layer_id, phys_pages]) + v_parts.append(self.v_cache[layer_id, phys_pages]) + k = torch.cat(k_parts, dim=1) + v = torch.cat(v_parts, dim=1) + return k, v -class PrefixCacheManager: - """Radix-tree prefix cache with LRU eviction. +class CacheView: + """Per-batch view that bundles PagedCache + page_table + total_len. - Maps token ID sequences to KV cache slots. Intermediate tree nodes - also store slot information, allowing direct slot reuse when the - cached slot is free and its version matches (no intervening writes). + Attention layers receive this as ``paged_cache`` and only see + ``write()`` / ``gather()``, never raw page tables or length params. """ - def __init__(self, max_capacity: int = 1000): - """Initializes the prefix cache. + __slots__ = ("_cache", "_page_table", "_total_len") - Args: - max_capacity: Maximum number of nodes in the LRU list. - """ - self.root = _RadixNode() - self.max_capacity = max_capacity - self._lru: OrderedDict[int, _RadixNode] = OrderedDict() + def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0): + self._cache = cache + self._page_table = page_table + self._total_len = total_len - def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None: - """Inserts a token sequence into the prefix cache. + def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None: + self._cache.write(layer_id, self._page_table, start_pos, k, v) - Every node along the path records the slot and its version, - enabling direct slot reuse for partial prefix matches. - - Args: - token_ids: The token ID sequence to cache. - slot: The KV cache slot containing this prefix's computed keys/values. - slot_ver: The slot version at insertion time, used for staleness detection. - """ - node = self.root - for tid in token_ids: - nxt = node.children.get(tid) - if nxt is None: - nxt = _RadixNode() - node.children[tid] = nxt - node = nxt - node.slot = slot - node.slot_ver = slot_ver - node.last_access = time.time() - self._lru[id(node)] = node - node.ref_count += 1 - self._evict_if_needed() - - def find(self, token_ids: List[int]) -> Tuple[int, int, int]: - """Finds the longest matching prefix in the cache. - - Walks the radix tree token by token, recording the deepest match. - - Args: - token_ids: The token sequence to match against. - - Returns: - Tuple of (prefix_len, slot, slot_ver): - prefix_len: Number of matching tokens (0 if no match). - slot: KV cache slot of the matched prefix (-1 if no match). - slot_ver: Version of that slot when the prefix was inserted. - """ - node = self.root - best_len, best_slot, best_ver = 0, -1, 0 - for i, tid in enumerate(token_ids): - nxt = node.children.get(tid) - if nxt is None: - break - node = nxt - best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver - node.last_access = time.time() - self._lru.move_to_end(id(node)) - return best_len, best_slot, best_ver - - def pin(self, token_ids: Tuple[int, ...]) -> None: - """Increments the reference count of a cached prefix. - - Called when a task reuses a cached prefix to prevent eviction. - - Args: - token_ids: The token sequence whose node's ref_count to increment. - """ - node = self.root - for tid in token_ids: - nxt = node.children.get(tid) - if nxt is None: - return - node = nxt - node.ref_count += 1 - - def release(self, token_ids: Tuple[int, ...]) -> None: - """Decrements the reference count of a cached prefix. - - The node's slot is preserved even when ref_count reaches zero, - allowing future tasks to reuse the slot directly if it remains free. - - Args: - token_ids: The token sequence whose node's ref_count to decrement. - """ - node = self.root - for tid in token_ids: - nxt = node.children.get(tid) - if nxt is None: - return - node = nxt - if node.ref_count > 0: - node.ref_count -= 1 - - def copy_kv( - self, - token_ids: Tuple[int, ...], - target_slot: int, - kv_cache: Tuple[Tensor, Tensor], - n_layers: int, - ) -> None: - """Copies cached KV data from the source slot to a target slot. - - Args: - token_ids: The prefix token sequence identifying the source cache node. - target_slot: The destination KV cache slot to copy into. - kv_cache: Tuple of (k_cache, v_cache) tensors. - n_layers: Number of transformer layers to copy. - """ - node = self.root - for tid in token_ids: - nxt = node.children.get(tid) - if nxt is None: - return - node = nxt - src_slot = node.slot - if src_slot < 0: - return - prefix_len = len(token_ids) - k_cache, v_cache = kv_cache - for li in range(n_layers): - k_cache[target_slot, :prefix_len, li].copy_( - k_cache[src_slot, :prefix_len, li] - ) - v_cache[target_slot, :prefix_len, li].copy_( - v_cache[src_slot, :prefix_len, li] - ) - - def _evict_if_needed(self) -> None: - """Evicts least-recently-used nodes until under capacity. - - Skips nodes with ref_count > 0 (still in use by active tasks). - Evicted nodes have their slot and children cleared. - """ - while len(self._lru) > self.max_capacity: - key, node = next(iter(self._lru.items())) - if node.ref_count > 0: - self._lru.move_to_end(key) - continue - self._lru.pop(key) - node.slot = -1 - node.slot_ver = 0 - node.children.clear() + def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: + k, v = self._cache.gather(layer_id, self._page_table) + if self._total_len: + k = k[:, : self._total_len] + v = v[:, : self._total_len] + return k, v diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index d3a50ce..440c958 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -16,7 +16,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union import torch import torch.nn as nn -from astrai.inference.cache import _STOP +from astrai.inference.cache import STOP from astrai.inference.scheduler import InferenceScheduler from astrai.tokenize import AutoTokenizer @@ -118,15 +118,15 @@ class _Result: """Appends a token to the result buffer. In non-streaming mode, tokens are concatenated into results[idx]. - The sentinel _STOP marks a task as complete. + The sentinel STOP marks a task as complete. Args: - token: The decoded token string, or _STOP sentinel. + token: The decoded token string, or STOP sentinel. idx: Index of the generation task this token belongs to. """ with self._lock: self.tokens.append(token) - if token is not _STOP: + if token is not STOP: self.results[idx] += token else: if not self._done[idx]: @@ -186,38 +186,28 @@ class InferenceEngine: max_batch_size: int = 1, max_seq_len: Optional[int] = None, max_prompt_len: int = 2048, - cache_capacity: int = 1000, + page_size: int = 128, ): - """Initializes the engine and starts the scheduler background thread. + """Initializes the inference engine. Args: - model: The language model (nn.Module, e.g. Transformer). - tokenizer: Tokenizer for encoding/decoding. - max_batch_size: Maximum concurrent tasks in the scheduler. - max_seq_len: Maximum sequence length (defaults to model config). - max_prompt_len: Maximum prompt tokens (longer prompts truncated). - cache_capacity: Maximum prefix cache nodes. + model: The model instance. + tokenizer: The tokenizer instance. + max_batch_size: Maximum number of concurrent tasks. + max_seq_len: Maximum sequence length. + max_prompt_len: Maximum prompt tokens. + compile: Whether to compile the model with torch.compile. + page_size: Number of tokens per KV cache page. """ - try: - first_param = next(model.parameters()) - device = first_param.device - dtype = first_param.dtype - except StopIteration: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - dtype = torch.float32 - self.model = model self.tokenizer = tokenizer - self.scheduler = InferenceScheduler( model=self.model, tokenizer=self.tokenizer, max_batch_size=max_batch_size, max_seq_len=max_seq_len, max_prompt_len=max_prompt_len, - cache_capacity=cache_capacity, - device=device, - dtype=dtype, + page_size=page_size, ) self.scheduler.start() @@ -383,7 +373,7 @@ class InferenceEngine: while True: tokens = result.pop_all() for token in tokens: - if token is _STOP: + if token is STOP: return yield token if not result.wait(timeout=0.05): diff --git a/astrai/inference/sampling.py b/astrai/inference/sampling.py index cb5315c..300b5b3 100644 --- a/astrai/inference/sampling.py +++ b/astrai/inference/sampling.py @@ -9,7 +9,7 @@ parameters, so a single pipeline works for any batch size. """ from abc import ABC, abstractmethod -from typing import List, Optional, Union +from typing import List, Union import torch from torch import Tensor diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 2699886..2d09590 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -1,24 +1,19 @@ -"""Inference scheduler for single-GPU continuous batching. - -Splits scheduling concerns across modules: - - cache.py: SlotAllocator (Object Pool), PrefixCacheManager - - sampling.py: Strategy-pattern logit transformations -""" +"""Inference scheduler for single-GPU continuous batching with paged KV cache.""" import logging import threading import time import uuid from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional import torch from torch import Tensor -from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator +from astrai.inference.cache import STOP, PagedCache from astrai.inference.sampling import sample from astrai.model.automodel import AutoModel -from astrai.tokenize import AutoTokenizer +from astrai.tokenize.tokenizer import AutoTokenizer logger = logging.getLogger(__name__) @@ -33,29 +28,7 @@ class TaskStatus(Enum): class Task: - """Represents a single generation request within the scheduler. - - Tracks prompt tokens, generated output, sampling parameters, - KV cache slot assignment, and prefix cache matching state. - """ - - __slots__ = ( - "task_id", - "prompt_ids", - "max_tokens", - "temperature", - "top_p", - "top_k", - "status", - "output_ids", - "input_tokens", - "output_tokens", - "slot", - "prefix_len", - "arrival_time", - "finish_time", - "stream_callback", - ) + """Represents a single generation request with paged KV cache tracking.""" def __init__( self, @@ -67,17 +40,6 @@ class Task: top_k: int = 50, stream_callback: Optional[Callable[[str], None]] = None, ): - """Initializes a new task. - - Args: - task_id: Unique identifier for this task. - prompt_ids: Tokenized prompt sequence. - max_tokens: Maximum number of tokens to generate. - temperature: Sampling temperature. - top_p: Nucleus sampling probability threshold. - top_k: Top-k sampling count (0 disables). - stream_callback: Optional callback invoked per decoded token. - """ self.task_id = task_id self.prompt_ids = prompt_ids self.max_tokens = max_tokens @@ -89,26 +51,17 @@ class Task: self.output_ids: List[int] = [] self.input_tokens: int = 0 self.output_tokens: int = 0 - self.slot: int = -1 - self.prefix_len: int = 0 + self.page_table: List[int] = [] + self.n_pages: int = 0 self.arrival_time = time.time() self.finish_time: Optional[float] = None self.stream_callback = stream_callback @property def next_pos(self) -> int: - """Returns the next KV cache position to write during decode.""" return self.input_tokens + len(self.output_ids) def is_finished(self, stop_ids: List[int]) -> bool: - """Checks whether the task has reached a stopping condition. - - Args: - stop_ids: List of stop token IDs (e.g., EOS). - - Returns: - True if max_tokens reached or the last output token is a stop ID. - """ if self.output_tokens >= self.max_tokens: return True if self.output_ids and self.output_ids[-1] in stop_ids: @@ -117,16 +70,13 @@ class Task: class InferenceScheduler: - """Continuous batching scheduler for single-GPU inference. + """Continuous batching scheduler with paged KV cache. Runs a background generation loop with four phases per iteration: 1. Cleanup finished tasks and release resources. 2. Refill active batch from the waiting queue. - 3. Prefill newly activated tasks (full, partial, or fully cached). + 3. Prefill newly activated tasks. 4. Decode the largest same-position group of active tasks. - - Tasks at different positions are never batched together in decode, - avoiding RoPE corruption from misaligned KV cache writes. """ def __init__( @@ -136,22 +86,10 @@ class InferenceScheduler: max_batch_size: int = 16, max_seq_len: Optional[int] = None, max_prompt_len: int = 512, - cache_capacity: int = 1000, + page_size: int = 64, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): - """Initializes the scheduler and pre-allocates the KV cache. - - Args: - model: The language model (must have config with n_layers, n_kv_heads, etc.). - tokenizer: Tokenizer for encoding prompts and decoding outputs. - max_batch_size: Maximum number of concurrent tasks. - max_seq_len: Maximum sequence length (defaults to config.max_len). - max_prompt_len: Maximum prompt tokens (longer prompts are truncated). - cache_capacity: Maximum prefix cache node count. - device: Target device for tensors. - dtype: Data type for KV cache tensors. - """ config = model.config self.model = model @@ -159,35 +97,25 @@ class InferenceScheduler: self.max_batch_size = max_batch_size self.max_seq_len = max_seq_len or config.max_len self.max_prompt_len = max_prompt_len + self.page_size = page_size self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype - self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity) - n_kv_heads = config.n_kv_heads head_dim = config.dim // config.n_heads n_layers = config.n_layers - self._n_layers = n_layers + n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size - k_cache = torch.empty( - (max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim), - device=self.device, - dtype=self.dtype, - ) - v_cache = torch.empty( - (max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim), - device=self.device, - dtype=self.dtype, - ) - self.kv_cache = (k_cache, v_cache) - - self.seq_mask = torch.zeros( - (max_batch_size, self.max_seq_len), - device=self.device, - dtype=torch.bool, + self.page_cache = PagedCache( + n_layers, + n_pages, + page_size, + n_kv_heads, + head_dim, + self.device, + self.dtype, ) - self.slot_allocator = SlotAllocator(max_batch_size) self.waiting_queue: List[Task] = [] self.active_tasks: List[Task] = [] @@ -198,41 +126,8 @@ class InferenceScheduler: self._total_tasks = 0 self._total_tokens = 0 - def _alloc_slot(self) -> int: - """Allocates a free KV cache slot using the Object Pool. - - Returns: - Slot index on success, -1 if all slots are occupied. - """ - return self.slot_allocator.alloc() - - def _free_slot(self, idx: int) -> None: - """Releases a KV cache slot back to the free pool. - - Args: - idx: Slot index to free. - """ - self.slot_allocator.free(idx) - self.seq_mask[idx, :] = False - - def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]: - """Attempts to reuse a prefix-cached slot directly without KV copy. - - The slot is reusable only if it is free and its version matches - the current slot version (no intervening allocation overwrote it). - - Args: - prefix: The matched prefix token sequence. - - Returns: - Tuple of (slot, True) on success, or (-1, False) if reuse is not possible. - """ - _plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix)) - if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot): - if cached_ver == self.slot_allocator.version(cached_slot): - self.slot_allocator.occupy(cached_slot) - return cached_slot, True - return -1, False + def _n_pages_for(self, n_tokens: int) -> int: + return (n_tokens + self.page_size - 1) // self.page_size def add_task( self, @@ -243,25 +138,8 @@ class InferenceScheduler: top_k: int = 50, stream_callback: Optional[Callable[[str], None]] = None, ) -> str: - """Adds a generation task to the waiting queue. - - Encodes the prompt, queries the prefix cache for a match, - and enqueues the task for the background generation loop. - - Args: - prompt: Input text to generate from. - max_tokens: Maximum tokens to generate. - temperature: Sampling temperature. - top_p: Nucleus sampling threshold. - top_k: Top-k sampling count. - stream_callback: Called per decoded token with the string representation. - - Returns: - Unique task ID string. - """ task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" prompt_ids = self.tokenizer.encode(prompt) - if len(prompt_ids) > self.max_prompt_len: prompt_ids = prompt_ids[-self.max_prompt_len :] @@ -275,9 +153,6 @@ class InferenceScheduler: stream_callback=stream_callback, ) - prefix_len, _cached_slot, _cached_ver = self.prefix_cache.find(prompt_ids) - task.prefix_len = prefix_len - with self._lock: self.waiting_queue.append(task) self._total_tasks += 1 @@ -286,32 +161,21 @@ class InferenceScheduler: return task_id def remove_task(self, task_id: str) -> None: - """Removes a task from both the waiting queue and active tasks. - - Args: - task_id: The task to remove. - """ with self._lock: removed_active = [t for t in self.active_tasks if t.task_id == task_id] self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] for task in removed_active: - if task.prefix_len > 0: - prefix = tuple(task.prompt_ids[: task.prefix_len]) - self.prefix_cache.release(prefix) - if task.prefix_len < len(task.prompt_ids): - self.prefix_cache.release(tuple(task.prompt_ids)) - if task.slot >= 0: - self._free_slot(task.slot) - task.slot = -1 + self._free_pages(task.page_table) + task.page_table.clear() + task.n_pages = 0 + + def _free_pages(self, indices: List[int]) -> None: + for idx in indices: + self.page_cache.free(idx) def _remove_finished_tasks(self) -> None: - """Removes all finished tasks from the active batch. - - Releases prefix cache references and frees the KV cache slot - for each completed task. - """ finished = [] for task in self.active_tasks: if task.is_finished(self.tokenizer.stop_ids): @@ -321,25 +185,15 @@ class InferenceScheduler: self._total_tokens += task.output_tokens for task in finished: - if task.prefix_len > 0: - prefix = tuple(task.prompt_ids[: task.prefix_len]) - self.prefix_cache.release(prefix) - if task.prefix_len < len(task.prompt_ids): - self.prefix_cache.release(tuple(task.prompt_ids)) - if task.slot >= 0: - self._free_slot(task.slot) - task.slot = -1 + self._free_pages(task.page_table) + task.page_table.clear() + task.n_pages = 0 self.active_tasks = [ t for t in self.active_tasks if t.status != TaskStatus.FINISHED ] def _refill_active_batch(self) -> None: - """Moves waiting tasks into the active batch, up to max_batch_size. - - Attempts direct slot reuse for prefix-matched tasks; falls back - to allocating a fresh slot with KV cache copy when reuse is not possible. - """ available = self.max_batch_size - len(self.active_tasks) if available <= 0: return @@ -350,122 +204,71 @@ class InferenceScheduler: for _ in range(n): to_add.append(self.waiting_queue.pop(0)) - for i, task in enumerate(to_add): - slot = -1 - reused = False - if task.prefix_len > 0: - prefix = tuple(task.prompt_ids[: task.prefix_len]) - cached_slot, reused = self._try_reuse_slot(prefix) - if reused: - slot = cached_slot - if slot < 0: - slot = self._alloc_slot() - if slot < 0: - with self._lock: - self.waiting_queue[:0] = to_add[i:] - break - task.slot = slot + for task in to_add: + prompt_len = len(task.prompt_ids) + n_pages = self._n_pages_for(prompt_len) + task.page_table = self.page_cache.alloc_n(n_pages) + if not task.page_table: + with self._lock: + self.waiting_queue.insert(0, task) + break + task.n_pages = len(task.page_table) task.status = TaskStatus.RUNNING self.active_tasks.append(task) - if task.prefix_len > 0 and not reused: - prefix = tuple(task.prompt_ids[: task.prefix_len]) - _plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix)) - if cached_slot >= 0 and cached_ver == self.slot_allocator.version( - cached_slot - ): - self.prefix_cache.pin(prefix) - self.prefix_cache.copy_kv( - prefix, slot, self.kv_cache, self._n_layers - ) - else: - task.prefix_len = 0 - - def _execute_prefill(self, tasks: List[Task]) -> None: - """Runs batched prefill for newly activated tasks. - - Fully-cached tasks skip the model. Others are grouped by prefix_len - so tasks sharing the same start_pos are batched together. - """ - if not tasks: + def _execute_prefill(self) -> None: + to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] + if not to_prefill: return + for t in to_prefill: + prompt_len = len(t.prompt_ids) + t.input_tokens = prompt_len + t.output_tokens = 0 + groups: Dict[int, List[Task]] = {} - for t in tasks: - plen = len(t.prompt_ids) - if t.prefix_len == plen: - t.input_tokens = plen - t.output_tokens = 0 - if t.slot >= 0: - self.seq_mask[t.slot, : t.input_tokens] = True - else: - groups.setdefault(t.prefix_len, []).append(t) + for t in to_prefill: + groups.setdefault(len(t.prompt_ids), []).append(t) - for prefix_len, group in groups.items(): - slot_indices = torch.tensor([t.slot for t in group], device=self.device) - self._execute_prefill_batch(group, prefix_len, slot_indices) + for prompt_len, group in groups.items(): + self._execute_prefill_batch(group, prompt_len) - def _execute_prefill_batch( - self, tasks: List[Task], prefix_len: int, slot_indices: Tensor - ) -> None: - """Unified prefill for tasks sharing a common prefix_len. - - Args: - tasks: Tasks with the same prefix_len < len(prompt_ids). - prefix_len: Number of cached prefix tokens (0 for full prefill). - slot_indices: Tensor of slot indices for KV cache mapping. - """ - tasks = sorted(tasks, key=lambda t: t.slot) + def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None: + tasks = sorted(tasks, key=lambda t: t.task_id) batch_sz = len(tasks) - new_lens = [len(t.prompt_ids) - prefix_len for t in tasks] - max_new_len = max(new_lens) - input_ids = torch.zeros( - batch_sz, max_new_len, dtype=torch.long, device=self.device + batch_sz, + prompt_len, + dtype=torch.long, + device=self.device, ) - input_mask = torch.zeros( - batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device + input_mask = torch.ones( + batch_sz, + prompt_len, + dtype=torch.bool, + device=self.device, ) for i, t in enumerate(tasks): - new_ids = t.prompt_ids[prefix_len:] - nl = len(new_ids) - if nl > 0: - input_ids[i, :nl] = torch.tensor(new_ids, device=self.device) - input_mask[i, : prefix_len + nl] = True + input_ids[i] = torch.tensor(t.prompt_ids, device=self.device) + + page_tables = self._make_page_table_tensor(tasks) with torch.inference_mode(): self.model( input_ids, input_mask=input_mask, - start_pos=prefix_len, - persistent_key_values=self.kv_cache, - slot_indices=slot_indices, + start_pos=0, + paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) - for i, t in enumerate(tasks): - t.input_tokens = len(t.prompt_ids) - t.output_tokens = 0 - self.prefix_cache.insert( - tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot) - ) - if t.slot >= 0: - self.seq_mask[t.slot, : t.input_tokens] = True - def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: - """Executes the decode phase for a group of tasks at the same position. - - Args: - tasks: Tasks sharing the same next_pos value. - start_pos: Common KV cache write position for the batch. - """ if not tasks: return - tasks = sorted(tasks, key=lambda t: t.slot) + tasks = sorted(tasks, key=lambda t: t.task_id) batch_sz = len(tasks) - slot_indices = torch.tensor([t.slot for t in tasks], device=self.device) input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) for i, t in enumerate(tasks): @@ -473,13 +276,15 @@ class InferenceScheduler: active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) + page_tables = self._make_page_table_tensor(tasks) + total_len = start_pos + 1 + with torch.inference_mode(): outputs = self.model( input_ids.unsqueeze(1), input_mask=active_mask, - persistent_key_values=self.kv_cache, + paged_cache=self.page_cache.bind(page_tables, total_len=total_len), start_pos=start_pos, - slot_indices=slot_indices, ) logits = outputs["logits"][:, -1, :] @@ -496,23 +301,30 @@ class InferenceScheduler: t.output_ids.append(ntok) t.output_tokens += 1 pos = t.input_tokens + t.output_tokens - if t.slot >= 0 and pos < self.max_seq_len: - self.seq_mask[t.slot, pos] = True + self._maybe_alloc_page(t, pos) if t.stream_callback: t.stream_callback(self.tokenizer.decode([ntok])) for t in tasks: if t.is_finished(self.tokenizer.stop_ids): if t.stream_callback: - t.stream_callback(_STOP) + t.stream_callback(STOP) + + def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor: + max_pages = max(t.n_pages for t in tasks) + rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks] + return torch.tensor(rows, dtype=torch.long, device=self.device) + + def _maybe_alloc_page(self, task: Task, pos: int) -> None: + needed = self._n_pages_for(pos + 1) + while task.n_pages < needed: + p = self.page_cache.alloc() + if p < 0: + break + task.page_table.append(p) + task.n_pages += 1 def _run_generation_loop(self) -> None: - """Main generation loop run in a daemon thread. - - Continuously cycles through cleanup, refill, prefill, and decode. - Decode processes only the largest position group to ensure all - batched tasks share the same KV cache write position. - """ try: while self._running: self._remove_finished_tasks() @@ -523,11 +335,8 @@ class InferenceScheduler: self._task_event.clear() self._task_event.wait(timeout=0.01) continue - tasks = self.active_tasks[:] - to_prefill = [t for t in tasks if t.output_tokens == 0] - if to_prefill: - self._execute_prefill(to_prefill) + self._execute_prefill() pos_groups: Dict[int, List[Task]] = {} for t in self.active_tasks: @@ -544,21 +353,20 @@ class InferenceScheduler: logger.error(f"Scheduler loop crashed: {e}", exc_info=True) for task in self.active_tasks: if task.stream_callback: - task.stream_callback(_STOP) + task.stream_callback(STOP) for task in self.waiting_queue: if task.stream_callback: - task.stream_callback(_STOP) + task.stream_callback(STOP) raise def start(self) -> None: - """Starts the background generation loop thread.""" if not self._running: self._running = True t = threading.Thread(target=self._run_generation_loop, daemon=True) t.start() + self._loop_thread = t def stop(self) -> None: - """Stops the generation loop and releases all resources.""" self._running = False self._task_event.set() if hasattr(self, "_loop_thread"): @@ -569,11 +377,6 @@ class InferenceScheduler: torch.cuda.empty_cache() def get_stats(self) -> Dict[str, Any]: - """Returns current scheduler statistics. - - Returns: - Dict with total_tasks, total_tokens, active_tasks, waiting_queue. - """ return { "total_tasks": self._total_tasks, "total_tokens": self._total_tokens, diff --git a/astrai/model/module.py b/astrai/model/module.py index 10f50a2..9162c17 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -5,17 +5,11 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from astrai.inference.cache import CacheView + def repeat_kv(x: Tensor, n_rep: int) -> Tensor: - """ - Repeat k times along the dimension for attention heads. - Args: - x (Tensor): The input tensor. - n_rep (int): The number of repetitions. - Returns: - Tensor: The repeated tensor. - """ - + """Repeat KV heads n_rep times for GQA.""" bs, slen, n_heads, head_dim = x.shape if n_rep == 1: return x @@ -32,49 +26,25 @@ def get_rotary_emb( base: float = 10000, device: Optional[torch.device] = None, ) -> Tuple[Tensor, Tensor]: - """ - Get the rotary embedding for the given dimension and maximum length. - Args: - dim (int): The dimension of the input. - max_len (int): The maximum length of the input. - base (float, optional): The base for the frequency. Defaults to 10000. - device (optional): The device to create tensors on. Defaults to None. - Returns: - Tensor: The rotary embedding tensor. - """ - + """Precompute cos/sin for RoPE.""" theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) t = torch.arange(0, max_len, dtype=torch.float64, device=device) freqs = torch.outer(t, theta) - return torch.cos(freqs).float(), torch.sin(freqs).float() def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: - """ - Apply rotary embedding to the input tensor using cos/sin form. - Args: - x (Tensor): The input tensor (shape [..., seq_len, dim]). - rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]). - Returns: - Tensor: The output tensor (rotated, same shape as input). - """ - + """Apply rotary embedding via cos/sin (shape-preserving).""" dtype = x.dtype cos, sin = rotary_emb - - cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] - sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] - - x_real = x[..., 0::2] # [batch, seq_len, dim//2] - x_imag = x[..., 1::2] # [batch, seq_len, dim//2] - + cos = cos.unsqueeze(0).unsqueeze(2) + sin = sin.unsqueeze(0).unsqueeze(2) + x_real = x[..., 0::2] + x_imag = x[..., 1::2] x_real_rot = x_real * cos - x_imag * sin x_imag_rot = x_real * sin + x_imag * cos - - x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2] - x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim] - + x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) + x_out = x_out.view(*x_out.shape[:-2], -1) return x_out.to(dtype) @@ -95,13 +65,10 @@ class RotaryEmbedding(nn.Module): def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]: seq_len = x.size(1) - if self.max_len_cached < seq_len + start_pos: self._set_rotary_buffer(self.max_len_cached * 2, x.device) - cos = self.cos_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len] - return (cos, sin) @@ -185,14 +152,13 @@ class GQA(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], mask: Tensor = None, - kv_cache: Optional[Tuple[Tensor, Tensor]] = None, + paged_cache: Optional[CacheView] = None, start_pos: int = 0, - slot_indices: Optional[Tensor] = None, ) -> Tensor: bsz, seq_len, _ = x.size() is_causal = mask is None - # x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim) + # (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim) q = self._split_heads(self.q_proj(x), self.n_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads) @@ -201,18 +167,14 @@ class GQA(nn.Module): if self.use_qk_norm: q, k = self.q_norm(q), self.k_norm(k) - if kv_cache is not None: - k_cache, v_cache = kv_cache - k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k - v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v - k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id] - v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id] + if paged_cache is not None: + paged_cache.write(self.layer_id, start_pos, k, v) + k, v = paged_cache.gather(self.layer_id) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) - # (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim) sdqa_out = ( F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) .permute(0, 2, 1, 3) @@ -224,7 +186,6 @@ class GQA(nn.Module): sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) out = self.o_proj(sdqa_out) - return out @@ -257,7 +218,7 @@ class MLA(nn.Module): self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) - # KV (k_nope, k_rope, v) + # fused KV: (k_nope, k_rope, v) self.kv_b_proj = Linear( kv_lora_rank, n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), @@ -273,9 +234,8 @@ class MLA(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], mask: Tensor = None, - kv_cache: Optional[Tuple[Tensor, Tensor]] = None, + paged_cache: Optional[CacheView] = None, start_pos: int = 0, - slot_indices: Optional[Tensor] = None, ) -> Tensor: bsz, seq_len, _ = x.size() is_causal = mask is None @@ -303,12 +263,9 @@ class MLA(nn.Module): q = torch.cat([q_nope, q_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1) - if kv_cache is not None: - k_cache, v_cache = kv_cache - k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k - v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v - k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id] - v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id] + if paged_cache is not None: + paged_cache.write(self.layer_id, start_pos, k, v) + k, v = paged_cache.gather(self.layer_id) q = q.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3) @@ -321,7 +278,6 @@ class MLA(nn.Module): attn_out = attn_out * F.sigmoid(self.gate(x)) out = self.o_proj(attn_out) - return out @@ -356,24 +312,19 @@ class DecoderBlock(nn.Module): x: Tensor, rotary_emb: Tuple[Tensor, Tensor], attention_mask: Optional[Tensor] = None, - kv_cache: Optional[Tuple[Tensor, Tensor]] = None, + paged_cache: Optional[CacheView] = None, start_pos: int = 0, - slot_indices: Optional[Tensor] = None, ) -> Tensor: - # attention attn_output = self.attention( self.input_norm(x), rotary_emb, attention_mask, - kv_cache, + paged_cache, start_pos, - slot_indices, ) x = attn_output + x - # feed forward x = self.mlp(self.post_attention_norm(x)) + x - return x diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 454adc7..6b243f0 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -1,10 +1,11 @@ -from typing import Any, Mapping, Optional, Tuple +from typing import Any, Mapping, Optional import torch import torch.nn as nn from torch import Tensor from astrai.config.model_config import ModelConfig +from astrai.inference.cache import CacheView from astrai.model.automodel import AutoModel from astrai.model.module import ( DecoderBlock, @@ -21,39 +22,25 @@ def process_attention_mask( start_pos: int = 0, is_causal: bool = False, ) -> Tensor: - """ - Create attention mask for GQA - Args: - seq_mask (Tensor): A tensor indicating whether each position is valid or not. - input_tensor (Tensor): The input tensor. - start_pos (int): The starting position of the sequence. - is_causal (bool): Whether the attention is causal or not. - Returns: - Tensor: The attention mask tensor. - """ + """Build 4D attention mask from 2D seq_mask, with optional causal masking.""" device = input_tensor.device dtype = input_tensor.dtype seq_len = input_tensor.size(1) if seq_mask is None: if start_pos != 0: - # for single prompt chat seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) else: return None if seq_mask.dim() > 2: - # shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos) - # if ndim > 2, it's 4D tensor return seq_mask batch_size = seq_mask.size(0) seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool) - # (bsz, start_pos + seq_len) expanded_mask = seq_mask.unsqueeze(1).expand( batch_size, seq_len, start_pos + seq_len ) - # (bsz, seq_len, start_pos + seq_len) if is_causal: expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) @@ -62,16 +49,13 @@ def process_attention_mask( attention_mask = attention_mask.masked_fill_( ~expanded_mask, -torch.finfo(dtype).max / 2 ).unsqueeze(1) - # (bsz, 1, seq_len, seq_len + start_pos) return attention_mask @AutoModel.register("transformer") class Transformer(AutoModel): - """ - Transformer language model. - """ + """Transformer language model with paged KV cache.""" def __init__(self, config: ModelConfig): super().__init__(config) @@ -114,18 +98,15 @@ class Transformer(AutoModel): lm_head_key = "lm_head.weight" embed_key = "embed_tokens.weight" - # Make a copy to avoid modifying the original state_dict state_dict = dict(state_dict) if self.config.tie_weight: - # same tensor + # same tensor for embed and lm_head if embed_key in state_dict: state_dict[lm_head_key] = state_dict[embed_key] else: - # If lm_head.weight exists in checkpoint, use it directly - # If not, copy from embed_tokens.weight if lm_head_key not in state_dict and embed_key in state_dict: - # use clone to avoid sharing the same tensor + # clone to avoid sharing gradients state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) return super().load_state_dict(state_dict, strict, assign) @@ -146,9 +127,8 @@ class Transformer(AutoModel): self, input_ids: Tensor, input_mask: Optional[Tensor] = None, - persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None, + paged_cache: Optional[CacheView] = None, start_pos: int = 0, - slot_indices: Optional[Tensor] = None, ) -> Tensor: assert input_ids.ndim == 2 @@ -157,13 +137,8 @@ class Transformer(AutoModel): attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) - if slot_indices is None: - slot_indices = slice(input_ids.size(0)) - for layer in self.layers: - x = layer( - x, rotary_emb, attn_mask, persistent_key_values, start_pos, slot_indices - ) + x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos) hidden_states = self.norm(x) logits = self.lm_head(hidden_states) diff --git a/tests/inference/test_scheduler_concurrency.py b/tests/inference/test_scheduler_concurrency.py index d006bbd..3d771c3 100644 --- a/tests/inference/test_scheduler_concurrency.py +++ b/tests/inference/test_scheduler_concurrency.py @@ -6,102 +6,7 @@ from unittest.mock import MagicMock, patch import pytest -from astrai.inference.cache import PrefixCacheManager -from astrai.inference.scheduler import ( - InferenceScheduler, -) - - -def test_prefix_cache_concurrent_insert_find(): - """Test concurrent insert and find operations.""" - cache = PrefixCacheManager(max_capacity=100) - - results = {"errors": [], "inserts": 0, "finds": 0} - - def insert_worker(): - try: - for i in range(50): - cache.insert((i,), slot=i % 10, slot_ver=0) - results["inserts"] += 1 - except Exception as e: - results["errors"].append(str(e)) - - def find_worker(): - try: - for i in range(50): - cache.find([i]) - results["finds"] += 1 - except Exception as e: - results["errors"].append(str(e)) - - threads = [threading.Thread(target=insert_worker) for _ in range(3)] - threads += [threading.Thread(target=find_worker) for _ in range(3)] - - for t in threads: - t.start() - for t in threads: - t.join() - - assert len(results["errors"]) == 0, f"Errors: {results['errors']}" - assert results["inserts"] == 150 - assert results["finds"] == 150 - - -def test_prefix_cache_concurrent_release(): - """Test concurrent release operations.""" - cache = PrefixCacheManager(max_capacity=100) - - # Insert some prefixes - for i in range(10): - cache.insert((i,), slot=i, slot_ver=0) - - results = {"errors": []} - - def release_worker(): - try: - for i in range(10): - cache.release((i,)) - except Exception as e: - results["errors"].append(str(e)) - - threads = [threading.Thread(target=release_worker) for _ in range(3)] - - for t in threads: - t.start() - for t in threads: - t.join() - - assert len(results["errors"]) == 0, f"Errors: {results['errors']}" - - -def test_prefix_cache_concurrent_insert_release_find(): - """Test mixed concurrent operations.""" - cache = PrefixCacheManager(max_capacity=50) - - results = {"errors": []} - - def worker(worker_id): - try: - for i in range(20): - token_ids = (worker_id * 100 + i,) - cache.insert(token_ids, slot=worker_id, slot_ver=0) - - # Find after insert - cache.find(list(token_ids)) - - # Release - cache.release(token_ids) - except Exception as e: - results["errors"].append(f"Worker {worker_id}: {str(e)}") - - threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] - - for t in threads: - t.start() - for t in threads: - t.join() - - assert len(results["errors"]) == 0, f"Errors: {results['errors']}" +from astrai.inference.scheduler import InferenceScheduler @pytest.fixture @@ -266,54 +171,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer): for stats in results["stats"]: assert "total_tasks" in stats assert stats["total_tasks"] >= 0 - - -def test_prefix_cache_insert_same_prefix_concurrently(): - """Test inserting the same prefix concurrently.""" - cache = PrefixCacheManager(max_capacity=100) - - results = {"slot_values": [], "errors": []} - - def insert_worker(): - try: - # All workers try to insert the same prefix - cache.insert((1, 2, 3), slot=0, slot_ver=0) - node = cache.root.children.get(1) - if node: - node = node.children.get(2) - if node: - node = node.children.get(3) - if node: - results["slot_values"].append(node.slot) - except Exception as e: - results["errors"].append(str(e)) - - threads = [threading.Thread(target=insert_worker) for _ in range(10)] - - for t in threads: - t.start() - for t in threads: - t.join() - - # All inserts should succeed, final slot should be one of the values - assert len(results["errors"]) == 0, f"Errors: {results['errors']}" - # Check ref_count is correct (should be 10) - node = cache.root.children.get(1).children.get(2).children.get(3) - assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}" - - -def test_prefix_cache_ref_count_underflow_prevention(): - """Test that ref_count doesn't go negative.""" - cache = PrefixCacheManager(max_capacity=100) - - cache.insert((1, 2, 3), slot=0, slot_ver=0) - - # Release multiple times - for _ in range(5): - cache.release((1, 2, 3)) - - # Try to find it - should return None since ref_count would be negative - # or handle it gracefully - node = cache.root.children.get(1).children.get(2).children.get(3) - # The ref_count should be 0, not negative - assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"