refactor: 分页 KV cache 替换固定 slot,删除 PrefixCache 及相关死代码

- 用 PagedCache + CacheView 替换固定 slot 式 KV cache,attention 层只通过 page_table 间接索引
- 删除 PrefixCache(radix tree)及 scheduler 中所有 prefix cache 命中/插入/释放逻辑
- 删除无用函数:pin、version、free_count、_mark_seq_mask 及 seq_mask 分配
- 修复 write 在多页 prefill 时 offset 为负导致 chunk 计算错误
- _make_page_table_tensor 改用 list 拼接一次 tensor,去掉逐元素赋值
- 清理 model 接口参数:kv_cache, slot_indices → paged_cache(CacheView)
- 精简 docstring 为单行,删除冗余 section 注释和旧代码
- 修复 test_scheduler_concurrency.py 缺少 import pytest
This commit is contained in:
ViperEkura 2026-05-08 20:44:05 +08:00
parent 7ddebf2cd9
commit 30cc2d67a4
7 changed files with 244 additions and 777 deletions

View File

@ -1,241 +1,135 @@
"""KV cache slot allocation and prefix cache management. """Page-based KV cache with page-table-indirected read/write.
Provides: Provides:
- SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask. - PagedCache: paged KV cache combining page pool and tensor storage.
- PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse.
""" """
import time from typing import List, Tuple
from collections import OrderedDict
from typing import Dict, List, Tuple
import torch
from torch import Tensor from torch import Tensor
_STOP = object() STOP = object()
class _RadixNode: class PagedCache:
"""Internal node for the radix tree prefix cache. """Paged KV cache with page-table-indirected read/write.
Attributes: Combines:
children: Mapping from token ID to child node. - Page pool (ref-counted alloc/free via bitmask)
slot: KV cache slot index for the prefix ending at this node. - KV tensor storage (k_cache, v_cache)
slot_ver: Version counter of the slot at insertion time.
ref_count: Number of tasks currently referencing this node. Call :meth:`bind` to obtain a batch view for the attention layers.
last_access: Timestamp of the most recent access (for LRU ordering).
""" """
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access") def __init__(
self,
def __init__(self): n_layers: int,
self.children: Dict[int, "_RadixNode"] = {} n_pages: int,
self.slot: int = -1 page_size: int,
self.slot_ver: int = 0 n_kv_heads: int,
self.ref_count: int = 0 head_dim: int,
self.last_access: float = 0.0 device: torch.device,
dtype: torch.dtype,
):
class SlotAllocator: self.page_size = page_size
"""KV cache slot allocator using bitmask for O(1) alloc/free. self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
Implements the Object Pool pattern: pre-allocated KV cache slots self.k_cache = torch.empty(
are managed via a bitmask, providing constant-time allocation and (n_layers, n_pages, page_size, n_kv_heads, head_dim),
deallocation with version counters for staleness detection. device=device,
""" dtype=dtype,
)
def __init__(self, max_slots: int): self.v_cache = torch.empty(
self._max_slots = max_slots (n_layers, n_pages, page_size, n_kv_heads, head_dim),
self._free_mask = (1 << max_slots) - 1 device=device,
self._versions: List[int] = [0] * max_slots dtype=dtype,
)
def alloc(self) -> int: def alloc(self) -> int:
"""Allocates a free slot.
Returns:
Slot index on success, -1 if all slots are occupied.
"""
lsb = self._free_mask & -self._free_mask lsb = self._free_mask & -self._free_mask
if lsb == 0: if lsb == 0:
return -1 return -1
idx = lsb.bit_length() - 1 idx = lsb.bit_length() - 1
self._free_mask ^= lsb self._free_mask ^= lsb
self._versions[idx] += 1 self._refs[idx] = 1
return idx return idx
def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)]
if any(p < 0 for p in pages):
for p in pages:
if p >= 0:
self.free(p)
return []
return pages
def free(self, idx: int) -> None: def free(self, idx: int) -> None:
"""Releases a slot back to the free pool.""" self._refs[idx] -= 1
if self._refs[idx] == 0:
self._free_mask |= 1 << idx self._free_mask |= 1 << idx
def occupy(self, idx: int) -> None: def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
"""Marks a currently free slot as occupied without bumping its version. return CacheView(self, page_table, total_len)
Used for direct slot reuse when a prefix-cached slot is still valid. def write(
""" self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
self._free_mask ^= 1 << idx
def is_free(self, idx: int) -> bool:
"""Checks whether a slot is currently free."""
return (self._free_mask >> idx) & 1 == 1
def version(self, idx: int) -> int:
"""Returns the current version counter for a slot."""
return self._versions[idx]
@property
def free_count(self) -> int:
"""Returns the number of currently free slots."""
return self._free_mask.bit_count()
class PrefixCacheManager:
"""Radix-tree prefix cache with LRU eviction.
Maps token ID sequences to KV cache slots. Intermediate tree nodes
also store slot information, allowing direct slot reuse when the
cached slot is free and its version matches (no intervening writes).
"""
def __init__(self, max_capacity: int = 1000):
"""Initializes the prefix cache.
Args:
max_capacity: Maximum number of nodes in the LRU list.
"""
self.root = _RadixNode()
self.max_capacity = max_capacity
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
"""Inserts a token sequence into the prefix cache.
Every node along the path records the slot and its version,
enabling direct slot reuse for partial prefix matches.
Args:
token_ids: The token ID sequence to cache.
slot: The KV cache slot containing this prefix's computed keys/values.
slot_ver: The slot version at insertion time, used for staleness detection.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
nxt = _RadixNode()
node.children[tid] = nxt
node = nxt
node.slot = slot
node.slot_ver = slot_ver
node.last_access = time.time()
self._lru[id(node)] = node
node.ref_count += 1
self._evict_if_needed()
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
"""Finds the longest matching prefix in the cache.
Walks the radix tree token by token, recording the deepest match.
Args:
token_ids: The token sequence to match against.
Returns:
Tuple of (prefix_len, slot, slot_ver):
prefix_len: Number of matching tokens (0 if no match).
slot: KV cache slot of the matched prefix (-1 if no match).
slot_ver: Version of that slot when the prefix was inserted.
"""
node = self.root
best_len, best_slot, best_ver = 0, -1, 0
for i, tid in enumerate(token_ids):
nxt = node.children.get(tid)
if nxt is None:
break
node = nxt
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
node.last_access = time.time()
self._lru.move_to_end(id(node))
return best_len, best_slot, best_ver
def pin(self, token_ids: Tuple[int, ...]) -> None:
"""Increments the reference count of a cached prefix.
Called when a task reuses a cached prefix to prevent eviction.
Args:
token_ids: The token sequence whose node's ref_count to increment.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
node.ref_count += 1
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Decrements the reference count of a cached prefix.
The node's slot is preserved even when ref_count reaches zero,
allowing future tasks to reuse the slot directly if it remains free.
Args:
token_ids: The token sequence whose node's ref_count to decrement.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
if node.ref_count > 0:
node.ref_count -= 1
def copy_kv(
self,
token_ids: Tuple[int, ...],
target_slot: int,
kv_cache: Tuple[Tensor, Tensor],
n_layers: int,
) -> None: ) -> None:
"""Copies cached KV data from the source slot to a target slot. seq_len = k.size(1)
if seq_len == 0:
Args:
token_ids: The prefix token sequence identifying the source cache node.
target_slot: The destination KV cache slot to copy into.
kv_cache: Tuple of (k_cache, v_cache) tensors.
n_layers: Number of transformer layers to copy.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return return
node = nxt page_size = self.page_size
src_slot = node.slot written = 0
if src_slot < 0: first_page = start_pos // page_size
return last_page = (start_pos + seq_len - 1) // page_size
prefix_len = len(token_ids) for pi in range(first_page, last_page + 1):
k_cache, v_cache = kv_cache phys_pages = page_table[:, pi]
for li in range(n_layers): page_start = pi * page_size
k_cache[target_slot, :prefix_len, li].copy_( write_start = max(page_start, start_pos)
k_cache[src_slot, :prefix_len, li] write_end = min(page_start + page_size, start_pos + seq_len)
) offset = write_start - page_start
v_cache[target_slot, :prefix_len, li].copy_( chunk = write_end - write_start
v_cache[src_slot, :prefix_len, li] self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
) :, written : written + chunk
]
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
:, written : written + chunk
]
written += chunk
def _evict_if_needed(self) -> None: def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
"""Evicts least-recently-used nodes until under capacity. k_parts, v_parts = [], []
for pi in range(page_table.size(1)):
phys_pages = page_table[:, pi]
if not (phys_pages >= 0).any():
break
k_parts.append(self.k_cache[layer_id, phys_pages])
v_parts.append(self.v_cache[layer_id, phys_pages])
k = torch.cat(k_parts, dim=1)
v = torch.cat(v_parts, dim=1)
return k, v
Skips nodes with ref_count > 0 (still in use by active tasks).
Evicted nodes have their slot and children cleared. class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len.
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
""" """
while len(self._lru) > self.max_capacity:
key, node = next(iter(self._lru.items())) __slots__ = ("_cache", "_page_table", "_total_len")
if node.ref_count > 0:
self._lru.move_to_end(key) def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
continue self._cache = cache
self._lru.pop(key) self._page_table = page_table
node.slot = -1 self._total_len = total_len
node.slot_ver = 0
node.children.clear() def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v

View File

@ -16,7 +16,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from astrai.inference.cache import _STOP from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
@ -118,15 +118,15 @@ class _Result:
"""Appends a token to the result buffer. """Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx]. In non-streaming mode, tokens are concatenated into results[idx].
The sentinel _STOP marks a task as complete. The sentinel STOP marks a task as complete.
Args: Args:
token: The decoded token string, or _STOP sentinel. token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to. idx: Index of the generation task this token belongs to.
""" """
with self._lock: with self._lock:
self.tokens.append(token) self.tokens.append(token)
if token is not _STOP: if token is not STOP:
self.results[idx] += token self.results[idx] += token
else: else:
if not self._done[idx]: if not self._done[idx]:
@ -186,38 +186,28 @@ class InferenceEngine:
max_batch_size: int = 1, max_batch_size: int = 1,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048, max_prompt_len: int = 2048,
cache_capacity: int = 1000, page_size: int = 128,
): ):
"""Initializes the engine and starts the scheduler background thread. """Initializes the inference engine.
Args: Args:
model: The language model (nn.Module, e.g. Transformer). model: The model instance.
tokenizer: Tokenizer for encoding/decoding. tokenizer: The tokenizer instance.
max_batch_size: Maximum concurrent tasks in the scheduler. max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length (defaults to model config). max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens (longer prompts truncated). max_prompt_len: Maximum prompt tokens.
cache_capacity: Maximum prefix cache nodes. compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
""" """
try:
first_param = next(model.parameters())
device = first_param.device
dtype = first_param.dtype
except StopIteration:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float32
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.scheduler = InferenceScheduler( self.scheduler = InferenceScheduler(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,
cache_capacity=cache_capacity, page_size=page_size,
device=device,
dtype=dtype,
) )
self.scheduler.start() self.scheduler.start()
@ -383,7 +373,7 @@ class InferenceEngine:
while True: while True:
tokens = result.pop_all() tokens = result.pop_all()
for token in tokens: for token in tokens:
if token is _STOP: if token is STOP:
return return
yield token yield token
if not result.wait(timeout=0.05): if not result.wait(timeout=0.05):

View File

@ -9,7 +9,7 @@ parameters, so a single pipeline works for any batch size.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Union from typing import List, Union
import torch import torch
from torch import Tensor from torch import Tensor

View File

@ -1,24 +1,19 @@
"""Inference scheduler for single-GPU continuous batching. """Inference scheduler for single-GPU continuous batching with paged KV cache."""
Splits scheduling concerns across modules:
- cache.py: SlotAllocator (Object Pool), PrefixCacheManager
- sampling.py: Strategy-pattern logit transformations
"""
import logging import logging
import threading import threading
import time import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional
import torch import torch
from torch import Tensor from torch import Tensor
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator from astrai.inference.cache import STOP, PagedCache
from astrai.inference.sampling import sample from astrai.inference.sampling import sample
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -33,29 +28,7 @@ class TaskStatus(Enum):
class Task: class Task:
"""Represents a single generation request within the scheduler. """Represents a single generation request with paged KV cache tracking."""
Tracks prompt tokens, generated output, sampling parameters,
KV cache slot assignment, and prefix cache matching state.
"""
__slots__ = (
"task_id",
"prompt_ids",
"max_tokens",
"temperature",
"top_p",
"top_k",
"status",
"output_ids",
"input_tokens",
"output_tokens",
"slot",
"prefix_len",
"arrival_time",
"finish_time",
"stream_callback",
)
def __init__( def __init__(
self, self,
@ -67,17 +40,6 @@ class Task:
top_k: int = 50, top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None, stream_callback: Optional[Callable[[str], None]] = None,
): ):
"""Initializes a new task.
Args:
task_id: Unique identifier for this task.
prompt_ids: Tokenized prompt sequence.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
stream_callback: Optional callback invoked per decoded token.
"""
self.task_id = task_id self.task_id = task_id
self.prompt_ids = prompt_ids self.prompt_ids = prompt_ids
self.max_tokens = max_tokens self.max_tokens = max_tokens
@ -89,26 +51,17 @@ class Task:
self.output_ids: List[int] = [] self.output_ids: List[int] = []
self.input_tokens: int = 0 self.input_tokens: int = 0
self.output_tokens: int = 0 self.output_tokens: int = 0
self.slot: int = -1 self.page_table: List[int] = []
self.prefix_len: int = 0 self.n_pages: int = 0
self.arrival_time = time.time() self.arrival_time = time.time()
self.finish_time: Optional[float] = None self.finish_time: Optional[float] = None
self.stream_callback = stream_callback self.stream_callback = stream_callback
@property @property
def next_pos(self) -> int: def next_pos(self) -> int:
"""Returns the next KV cache position to write during decode."""
return self.input_tokens + len(self.output_ids) return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool: def is_finished(self, stop_ids: List[int]) -> bool:
"""Checks whether the task has reached a stopping condition.
Args:
stop_ids: List of stop token IDs (e.g., EOS).
Returns:
True if max_tokens reached or the last output token is a stop ID.
"""
if self.output_tokens >= self.max_tokens: if self.output_tokens >= self.max_tokens:
return True return True
if self.output_ids and self.output_ids[-1] in stop_ids: if self.output_ids and self.output_ids[-1] in stop_ids:
@ -117,16 +70,13 @@ class Task:
class InferenceScheduler: class InferenceScheduler:
"""Continuous batching scheduler for single-GPU inference. """Continuous batching scheduler with paged KV cache.
Runs a background generation loop with four phases per iteration: Runs a background generation loop with four phases per iteration:
1. Cleanup finished tasks and release resources. 1. Cleanup finished tasks and release resources.
2. Refill active batch from the waiting queue. 2. Refill active batch from the waiting queue.
3. Prefill newly activated tasks (full, partial, or fully cached). 3. Prefill newly activated tasks.
4. Decode the largest same-position group of active tasks. 4. Decode the largest same-position group of active tasks.
Tasks at different positions are never batched together in decode,
avoiding RoPE corruption from misaligned KV cache writes.
""" """
def __init__( def __init__(
@ -136,22 +86,10 @@ class InferenceScheduler:
max_batch_size: int = 16, max_batch_size: int = 16,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 512, max_prompt_len: int = 512,
cache_capacity: int = 1000, page_size: int = 64,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
): ):
"""Initializes the scheduler and pre-allocates the KV cache.
Args:
model: The language model (must have config with n_layers, n_kv_heads, etc.).
tokenizer: Tokenizer for encoding prompts and decoding outputs.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length (defaults to config.max_len).
max_prompt_len: Maximum prompt tokens (longer prompts are truncated).
cache_capacity: Maximum prefix cache node count.
device: Target device for tensors.
dtype: Data type for KV cache tensors.
"""
config = model.config config = model.config
self.model = model self.model = model
@ -159,35 +97,25 @@ class InferenceScheduler:
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_seq_len = max_seq_len or config.max_len self.max_seq_len = max_seq_len or config.max_len
self.max_prompt_len = max_prompt_len self.max_prompt_len = max_prompt_len
self.page_size = page_size
self.device = device or next(model.parameters()).device self.device = device or next(model.parameters()).device
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
n_kv_heads = config.n_kv_heads n_kv_heads = config.n_kv_heads
head_dim = config.dim // config.n_heads head_dim = config.dim // config.n_heads
n_layers = config.n_layers n_layers = config.n_layers
self._n_layers = n_layers n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
k_cache = torch.empty( self.page_cache = PagedCache(
(max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim), n_layers,
device=self.device, n_pages,
dtype=self.dtype, page_size,
) n_kv_heads,
v_cache = torch.empty( head_dim,
(max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim), self.device,
device=self.device, self.dtype,
dtype=self.dtype,
)
self.kv_cache = (k_cache, v_cache)
self.seq_mask = torch.zeros(
(max_batch_size, self.max_seq_len),
device=self.device,
dtype=torch.bool,
) )
self.slot_allocator = SlotAllocator(max_batch_size)
self.waiting_queue: List[Task] = [] self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = [] self.active_tasks: List[Task] = []
@ -198,41 +126,8 @@ class InferenceScheduler:
self._total_tasks = 0 self._total_tasks = 0
self._total_tokens = 0 self._total_tokens = 0
def _alloc_slot(self) -> int: def _n_pages_for(self, n_tokens: int) -> int:
"""Allocates a free KV cache slot using the Object Pool. return (n_tokens + self.page_size - 1) // self.page_size
Returns:
Slot index on success, -1 if all slots are occupied.
"""
return self.slot_allocator.alloc()
def _free_slot(self, idx: int) -> None:
"""Releases a KV cache slot back to the free pool.
Args:
idx: Slot index to free.
"""
self.slot_allocator.free(idx)
self.seq_mask[idx, :] = False
def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]:
"""Attempts to reuse a prefix-cached slot directly without KV copy.
The slot is reusable only if it is free and its version matches
the current slot version (no intervening allocation overwrote it).
Args:
prefix: The matched prefix token sequence.
Returns:
Tuple of (slot, True) on success, or (-1, False) if reuse is not possible.
"""
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot):
if cached_ver == self.slot_allocator.version(cached_slot):
self.slot_allocator.occupy(cached_slot)
return cached_slot, True
return -1, False
def add_task( def add_task(
self, self,
@ -243,25 +138,8 @@ class InferenceScheduler:
top_k: int = 50, top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None, stream_callback: Optional[Callable[[str], None]] = None,
) -> str: ) -> str:
"""Adds a generation task to the waiting queue.
Encodes the prompt, queries the prefix cache for a match,
and enqueues the task for the background generation loop.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
stream_callback: Called per decoded token with the string representation.
Returns:
Unique task ID string.
"""
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt) prompt_ids = self.tokenizer.encode(prompt)
if len(prompt_ids) > self.max_prompt_len: if len(prompt_ids) > self.max_prompt_len:
prompt_ids = prompt_ids[-self.max_prompt_len :] prompt_ids = prompt_ids[-self.max_prompt_len :]
@ -275,9 +153,6 @@ class InferenceScheduler:
stream_callback=stream_callback, stream_callback=stream_callback,
) )
prefix_len, _cached_slot, _cached_ver = self.prefix_cache.find(prompt_ids)
task.prefix_len = prefix_len
with self._lock: with self._lock:
self.waiting_queue.append(task) self.waiting_queue.append(task)
self._total_tasks += 1 self._total_tasks += 1
@ -286,32 +161,21 @@ class InferenceScheduler:
return task_id return task_id
def remove_task(self, task_id: str) -> None: def remove_task(self, task_id: str) -> None:
"""Removes a task from both the waiting queue and active tasks.
Args:
task_id: The task to remove.
"""
with self._lock: with self._lock:
removed_active = [t for t in self.active_tasks if t.task_id == task_id] removed_active = [t for t in self.active_tasks if t.task_id == task_id]
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
for task in removed_active: for task in removed_active:
if task.prefix_len > 0: self._free_pages(task.page_table)
prefix = tuple(task.prompt_ids[: task.prefix_len]) task.page_table.clear()
self.prefix_cache.release(prefix) task.n_pages = 0
if task.prefix_len < len(task.prompt_ids):
self.prefix_cache.release(tuple(task.prompt_ids)) def _free_pages(self, indices: List[int]) -> None:
if task.slot >= 0: for idx in indices:
self._free_slot(task.slot) self.page_cache.free(idx)
task.slot = -1
def _remove_finished_tasks(self) -> None: def _remove_finished_tasks(self) -> None:
"""Removes all finished tasks from the active batch.
Releases prefix cache references and frees the KV cache slot
for each completed task.
"""
finished = [] finished = []
for task in self.active_tasks: for task in self.active_tasks:
if task.is_finished(self.tokenizer.stop_ids): if task.is_finished(self.tokenizer.stop_ids):
@ -321,25 +185,15 @@ class InferenceScheduler:
self._total_tokens += task.output_tokens self._total_tokens += task.output_tokens
for task in finished: for task in finished:
if task.prefix_len > 0: self._free_pages(task.page_table)
prefix = tuple(task.prompt_ids[: task.prefix_len]) task.page_table.clear()
self.prefix_cache.release(prefix) task.n_pages = 0
if task.prefix_len < len(task.prompt_ids):
self.prefix_cache.release(tuple(task.prompt_ids))
if task.slot >= 0:
self._free_slot(task.slot)
task.slot = -1
self.active_tasks = [ self.active_tasks = [
t for t in self.active_tasks if t.status != TaskStatus.FINISHED t for t in self.active_tasks if t.status != TaskStatus.FINISHED
] ]
def _refill_active_batch(self) -> None: def _refill_active_batch(self) -> None:
"""Moves waiting tasks into the active batch, up to max_batch_size.
Attempts direct slot reuse for prefix-matched tasks; falls back
to allocating a fresh slot with KV cache copy when reuse is not possible.
"""
available = self.max_batch_size - len(self.active_tasks) available = self.max_batch_size - len(self.active_tasks)
if available <= 0: if available <= 0:
return return
@ -350,122 +204,71 @@ class InferenceScheduler:
for _ in range(n): for _ in range(n):
to_add.append(self.waiting_queue.pop(0)) to_add.append(self.waiting_queue.pop(0))
for i, task in enumerate(to_add): for task in to_add:
slot = -1 prompt_len = len(task.prompt_ids)
reused = False n_pages = self._n_pages_for(prompt_len)
if task.prefix_len > 0: task.page_table = self.page_cache.alloc_n(n_pages)
prefix = tuple(task.prompt_ids[: task.prefix_len]) if not task.page_table:
cached_slot, reused = self._try_reuse_slot(prefix)
if reused:
slot = cached_slot
if slot < 0:
slot = self._alloc_slot()
if slot < 0:
with self._lock: with self._lock:
self.waiting_queue[:0] = to_add[i:] self.waiting_queue.insert(0, task)
break break
task.slot = slot task.n_pages = len(task.page_table)
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
self.active_tasks.append(task) self.active_tasks.append(task)
if task.prefix_len > 0 and not reused: def _execute_prefill(self) -> None:
prefix = tuple(task.prompt_ids[: task.prefix_len]) to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix)) if not to_prefill:
if cached_slot >= 0 and cached_ver == self.slot_allocator.version(
cached_slot
):
self.prefix_cache.pin(prefix)
self.prefix_cache.copy_kv(
prefix, slot, self.kv_cache, self._n_layers
)
else:
task.prefix_len = 0
def _execute_prefill(self, tasks: List[Task]) -> None:
"""Runs batched prefill for newly activated tasks.
Fully-cached tasks skip the model. Others are grouped by prefix_len
so tasks sharing the same start_pos are batched together.
"""
if not tasks:
return return
groups: Dict[int, List[Task]] = {} for t in to_prefill:
for t in tasks: prompt_len = len(t.prompt_ids)
plen = len(t.prompt_ids) t.input_tokens = prompt_len
if t.prefix_len == plen:
t.input_tokens = plen
t.output_tokens = 0 t.output_tokens = 0
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
else:
groups.setdefault(t.prefix_len, []).append(t)
for prefix_len, group in groups.items(): groups: Dict[int, List[Task]] = {}
slot_indices = torch.tensor([t.slot for t in group], device=self.device) for t in to_prefill:
self._execute_prefill_batch(group, prefix_len, slot_indices) groups.setdefault(len(t.prompt_ids), []).append(t)
def _execute_prefill_batch( for prompt_len, group in groups.items():
self, tasks: List[Task], prefix_len: int, slot_indices: Tensor self._execute_prefill_batch(group, prompt_len)
) -> None:
"""Unified prefill for tasks sharing a common prefix_len.
Args: def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
tasks: Tasks with the same prefix_len < len(prompt_ids). tasks = sorted(tasks, key=lambda t: t.task_id)
prefix_len: Number of cached prefix tokens (0 for full prefill).
slot_indices: Tensor of slot indices for KV cache mapping.
"""
tasks = sorted(tasks, key=lambda t: t.slot)
batch_sz = len(tasks) batch_sz = len(tasks)
new_lens = [len(t.prompt_ids) - prefix_len for t in tasks]
max_new_len = max(new_lens)
input_ids = torch.zeros( input_ids = torch.zeros(
batch_sz, max_new_len, dtype=torch.long, device=self.device batch_sz,
prompt_len,
dtype=torch.long,
device=self.device,
) )
input_mask = torch.zeros( input_mask = torch.ones(
batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device batch_sz,
prompt_len,
dtype=torch.bool,
device=self.device,
) )
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
new_ids = t.prompt_ids[prefix_len:] input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
nl = len(new_ids)
if nl > 0: page_tables = self._make_page_table_tensor(tasks)
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
input_mask[i, : prefix_len + nl] = True
with torch.inference_mode(): with torch.inference_mode():
self.model( self.model(
input_ids, input_ids,
input_mask=input_mask, input_mask=input_mask,
start_pos=prefix_len, start_pos=0,
persistent_key_values=self.kv_cache, paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
slot_indices=slot_indices,
) )
for i, t in enumerate(tasks):
t.input_tokens = len(t.prompt_ids)
t.output_tokens = 0
self.prefix_cache.insert(
tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot)
)
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
"""Executes the decode phase for a group of tasks at the same position.
Args:
tasks: Tasks sharing the same next_pos value.
start_pos: Common KV cache write position for the batch.
"""
if not tasks: if not tasks:
return return
tasks = sorted(tasks, key=lambda t: t.slot) tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks) batch_sz = len(tasks)
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
@ -473,13 +276,15 @@ class InferenceScheduler:
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
total_len = start_pos + 1
with torch.inference_mode(): with torch.inference_mode():
outputs = self.model( outputs = self.model(
input_ids.unsqueeze(1), input_ids.unsqueeze(1),
input_mask=active_mask, input_mask=active_mask,
persistent_key_values=self.kv_cache, paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
start_pos=start_pos, start_pos=start_pos,
slot_indices=slot_indices,
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
@ -496,23 +301,30 @@ class InferenceScheduler:
t.output_ids.append(ntok) t.output_ids.append(ntok)
t.output_tokens += 1 t.output_tokens += 1
pos = t.input_tokens + t.output_tokens pos = t.input_tokens + t.output_tokens
if t.slot >= 0 and pos < self.max_seq_len: self._maybe_alloc_page(t, pos)
self.seq_mask[t.slot, pos] = True
if t.stream_callback: if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok])) t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks: for t in tasks:
if t.is_finished(self.tokenizer.stop_ids): if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback: if t.stream_callback:
t.stream_callback(_STOP) t.stream_callback(STOP)
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _maybe_alloc_page(self, task: Task, pos: int) -> None:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
break
task.page_table.append(p)
task.n_pages += 1
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
"""Main generation loop run in a daemon thread.
Continuously cycles through cleanup, refill, prefill, and decode.
Decode processes only the largest position group to ensure all
batched tasks share the same KV cache write position.
"""
try: try:
while self._running: while self._running:
self._remove_finished_tasks() self._remove_finished_tasks()
@ -523,11 +335,8 @@ class InferenceScheduler:
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=0.01) self._task_event.wait(timeout=0.01)
continue continue
tasks = self.active_tasks[:]
to_prefill = [t for t in tasks if t.output_tokens == 0] self._execute_prefill()
if to_prefill:
self._execute_prefill(to_prefill)
pos_groups: Dict[int, List[Task]] = {} pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks: for t in self.active_tasks:
@ -544,21 +353,20 @@ class InferenceScheduler:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self.active_tasks: for task in self.active_tasks:
if task.stream_callback: if task.stream_callback:
task.stream_callback(_STOP) task.stream_callback(STOP)
for task in self.waiting_queue: for task in self.waiting_queue:
if task.stream_callback: if task.stream_callback:
task.stream_callback(_STOP) task.stream_callback(STOP)
raise raise
def start(self) -> None: def start(self) -> None:
"""Starts the background generation loop thread."""
if not self._running: if not self._running:
self._running = True self._running = True
t = threading.Thread(target=self._run_generation_loop, daemon=True) t = threading.Thread(target=self._run_generation_loop, daemon=True)
t.start() t.start()
self._loop_thread = t
def stop(self) -> None: def stop(self) -> None:
"""Stops the generation loop and releases all resources."""
self._running = False self._running = False
self._task_event.set() self._task_event.set()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
@ -569,11 +377,6 @@ class InferenceScheduler:
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Returns current scheduler statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return { return {
"total_tasks": self._total_tasks, "total_tasks": self._total_tasks,
"total_tokens": self._total_tokens, "total_tokens": self._total_tokens,

View File

@ -5,17 +5,11 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.inference.cache import CacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
""" """Repeat KV heads n_rep times for GQA."""
Repeat k times along the dimension for attention heads.
Args:
x (Tensor): The input tensor.
n_rep (int): The number of repetitions.
Returns:
Tensor: The repeated tensor.
"""
bs, slen, n_heads, head_dim = x.shape bs, slen, n_heads, head_dim = x.shape
if n_rep == 1: if n_rep == 1:
return x return x
@ -32,49 +26,25 @@ def get_rotary_emb(
base: float = 10000, base: float = 10000,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """Precompute cos/sin for RoPE."""
Get the rotary embedding for the given dimension and maximum length.
Args:
dim (int): The dimension of the input.
max_len (int): The maximum length of the input.
base (float, optional): The base for the frequency. Defaults to 10000.
device (optional): The device to create tensors on. Defaults to None.
Returns:
Tensor: The rotary embedding tensor.
"""
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device) t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta) freqs = torch.outer(t, theta)
return torch.cos(freqs).float(), torch.sin(freqs).float() return torch.cos(freqs).float(), torch.sin(freqs).float()
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor: def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
""" """Apply rotary embedding via cos/sin (shape-preserving)."""
Apply rotary embedding to the input tensor using cos/sin form.
Args:
x (Tensor): The input tensor (shape [..., seq_len, dim]).
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
Returns:
Tensor: The output tensor (rotated, same shape as input).
"""
dtype = x.dtype dtype = x.dtype
cos, sin = rotary_emb cos, sin = rotary_emb
cos = cos.unsqueeze(0).unsqueeze(2)
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] sin = sin.unsqueeze(0).unsqueeze(2)
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2] x_real = x[..., 0::2]
x_imag = x[..., 1::2]
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
x_real_rot = x_real * cos - x_imag * sin x_real_rot = x_real * cos - x_imag * sin
x_imag_rot = x_real * sin + x_imag * cos x_imag_rot = x_real * sin + x_imag * cos
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2] x_out = x_out.view(*x_out.shape[:-2], -1)
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
return x_out.to(dtype) return x_out.to(dtype)
@ -95,13 +65,10 @@ class RotaryEmbedding(nn.Module):
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]: def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
seq_len = x.size(1) seq_len = x.size(1)
if self.max_len_cached < seq_len + start_pos: if self.max_len_cached < seq_len + start_pos:
self._set_rotary_buffer(self.max_len_cached * 2, x.device) self._set_rotary_buffer(self.max_len_cached * 2, x.device)
cos = self.cos_cached[start_pos : start_pos + seq_len] cos = self.cos_cached[start_pos : start_pos + seq_len]
sin = self.sin_cached[start_pos : start_pos + seq_len] sin = self.sin_cached[start_pos : start_pos + seq_len]
return (cos, sin) return (cos, sin)
@ -185,14 +152,13 @@ class GQA(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim) # (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
q = self._split_heads(self.q_proj(x), self.n_heads) q = self._split_heads(self.q_proj(x), self.n_heads)
k = self._split_heads(self.k_proj(x), self.n_kv_heads) k = self._split_heads(self.k_proj(x), self.n_kv_heads)
v = self._split_heads(self.v_proj(x), self.n_kv_heads) v = self._split_heads(self.v_proj(x), self.n_kv_heads)
@ -201,18 +167,14 @@ class GQA(nn.Module):
if self.use_qk_norm: if self.use_qk_norm:
q, k = self.q_norm(q), self.k_norm(k) q, k = self.q_norm(q), self.k_norm(k)
if kv_cache is not None: if paged_cache is not None:
k_cache, v_cache = kv_cache paged_cache.write(self.layer_id, start_pos, k, v)
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k k, v = paged_cache.gather(self.layer_id)
v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep) k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim) # (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
sdqa_out = ( sdqa_out = (
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal) F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
.permute(0, 2, 1, 3) .permute(0, 2, 1, 3)
@ -224,7 +186,6 @@ class GQA(nn.Module):
sdqa_out = sdqa_out * F.sigmoid(self.gate(x)) sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
out = self.o_proj(sdqa_out) out = self.o_proj(sdqa_out)
return out return out
@ -257,7 +218,7 @@ class MLA(nn.Module):
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# KV (k_nope, k_rope, v) # fused KV: (k_nope, k_rope, v)
self.kv_b_proj = Linear( self.kv_b_proj = Linear(
kv_lora_rank, kv_lora_rank,
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
@ -273,9 +234,8 @@ class MLA(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
mask: Tensor = None, mask: Tensor = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = mask is None is_causal = mask is None
@ -303,12 +263,9 @@ class MLA(nn.Module):
q = torch.cat([q_nope, q_rope], dim=-1) q = torch.cat([q_nope, q_rope], dim=-1)
k = torch.cat([k_nope, k_rope], dim=-1) k = torch.cat([k_nope, k_rope], dim=-1)
if kv_cache is not None: if paged_cache is not None:
k_cache, v_cache = kv_cache paged_cache.write(self.layer_id, start_pos, k, v)
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k k, v = paged_cache.gather(self.layer_id)
v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
q = q.permute(0, 2, 1, 3) q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3) k = k.permute(0, 2, 1, 3)
@ -321,7 +278,6 @@ class MLA(nn.Module):
attn_out = attn_out * F.sigmoid(self.gate(x)) attn_out = attn_out * F.sigmoid(self.gate(x))
out = self.o_proj(attn_out) out = self.o_proj(attn_out)
return out return out
@ -356,24 +312,19 @@ class DecoderBlock(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tuple[Tensor, Tensor], rotary_emb: Tuple[Tensor, Tensor],
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
kv_cache: Optional[Tuple[Tensor, Tensor]] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
# attention
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), self.input_norm(x),
rotary_emb, rotary_emb,
attention_mask, attention_mask,
kv_cache, paged_cache,
start_pos, start_pos,
slot_indices,
) )
x = attn_output + x x = attn_output + x
# feed forward
x = self.mlp(self.post_attention_norm(x)) + x x = self.mlp(self.post_attention_norm(x)) + x
return x return x

View File

@ -1,10 +1,11 @@
from typing import Any, Mapping, Optional, Tuple from typing import Any, Mapping, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.inference.cache import CacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.module import ( from astrai.model.module import (
DecoderBlock, DecoderBlock,
@ -21,39 +22,25 @@ def process_attention_mask(
start_pos: int = 0, start_pos: int = 0,
is_causal: bool = False, is_causal: bool = False,
) -> Tensor: ) -> Tensor:
""" """Build 4D attention mask from 2D seq_mask, with optional causal masking."""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
input_tensor (Tensor): The input tensor.
start_pos (int): The starting position of the sequence.
is_causal (bool): Whether the attention is causal or not.
Returns:
Tensor: The attention mask tensor.
"""
device = input_tensor.device device = input_tensor.device
dtype = input_tensor.dtype dtype = input_tensor.dtype
seq_len = input_tensor.size(1) seq_len = input_tensor.size(1)
if seq_mask is None: if seq_mask is None:
if start_pos != 0: if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device) seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else: else:
return None return None
if seq_mask.dim() > 2: if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask return seq_mask
batch_size = seq_mask.size(0) batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool) seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand( expanded_mask = seq_mask.unsqueeze(1).expand(
batch_size, seq_len, start_pos + seq_len batch_size, seq_len, start_pos + seq_len
) )
# (bsz, seq_len, start_pos + seq_len)
if is_causal: if is_causal:
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos) expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
@ -62,16 +49,13 @@ def process_attention_mask(
attention_mask = attention_mask.masked_fill_( attention_mask = attention_mask.masked_fill_(
~expanded_mask, -torch.finfo(dtype).max / 2 ~expanded_mask, -torch.finfo(dtype).max / 2
).unsqueeze(1) ).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask return attention_mask
@AutoModel.register("transformer") @AutoModel.register("transformer")
class Transformer(AutoModel): class Transformer(AutoModel):
""" """Transformer language model with paged KV cache."""
Transformer language model.
"""
def __init__(self, config: ModelConfig): def __init__(self, config: ModelConfig):
super().__init__(config) super().__init__(config)
@ -114,18 +98,15 @@ class Transformer(AutoModel):
lm_head_key = "lm_head.weight" lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight" embed_key = "embed_tokens.weight"
# Make a copy to avoid modifying the original state_dict
state_dict = dict(state_dict) state_dict = dict(state_dict)
if self.config.tie_weight: if self.config.tie_weight:
# same tensor # same tensor for embed and lm_head
if embed_key in state_dict: if embed_key in state_dict:
state_dict[lm_head_key] = state_dict[embed_key] state_dict[lm_head_key] = state_dict[embed_key]
else: else:
# If lm_head.weight exists in checkpoint, use it directly
# If not, copy from embed_tokens.weight
if lm_head_key not in state_dict and embed_key in state_dict: if lm_head_key not in state_dict and embed_key in state_dict:
# use clone to avoid sharing the same tensor # clone to avoid sharing gradients
state_dict[lm_head_key] = torch.clone(state_dict[embed_key]) state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
return super().load_state_dict(state_dict, strict, assign) return super().load_state_dict(state_dict, strict, assign)
@ -146,9 +127,8 @@ class Transformer(AutoModel):
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None, paged_cache: Optional[CacheView] = None,
start_pos: int = 0, start_pos: int = 0,
slot_indices: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2
@ -157,13 +137,8 @@ class Transformer(AutoModel):
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True) attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
if slot_indices is None:
slot_indices = slice(input_ids.size(0))
for layer in self.layers: for layer in self.layers:
x = layer( x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
x, rotary_emb, attn_mask, persistent_key_values, start_pos, slot_indices
)
hidden_states = self.norm(x) hidden_states = self.norm(x)
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states)

View File

@ -6,102 +6,7 @@ from unittest.mock import MagicMock, patch
import pytest import pytest
from astrai.inference.cache import PrefixCacheManager from astrai.inference.scheduler import InferenceScheduler
from astrai.inference.scheduler import (
InferenceScheduler,
)
def test_prefix_cache_concurrent_insert_find():
"""Test concurrent insert and find operations."""
cache = PrefixCacheManager(max_capacity=100)
results = {"errors": [], "inserts": 0, "finds": 0}
def insert_worker():
try:
for i in range(50):
cache.insert((i,), slot=i % 10, slot_ver=0)
results["inserts"] += 1
except Exception as e:
results["errors"].append(str(e))
def find_worker():
try:
for i in range(50):
cache.find([i])
results["finds"] += 1
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
threads += [threading.Thread(target=find_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert results["inserts"] == 150
assert results["finds"] == 150
def test_prefix_cache_concurrent_release():
"""Test concurrent release operations."""
cache = PrefixCacheManager(max_capacity=100)
# Insert some prefixes
for i in range(10):
cache.insert((i,), slot=i, slot_ver=0)
results = {"errors": []}
def release_worker():
try:
for i in range(10):
cache.release((i,))
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=release_worker) for _ in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
def test_prefix_cache_concurrent_insert_release_find():
"""Test mixed concurrent operations."""
cache = PrefixCacheManager(max_capacity=50)
results = {"errors": []}
def worker(worker_id):
try:
for i in range(20):
token_ids = (worker_id * 100 + i,)
cache.insert(token_ids, slot=worker_id, slot_ver=0)
# Find after insert
cache.find(list(token_ids))
# Release
cache.release(token_ids)
except Exception as e:
results["errors"].append(f"Worker {worker_id}: {str(e)}")
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
@pytest.fixture @pytest.fixture
@ -266,54 +171,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
for stats in results["stats"]: for stats in results["stats"]:
assert "total_tasks" in stats assert "total_tasks" in stats
assert stats["total_tasks"] >= 0 assert stats["total_tasks"] >= 0
def test_prefix_cache_insert_same_prefix_concurrently():
"""Test inserting the same prefix concurrently."""
cache = PrefixCacheManager(max_capacity=100)
results = {"slot_values": [], "errors": []}
def insert_worker():
try:
# All workers try to insert the same prefix
cache.insert((1, 2, 3), slot=0, slot_ver=0)
node = cache.root.children.get(1)
if node:
node = node.children.get(2)
if node:
node = node.children.get(3)
if node:
results["slot_values"].append(node.slot)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
# All inserts should succeed, final slot should be one of the values
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
# Check ref_count is correct (should be 10)
node = cache.root.children.get(1).children.get(2).children.get(3)
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
def test_prefix_cache_ref_count_underflow_prevention():
"""Test that ref_count doesn't go negative."""
cache = PrefixCacheManager(max_capacity=100)
cache.insert((1, 2, 3), slot=0, slot_ver=0)
# Release multiple times
for _ in range(5):
cache.release((1, 2, 3))
# Try to find it - should return None since ref_count would be negative
# or handle it gracefully
node = cache.root.children.get(1).children.get(2).children.get(3)
# The ref_count should be 0, not negative
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"