refactor: 分页 KV cache 替换固定 slot,删除 PrefixCache 及相关死代码
- 用 PagedCache + CacheView 替换固定 slot 式 KV cache,attention 层只通过 page_table 间接索引 - 删除 PrefixCache(radix tree)及 scheduler 中所有 prefix cache 命中/插入/释放逻辑 - 删除无用函数:pin、version、free_count、_mark_seq_mask 及 seq_mask 分配 - 修复 write 在多页 prefill 时 offset 为负导致 chunk 计算错误 - _make_page_table_tensor 改用 list 拼接一次 tensor,去掉逐元素赋值 - 清理 model 接口参数:kv_cache, slot_indices → paged_cache(CacheView) - 精简 docstring 为单行,删除冗余 section 注释和旧代码 - 修复 test_scheduler_concurrency.py 缺少 import pytest
This commit is contained in:
parent
7ddebf2cd9
commit
30cc2d67a4
|
|
@ -1,241 +1,135 @@
|
|||
"""KV cache slot allocation and prefix cache management.
|
||||
"""Page-based KV cache with page-table-indirected read/write.
|
||||
|
||||
Provides:
|
||||
- SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask.
|
||||
- PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse.
|
||||
- PagedCache: paged KV cache combining page pool and tensor storage.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
_STOP = object()
|
||||
STOP = object()
|
||||
|
||||
|
||||
class _RadixNode:
|
||||
"""Internal node for the radix tree prefix cache.
|
||||
class PagedCache:
|
||||
"""Paged KV cache with page-table-indirected read/write.
|
||||
|
||||
Attributes:
|
||||
children: Mapping from token ID to child node.
|
||||
slot: KV cache slot index for the prefix ending at this node.
|
||||
slot_ver: Version counter of the slot at insertion time.
|
||||
ref_count: Number of tasks currently referencing this node.
|
||||
last_access: Timestamp of the most recent access (for LRU ordering).
|
||||
Combines:
|
||||
- Page pool (ref-counted alloc/free via bitmask)
|
||||
- KV tensor storage (k_cache, v_cache)
|
||||
|
||||
Call :meth:`bind` to obtain a batch view for the attention layers.
|
||||
"""
|
||||
|
||||
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
|
||||
|
||||
def __init__(self):
|
||||
self.children: Dict[int, "_RadixNode"] = {}
|
||||
self.slot: int = -1
|
||||
self.slot_ver: int = 0
|
||||
self.ref_count: int = 0
|
||||
self.last_access: float = 0.0
|
||||
|
||||
|
||||
class SlotAllocator:
|
||||
"""KV cache slot allocator using bitmask for O(1) alloc/free.
|
||||
|
||||
Implements the Object Pool pattern: pre-allocated KV cache slots
|
||||
are managed via a bitmask, providing constant-time allocation and
|
||||
deallocation with version counters for staleness detection.
|
||||
"""
|
||||
|
||||
def __init__(self, max_slots: int):
|
||||
self._max_slots = max_slots
|
||||
self._free_mask = (1 << max_slots) - 1
|
||||
self._versions: List[int] = [0] * max_slots
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
n_pages: int,
|
||||
page_size: int,
|
||||
n_kv_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self._free_mask = (1 << n_pages) - 1
|
||||
self._refs: List[int] = [0] * n_pages
|
||||
self.k_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self.v_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def alloc(self) -> int:
|
||||
"""Allocates a free slot.
|
||||
|
||||
Returns:
|
||||
Slot index on success, -1 if all slots are occupied.
|
||||
"""
|
||||
lsb = self._free_mask & -self._free_mask
|
||||
if lsb == 0:
|
||||
return -1
|
||||
idx = lsb.bit_length() - 1
|
||||
self._free_mask ^= lsb
|
||||
self._versions[idx] += 1
|
||||
self._refs[idx] = 1
|
||||
return idx
|
||||
|
||||
def alloc_n(self, n: int) -> List[int]:
|
||||
pages = [self.alloc() for _ in range(n)]
|
||||
if any(p < 0 for p in pages):
|
||||
for p in pages:
|
||||
if p >= 0:
|
||||
self.free(p)
|
||||
return []
|
||||
return pages
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
"""Releases a slot back to the free pool."""
|
||||
self._free_mask |= 1 << idx
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def occupy(self, idx: int) -> None:
|
||||
"""Marks a currently free slot as occupied without bumping its version.
|
||||
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
||||
return CacheView(self, page_table, total_len)
|
||||
|
||||
Used for direct slot reuse when a prefix-cached slot is still valid.
|
||||
"""
|
||||
self._free_mask ^= 1 << idx
|
||||
def write(
|
||||
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
|
||||
) -> None:
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
return
|
||||
page_size = self.page_size
|
||||
written = 0
|
||||
first_page = start_pos // page_size
|
||||
last_page = (start_pos + seq_len - 1) // page_size
|
||||
for pi in range(first_page, last_page + 1):
|
||||
phys_pages = page_table[:, pi]
|
||||
page_start = pi * page_size
|
||||
write_start = max(page_start, start_pos)
|
||||
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||
offset = write_start - page_start
|
||||
chunk = write_end - write_start
|
||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||
:, written : written + chunk
|
||||
]
|
||||
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
|
||||
:, written : written + chunk
|
||||
]
|
||||
written += chunk
|
||||
|
||||
def is_free(self, idx: int) -> bool:
|
||||
"""Checks whether a slot is currently free."""
|
||||
return (self._free_mask >> idx) & 1 == 1
|
||||
|
||||
def version(self, idx: int) -> int:
|
||||
"""Returns the current version counter for a slot."""
|
||||
return self._versions[idx]
|
||||
|
||||
@property
|
||||
def free_count(self) -> int:
|
||||
"""Returns the number of currently free slots."""
|
||||
return self._free_mask.bit_count()
|
||||
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
k_parts, v_parts = [], []
|
||||
for pi in range(page_table.size(1)):
|
||||
phys_pages = page_table[:, pi]
|
||||
if not (phys_pages >= 0).any():
|
||||
break
|
||||
k_parts.append(self.k_cache[layer_id, phys_pages])
|
||||
v_parts.append(self.v_cache[layer_id, phys_pages])
|
||||
k = torch.cat(k_parts, dim=1)
|
||||
v = torch.cat(v_parts, dim=1)
|
||||
return k, v
|
||||
|
||||
|
||||
class PrefixCacheManager:
|
||||
"""Radix-tree prefix cache with LRU eviction.
|
||||
class CacheView:
|
||||
"""Per-batch view that bundles PagedCache + page_table + total_len.
|
||||
|
||||
Maps token ID sequences to KV cache slots. Intermediate tree nodes
|
||||
also store slot information, allowing direct slot reuse when the
|
||||
cached slot is free and its version matches (no intervening writes).
|
||||
Attention layers receive this as ``paged_cache`` and only see
|
||||
``write()`` / ``gather()``, never raw page tables or length params.
|
||||
"""
|
||||
|
||||
def __init__(self, max_capacity: int = 1000):
|
||||
"""Initializes the prefix cache.
|
||||
__slots__ = ("_cache", "_page_table", "_total_len")
|
||||
|
||||
Args:
|
||||
max_capacity: Maximum number of nodes in the LRU list.
|
||||
"""
|
||||
self.root = _RadixNode()
|
||||
self.max_capacity = max_capacity
|
||||
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
|
||||
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
||||
self._cache = cache
|
||||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
|
||||
"""Inserts a token sequence into the prefix cache.
|
||||
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
|
||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
Every node along the path records the slot and its version,
|
||||
enabling direct slot reuse for partial prefix matches.
|
||||
|
||||
Args:
|
||||
token_ids: The token ID sequence to cache.
|
||||
slot: The KV cache slot containing this prefix's computed keys/values.
|
||||
slot_ver: The slot version at insertion time, used for staleness detection.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
nxt = _RadixNode()
|
||||
node.children[tid] = nxt
|
||||
node = nxt
|
||||
node.slot = slot
|
||||
node.slot_ver = slot_ver
|
||||
node.last_access = time.time()
|
||||
self._lru[id(node)] = node
|
||||
node.ref_count += 1
|
||||
self._evict_if_needed()
|
||||
|
||||
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
|
||||
"""Finds the longest matching prefix in the cache.
|
||||
|
||||
Walks the radix tree token by token, recording the deepest match.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence to match against.
|
||||
|
||||
Returns:
|
||||
Tuple of (prefix_len, slot, slot_ver):
|
||||
prefix_len: Number of matching tokens (0 if no match).
|
||||
slot: KV cache slot of the matched prefix (-1 if no match).
|
||||
slot_ver: Version of that slot when the prefix was inserted.
|
||||
"""
|
||||
node = self.root
|
||||
best_len, best_slot, best_ver = 0, -1, 0
|
||||
for i, tid in enumerate(token_ids):
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
break
|
||||
node = nxt
|
||||
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
|
||||
node.last_access = time.time()
|
||||
self._lru.move_to_end(id(node))
|
||||
return best_len, best_slot, best_ver
|
||||
|
||||
def pin(self, token_ids: Tuple[int, ...]) -> None:
|
||||
"""Increments the reference count of a cached prefix.
|
||||
|
||||
Called when a task reuses a cached prefix to prevent eviction.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence whose node's ref_count to increment.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
node.ref_count += 1
|
||||
|
||||
def release(self, token_ids: Tuple[int, ...]) -> None:
|
||||
"""Decrements the reference count of a cached prefix.
|
||||
|
||||
The node's slot is preserved even when ref_count reaches zero,
|
||||
allowing future tasks to reuse the slot directly if it remains free.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence whose node's ref_count to decrement.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
if node.ref_count > 0:
|
||||
node.ref_count -= 1
|
||||
|
||||
def copy_kv(
|
||||
self,
|
||||
token_ids: Tuple[int, ...],
|
||||
target_slot: int,
|
||||
kv_cache: Tuple[Tensor, Tensor],
|
||||
n_layers: int,
|
||||
) -> None:
|
||||
"""Copies cached KV data from the source slot to a target slot.
|
||||
|
||||
Args:
|
||||
token_ids: The prefix token sequence identifying the source cache node.
|
||||
target_slot: The destination KV cache slot to copy into.
|
||||
kv_cache: Tuple of (k_cache, v_cache) tensors.
|
||||
n_layers: Number of transformer layers to copy.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
src_slot = node.slot
|
||||
if src_slot < 0:
|
||||
return
|
||||
prefix_len = len(token_ids)
|
||||
k_cache, v_cache = kv_cache
|
||||
for li in range(n_layers):
|
||||
k_cache[target_slot, :prefix_len, li].copy_(
|
||||
k_cache[src_slot, :prefix_len, li]
|
||||
)
|
||||
v_cache[target_slot, :prefix_len, li].copy_(
|
||||
v_cache[src_slot, :prefix_len, li]
|
||||
)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evicts least-recently-used nodes until under capacity.
|
||||
|
||||
Skips nodes with ref_count > 0 (still in use by active tasks).
|
||||
Evicted nodes have their slot and children cleared.
|
||||
"""
|
||||
while len(self._lru) > self.max_capacity:
|
||||
key, node = next(iter(self._lru.items()))
|
||||
if node.ref_count > 0:
|
||||
self._lru.move_to_end(key)
|
||||
continue
|
||||
self._lru.pop(key)
|
||||
node.slot = -1
|
||||
node.slot_ver = 0
|
||||
node.children.clear()
|
||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
k, v = self._cache.gather(layer_id, self._page_table)
|
||||
if self._total_len:
|
||||
k = k[:, : self._total_len]
|
||||
v = v[:, : self._total_len]
|
||||
return k, v
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.inference.cache import _STOP
|
||||
from astrai.inference.cache import STOP
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
|
@ -118,15 +118,15 @@ class _Result:
|
|||
"""Appends a token to the result buffer.
|
||||
|
||||
In non-streaming mode, tokens are concatenated into results[idx].
|
||||
The sentinel _STOP marks a task as complete.
|
||||
The sentinel STOP marks a task as complete.
|
||||
|
||||
Args:
|
||||
token: The decoded token string, or _STOP sentinel.
|
||||
token: The decoded token string, or STOP sentinel.
|
||||
idx: Index of the generation task this token belongs to.
|
||||
"""
|
||||
with self._lock:
|
||||
self.tokens.append(token)
|
||||
if token is not _STOP:
|
||||
if token is not STOP:
|
||||
self.results[idx] += token
|
||||
else:
|
||||
if not self._done[idx]:
|
||||
|
|
@ -186,38 +186,28 @@ class InferenceEngine:
|
|||
max_batch_size: int = 1,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prompt_len: int = 2048,
|
||||
cache_capacity: int = 1000,
|
||||
page_size: int = 128,
|
||||
):
|
||||
"""Initializes the engine and starts the scheduler background thread.
|
||||
"""Initializes the inference engine.
|
||||
|
||||
Args:
|
||||
model: The language model (nn.Module, e.g. Transformer).
|
||||
tokenizer: Tokenizer for encoding/decoding.
|
||||
max_batch_size: Maximum concurrent tasks in the scheduler.
|
||||
max_seq_len: Maximum sequence length (defaults to model config).
|
||||
max_prompt_len: Maximum prompt tokens (longer prompts truncated).
|
||||
cache_capacity: Maximum prefix cache nodes.
|
||||
model: The model instance.
|
||||
tokenizer: The tokenizer instance.
|
||||
max_batch_size: Maximum number of concurrent tasks.
|
||||
max_seq_len: Maximum sequence length.
|
||||
max_prompt_len: Maximum prompt tokens.
|
||||
compile: Whether to compile the model with torch.compile.
|
||||
page_size: Number of tokens per KV cache page.
|
||||
"""
|
||||
try:
|
||||
first_param = next(model.parameters())
|
||||
device = first_param.device
|
||||
dtype = first_param.dtype
|
||||
except StopIteration:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.scheduler = InferenceScheduler(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=max_batch_size,
|
||||
max_seq_len=max_seq_len,
|
||||
max_prompt_len=max_prompt_len,
|
||||
cache_capacity=cache_capacity,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
page_size=page_size,
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
|
|
@ -383,7 +373,7 @@ class InferenceEngine:
|
|||
while True:
|
||||
tokens = result.pop_all()
|
||||
for token in tokens:
|
||||
if token is _STOP:
|
||||
if token is STOP:
|
||||
return
|
||||
yield token
|
||||
if not result.wait(timeout=0.05):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ parameters, so a single pipeline works for any batch size.
|
|||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
|
|||
|
|
@ -1,24 +1,19 @@
|
|||
"""Inference scheduler for single-GPU continuous batching.
|
||||
|
||||
Splits scheduling concerns across modules:
|
||||
- cache.py: SlotAllocator (Object Pool), PrefixCacheManager
|
||||
- sampling.py: Strategy-pattern logit transformations
|
||||
"""
|
||||
"""Inference scheduler for single-GPU continuous batching with paged KV cache."""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator
|
||||
from astrai.inference.cache import STOP, PagedCache
|
||||
from astrai.inference.sampling import sample
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.tokenize.tokenizer import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -33,29 +28,7 @@ class TaskStatus(Enum):
|
|||
|
||||
|
||||
class Task:
|
||||
"""Represents a single generation request within the scheduler.
|
||||
|
||||
Tracks prompt tokens, generated output, sampling parameters,
|
||||
KV cache slot assignment, and prefix cache matching state.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"task_id",
|
||||
"prompt_ids",
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"status",
|
||||
"output_ids",
|
||||
"input_tokens",
|
||||
"output_tokens",
|
||||
"slot",
|
||||
"prefix_len",
|
||||
"arrival_time",
|
||||
"finish_time",
|
||||
"stream_callback",
|
||||
)
|
||||
"""Represents a single generation request with paged KV cache tracking."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -67,17 +40,6 @@ class Task:
|
|||
top_k: int = 50,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
"""Initializes a new task.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for this task.
|
||||
prompt_ids: Tokenized prompt sequence.
|
||||
max_tokens: Maximum number of tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling probability threshold.
|
||||
top_k: Top-k sampling count (0 disables).
|
||||
stream_callback: Optional callback invoked per decoded token.
|
||||
"""
|
||||
self.task_id = task_id
|
||||
self.prompt_ids = prompt_ids
|
||||
self.max_tokens = max_tokens
|
||||
|
|
@ -89,26 +51,17 @@ class Task:
|
|||
self.output_ids: List[int] = []
|
||||
self.input_tokens: int = 0
|
||||
self.output_tokens: int = 0
|
||||
self.slot: int = -1
|
||||
self.prefix_len: int = 0
|
||||
self.page_table: List[int] = []
|
||||
self.n_pages: int = 0
|
||||
self.arrival_time = time.time()
|
||||
self.finish_time: Optional[float] = None
|
||||
self.stream_callback = stream_callback
|
||||
|
||||
@property
|
||||
def next_pos(self) -> int:
|
||||
"""Returns the next KV cache position to write during decode."""
|
||||
return self.input_tokens + len(self.output_ids)
|
||||
|
||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||
"""Checks whether the task has reached a stopping condition.
|
||||
|
||||
Args:
|
||||
stop_ids: List of stop token IDs (e.g., EOS).
|
||||
|
||||
Returns:
|
||||
True if max_tokens reached or the last output token is a stop ID.
|
||||
"""
|
||||
if self.output_tokens >= self.max_tokens:
|
||||
return True
|
||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||
|
|
@ -117,16 +70,13 @@ class Task:
|
|||
|
||||
|
||||
class InferenceScheduler:
|
||||
"""Continuous batching scheduler for single-GPU inference.
|
||||
"""Continuous batching scheduler with paged KV cache.
|
||||
|
||||
Runs a background generation loop with four phases per iteration:
|
||||
1. Cleanup finished tasks and release resources.
|
||||
2. Refill active batch from the waiting queue.
|
||||
3. Prefill newly activated tasks (full, partial, or fully cached).
|
||||
3. Prefill newly activated tasks.
|
||||
4. Decode the largest same-position group of active tasks.
|
||||
|
||||
Tasks at different positions are never batched together in decode,
|
||||
avoiding RoPE corruption from misaligned KV cache writes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -136,22 +86,10 @@ class InferenceScheduler:
|
|||
max_batch_size: int = 16,
|
||||
max_seq_len: Optional[int] = None,
|
||||
max_prompt_len: int = 512,
|
||||
cache_capacity: int = 1000,
|
||||
page_size: int = 64,
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
):
|
||||
"""Initializes the scheduler and pre-allocates the KV cache.
|
||||
|
||||
Args:
|
||||
model: The language model (must have config with n_layers, n_kv_heads, etc.).
|
||||
tokenizer: Tokenizer for encoding prompts and decoding outputs.
|
||||
max_batch_size: Maximum number of concurrent tasks.
|
||||
max_seq_len: Maximum sequence length (defaults to config.max_len).
|
||||
max_prompt_len: Maximum prompt tokens (longer prompts are truncated).
|
||||
cache_capacity: Maximum prefix cache node count.
|
||||
device: Target device for tensors.
|
||||
dtype: Data type for KV cache tensors.
|
||||
"""
|
||||
config = model.config
|
||||
|
||||
self.model = model
|
||||
|
|
@ -159,35 +97,25 @@ class InferenceScheduler:
|
|||
self.max_batch_size = max_batch_size
|
||||
self.max_seq_len = max_seq_len or config.max_len
|
||||
self.max_prompt_len = max_prompt_len
|
||||
self.page_size = page_size
|
||||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity)
|
||||
|
||||
n_kv_heads = config.n_kv_heads
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_layers = config.n_layers
|
||||
self._n_layers = n_layers
|
||||
n_pages = (max_batch_size * self.max_seq_len + page_size - 1) // page_size
|
||||
|
||||
k_cache = torch.empty(
|
||||
(max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
v_cache = torch.empty(
|
||||
(max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim),
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
self.kv_cache = (k_cache, v_cache)
|
||||
|
||||
self.seq_mask = torch.zeros(
|
||||
(max_batch_size, self.max_seq_len),
|
||||
device=self.device,
|
||||
dtype=torch.bool,
|
||||
self.page_cache = PagedCache(
|
||||
n_layers,
|
||||
n_pages,
|
||||
page_size,
|
||||
n_kv_heads,
|
||||
head_dim,
|
||||
self.device,
|
||||
self.dtype,
|
||||
)
|
||||
|
||||
self.slot_allocator = SlotAllocator(max_batch_size)
|
||||
self.waiting_queue: List[Task] = []
|
||||
self.active_tasks: List[Task] = []
|
||||
|
||||
|
|
@ -198,41 +126,8 @@ class InferenceScheduler:
|
|||
self._total_tasks = 0
|
||||
self._total_tokens = 0
|
||||
|
||||
def _alloc_slot(self) -> int:
|
||||
"""Allocates a free KV cache slot using the Object Pool.
|
||||
|
||||
Returns:
|
||||
Slot index on success, -1 if all slots are occupied.
|
||||
"""
|
||||
return self.slot_allocator.alloc()
|
||||
|
||||
def _free_slot(self, idx: int) -> None:
|
||||
"""Releases a KV cache slot back to the free pool.
|
||||
|
||||
Args:
|
||||
idx: Slot index to free.
|
||||
"""
|
||||
self.slot_allocator.free(idx)
|
||||
self.seq_mask[idx, :] = False
|
||||
|
||||
def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]:
|
||||
"""Attempts to reuse a prefix-cached slot directly without KV copy.
|
||||
|
||||
The slot is reusable only if it is free and its version matches
|
||||
the current slot version (no intervening allocation overwrote it).
|
||||
|
||||
Args:
|
||||
prefix: The matched prefix token sequence.
|
||||
|
||||
Returns:
|
||||
Tuple of (slot, True) on success, or (-1, False) if reuse is not possible.
|
||||
"""
|
||||
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
||||
if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot):
|
||||
if cached_ver == self.slot_allocator.version(cached_slot):
|
||||
self.slot_allocator.occupy(cached_slot)
|
||||
return cached_slot, True
|
||||
return -1, False
|
||||
def _n_pages_for(self, n_tokens: int) -> int:
|
||||
return (n_tokens + self.page_size - 1) // self.page_size
|
||||
|
||||
def add_task(
|
||||
self,
|
||||
|
|
@ -243,25 +138,8 @@ class InferenceScheduler:
|
|||
top_k: int = 50,
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> str:
|
||||
"""Adds a generation task to the waiting queue.
|
||||
|
||||
Encodes the prompt, queries the prefix cache for a match,
|
||||
and enqueues the task for the background generation loop.
|
||||
|
||||
Args:
|
||||
prompt: Input text to generate from.
|
||||
max_tokens: Maximum tokens to generate.
|
||||
temperature: Sampling temperature.
|
||||
top_p: Nucleus sampling threshold.
|
||||
top_k: Top-k sampling count.
|
||||
stream_callback: Called per decoded token with the string representation.
|
||||
|
||||
Returns:
|
||||
Unique task ID string.
|
||||
"""
|
||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
prompt_ids = self.tokenizer.encode(prompt)
|
||||
|
||||
if len(prompt_ids) > self.max_prompt_len:
|
||||
prompt_ids = prompt_ids[-self.max_prompt_len :]
|
||||
|
||||
|
|
@ -275,9 +153,6 @@ class InferenceScheduler:
|
|||
stream_callback=stream_callback,
|
||||
)
|
||||
|
||||
prefix_len, _cached_slot, _cached_ver = self.prefix_cache.find(prompt_ids)
|
||||
task.prefix_len = prefix_len
|
||||
|
||||
with self._lock:
|
||||
self.waiting_queue.append(task)
|
||||
self._total_tasks += 1
|
||||
|
|
@ -286,32 +161,21 @@ class InferenceScheduler:
|
|||
return task_id
|
||||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
"""Removes a task from both the waiting queue and active tasks.
|
||||
|
||||
Args:
|
||||
task_id: The task to remove.
|
||||
"""
|
||||
with self._lock:
|
||||
removed_active = [t for t in self.active_tasks if t.task_id == task_id]
|
||||
self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id]
|
||||
self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id]
|
||||
|
||||
for task in removed_active:
|
||||
if task.prefix_len > 0:
|
||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||
self.prefix_cache.release(prefix)
|
||||
if task.prefix_len < len(task.prompt_ids):
|
||||
self.prefix_cache.release(tuple(task.prompt_ids))
|
||||
if task.slot >= 0:
|
||||
self._free_slot(task.slot)
|
||||
task.slot = -1
|
||||
self._free_pages(task.page_table)
|
||||
task.page_table.clear()
|
||||
task.n_pages = 0
|
||||
|
||||
def _free_pages(self, indices: List[int]) -> None:
|
||||
for idx in indices:
|
||||
self.page_cache.free(idx)
|
||||
|
||||
def _remove_finished_tasks(self) -> None:
|
||||
"""Removes all finished tasks from the active batch.
|
||||
|
||||
Releases prefix cache references and frees the KV cache slot
|
||||
for each completed task.
|
||||
"""
|
||||
finished = []
|
||||
for task in self.active_tasks:
|
||||
if task.is_finished(self.tokenizer.stop_ids):
|
||||
|
|
@ -321,25 +185,15 @@ class InferenceScheduler:
|
|||
self._total_tokens += task.output_tokens
|
||||
|
||||
for task in finished:
|
||||
if task.prefix_len > 0:
|
||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||
self.prefix_cache.release(prefix)
|
||||
if task.prefix_len < len(task.prompt_ids):
|
||||
self.prefix_cache.release(tuple(task.prompt_ids))
|
||||
if task.slot >= 0:
|
||||
self._free_slot(task.slot)
|
||||
task.slot = -1
|
||||
self._free_pages(task.page_table)
|
||||
task.page_table.clear()
|
||||
task.n_pages = 0
|
||||
|
||||
self.active_tasks = [
|
||||
t for t in self.active_tasks if t.status != TaskStatus.FINISHED
|
||||
]
|
||||
|
||||
def _refill_active_batch(self) -> None:
|
||||
"""Moves waiting tasks into the active batch, up to max_batch_size.
|
||||
|
||||
Attempts direct slot reuse for prefix-matched tasks; falls back
|
||||
to allocating a fresh slot with KV cache copy when reuse is not possible.
|
||||
"""
|
||||
available = self.max_batch_size - len(self.active_tasks)
|
||||
if available <= 0:
|
||||
return
|
||||
|
|
@ -350,122 +204,71 @@ class InferenceScheduler:
|
|||
for _ in range(n):
|
||||
to_add.append(self.waiting_queue.pop(0))
|
||||
|
||||
for i, task in enumerate(to_add):
|
||||
slot = -1
|
||||
reused = False
|
||||
if task.prefix_len > 0:
|
||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||
cached_slot, reused = self._try_reuse_slot(prefix)
|
||||
if reused:
|
||||
slot = cached_slot
|
||||
if slot < 0:
|
||||
slot = self._alloc_slot()
|
||||
if slot < 0:
|
||||
with self._lock:
|
||||
self.waiting_queue[:0] = to_add[i:]
|
||||
break
|
||||
task.slot = slot
|
||||
for task in to_add:
|
||||
prompt_len = len(task.prompt_ids)
|
||||
n_pages = self._n_pages_for(prompt_len)
|
||||
task.page_table = self.page_cache.alloc_n(n_pages)
|
||||
if not task.page_table:
|
||||
with self._lock:
|
||||
self.waiting_queue.insert(0, task)
|
||||
break
|
||||
task.n_pages = len(task.page_table)
|
||||
task.status = TaskStatus.RUNNING
|
||||
self.active_tasks.append(task)
|
||||
|
||||
if task.prefix_len > 0 and not reused:
|
||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
||||
if cached_slot >= 0 and cached_ver == self.slot_allocator.version(
|
||||
cached_slot
|
||||
):
|
||||
self.prefix_cache.pin(prefix)
|
||||
self.prefix_cache.copy_kv(
|
||||
prefix, slot, self.kv_cache, self._n_layers
|
||||
)
|
||||
else:
|
||||
task.prefix_len = 0
|
||||
|
||||
def _execute_prefill(self, tasks: List[Task]) -> None:
|
||||
"""Runs batched prefill for newly activated tasks.
|
||||
|
||||
Fully-cached tasks skip the model. Others are grouped by prefix_len
|
||||
so tasks sharing the same start_pos are batched together.
|
||||
"""
|
||||
if not tasks:
|
||||
def _execute_prefill(self) -> None:
|
||||
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
|
||||
if not to_prefill:
|
||||
return
|
||||
|
||||
for t in to_prefill:
|
||||
prompt_len = len(t.prompt_ids)
|
||||
t.input_tokens = prompt_len
|
||||
t.output_tokens = 0
|
||||
|
||||
groups: Dict[int, List[Task]] = {}
|
||||
for t in tasks:
|
||||
plen = len(t.prompt_ids)
|
||||
if t.prefix_len == plen:
|
||||
t.input_tokens = plen
|
||||
t.output_tokens = 0
|
||||
if t.slot >= 0:
|
||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||
else:
|
||||
groups.setdefault(t.prefix_len, []).append(t)
|
||||
for t in to_prefill:
|
||||
groups.setdefault(len(t.prompt_ids), []).append(t)
|
||||
|
||||
for prefix_len, group in groups.items():
|
||||
slot_indices = torch.tensor([t.slot for t in group], device=self.device)
|
||||
self._execute_prefill_batch(group, prefix_len, slot_indices)
|
||||
for prompt_len, group in groups.items():
|
||||
self._execute_prefill_batch(group, prompt_len)
|
||||
|
||||
def _execute_prefill_batch(
|
||||
self, tasks: List[Task], prefix_len: int, slot_indices: Tensor
|
||||
) -> None:
|
||||
"""Unified prefill for tasks sharing a common prefix_len.
|
||||
|
||||
Args:
|
||||
tasks: Tasks with the same prefix_len < len(prompt_ids).
|
||||
prefix_len: Number of cached prefix tokens (0 for full prefill).
|
||||
slot_indices: Tensor of slot indices for KV cache mapping.
|
||||
"""
|
||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
|
||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||
batch_sz = len(tasks)
|
||||
|
||||
new_lens = [len(t.prompt_ids) - prefix_len for t in tasks]
|
||||
max_new_len = max(new_lens)
|
||||
|
||||
input_ids = torch.zeros(
|
||||
batch_sz, max_new_len, dtype=torch.long, device=self.device
|
||||
batch_sz,
|
||||
prompt_len,
|
||||
dtype=torch.long,
|
||||
device=self.device,
|
||||
)
|
||||
input_mask = torch.zeros(
|
||||
batch_sz, prefix_len + max_new_len, dtype=torch.bool, device=self.device
|
||||
input_mask = torch.ones(
|
||||
batch_sz,
|
||||
prompt_len,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
new_ids = t.prompt_ids[prefix_len:]
|
||||
nl = len(new_ids)
|
||||
if nl > 0:
|
||||
input_ids[i, :nl] = torch.tensor(new_ids, device=self.device)
|
||||
input_mask[i, : prefix_len + nl] = True
|
||||
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
|
||||
|
||||
page_tables = self._make_page_table_tensor(tasks)
|
||||
|
||||
with torch.inference_mode():
|
||||
self.model(
|
||||
input_ids,
|
||||
input_mask=input_mask,
|
||||
start_pos=prefix_len,
|
||||
persistent_key_values=self.kv_cache,
|
||||
slot_indices=slot_indices,
|
||||
start_pos=0,
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||
)
|
||||
|
||||
for i, t in enumerate(tasks):
|
||||
t.input_tokens = len(t.prompt_ids)
|
||||
t.output_tokens = 0
|
||||
self.prefix_cache.insert(
|
||||
tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot)
|
||||
)
|
||||
if t.slot >= 0:
|
||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||
|
||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||
"""Executes the decode phase for a group of tasks at the same position.
|
||||
|
||||
Args:
|
||||
tasks: Tasks sharing the same next_pos value.
|
||||
start_pos: Common KV cache write position for the batch.
|
||||
"""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
tasks = sorted(tasks, key=lambda t: t.slot)
|
||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||
batch_sz = len(tasks)
|
||||
slot_indices = torch.tensor([t.slot for t in tasks], device=self.device)
|
||||
|
||||
input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device)
|
||||
for i, t in enumerate(tasks):
|
||||
|
|
@ -473,13 +276,15 @@ class InferenceScheduler:
|
|||
|
||||
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||
|
||||
page_tables = self._make_page_table_tensor(tasks)
|
||||
total_len = start_pos + 1
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(
|
||||
input_ids.unsqueeze(1),
|
||||
input_mask=active_mask,
|
||||
persistent_key_values=self.kv_cache,
|
||||
paged_cache=self.page_cache.bind(page_tables, total_len=total_len),
|
||||
start_pos=start_pos,
|
||||
slot_indices=slot_indices,
|
||||
)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
|
||||
|
|
@ -496,23 +301,30 @@ class InferenceScheduler:
|
|||
t.output_ids.append(ntok)
|
||||
t.output_tokens += 1
|
||||
pos = t.input_tokens + t.output_tokens
|
||||
if t.slot >= 0 and pos < self.max_seq_len:
|
||||
self.seq_mask[t.slot, pos] = True
|
||||
self._maybe_alloc_page(t, pos)
|
||||
if t.stream_callback:
|
||||
t.stream_callback(self.tokenizer.decode([ntok]))
|
||||
|
||||
for t in tasks:
|
||||
if t.is_finished(self.tokenizer.stop_ids):
|
||||
if t.stream_callback:
|
||||
t.stream_callback(_STOP)
|
||||
t.stream_callback(STOP)
|
||||
|
||||
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
|
||||
max_pages = max(t.n_pages for t in tasks)
|
||||
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
|
||||
return torch.tensor(rows, dtype=torch.long, device=self.device)
|
||||
|
||||
def _maybe_alloc_page(self, task: Task, pos: int) -> None:
|
||||
needed = self._n_pages_for(pos + 1)
|
||||
while task.n_pages < needed:
|
||||
p = self.page_cache.alloc()
|
||||
if p < 0:
|
||||
break
|
||||
task.page_table.append(p)
|
||||
task.n_pages += 1
|
||||
|
||||
def _run_generation_loop(self) -> None:
|
||||
"""Main generation loop run in a daemon thread.
|
||||
|
||||
Continuously cycles through cleanup, refill, prefill, and decode.
|
||||
Decode processes only the largest position group to ensure all
|
||||
batched tasks share the same KV cache write position.
|
||||
"""
|
||||
try:
|
||||
while self._running:
|
||||
self._remove_finished_tasks()
|
||||
|
|
@ -523,11 +335,8 @@ class InferenceScheduler:
|
|||
self._task_event.clear()
|
||||
self._task_event.wait(timeout=0.01)
|
||||
continue
|
||||
tasks = self.active_tasks[:]
|
||||
|
||||
to_prefill = [t for t in tasks if t.output_tokens == 0]
|
||||
if to_prefill:
|
||||
self._execute_prefill(to_prefill)
|
||||
self._execute_prefill()
|
||||
|
||||
pos_groups: Dict[int, List[Task]] = {}
|
||||
for t in self.active_tasks:
|
||||
|
|
@ -544,21 +353,20 @@ class InferenceScheduler:
|
|||
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
|
||||
for task in self.active_tasks:
|
||||
if task.stream_callback:
|
||||
task.stream_callback(_STOP)
|
||||
task.stream_callback(STOP)
|
||||
for task in self.waiting_queue:
|
||||
if task.stream_callback:
|
||||
task.stream_callback(_STOP)
|
||||
task.stream_callback(STOP)
|
||||
raise
|
||||
|
||||
def start(self) -> None:
|
||||
"""Starts the background generation loop thread."""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
t = threading.Thread(target=self._run_generation_loop, daemon=True)
|
||||
t.start()
|
||||
self._loop_thread = t
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stops the generation loop and releases all resources."""
|
||||
self._running = False
|
||||
self._task_event.set()
|
||||
if hasattr(self, "_loop_thread"):
|
||||
|
|
@ -569,11 +377,6 @@ class InferenceScheduler:
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Returns current scheduler statistics.
|
||||
|
||||
Returns:
|
||||
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
|
||||
"""
|
||||
return {
|
||||
"total_tasks": self._total_tasks,
|
||||
"total_tokens": self._total_tokens,
|
||||
|
|
|
|||
|
|
@ -5,17 +5,11 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.cache import CacheView
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
"""
|
||||
Repeat k times along the dimension for attention heads.
|
||||
Args:
|
||||
x (Tensor): The input tensor.
|
||||
n_rep (int): The number of repetitions.
|
||||
Returns:
|
||||
Tensor: The repeated tensor.
|
||||
"""
|
||||
|
||||
"""Repeat KV heads n_rep times for GQA."""
|
||||
bs, slen, n_heads, head_dim = x.shape
|
||||
if n_rep == 1:
|
||||
return x
|
||||
|
|
@ -32,49 +26,25 @@ def get_rotary_emb(
|
|||
base: float = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get the rotary embedding for the given dimension and maximum length.
|
||||
Args:
|
||||
dim (int): The dimension of the input.
|
||||
max_len (int): The maximum length of the input.
|
||||
base (float, optional): The base for the frequency. Defaults to 10000.
|
||||
device (optional): The device to create tensors on. Defaults to None.
|
||||
Returns:
|
||||
Tensor: The rotary embedding tensor.
|
||||
"""
|
||||
|
||||
"""Precompute cos/sin for RoPE."""
|
||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||
freqs = torch.outer(t, theta)
|
||||
|
||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
||||
|
||||
|
||||
def apply_rotary_emb(x: torch.Tensor, rotary_emb: Tuple[Tensor, Tensor]) -> Tensor:
|
||||
"""
|
||||
Apply rotary embedding to the input tensor using cos/sin form.
|
||||
Args:
|
||||
x (Tensor): The input tensor (shape [..., seq_len, dim]).
|
||||
rotary_emb (Tuple[Tensor, Tensor]): The rotary embedding (shape [seq_len, dim//2]).
|
||||
Returns:
|
||||
Tensor: The output tensor (rotated, same shape as input).
|
||||
"""
|
||||
|
||||
"""Apply rotary embedding via cos/sin (shape-preserving)."""
|
||||
dtype = x.dtype
|
||||
cos, sin = rotary_emb
|
||||
|
||||
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
||||
sin = sin.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, dim//2]
|
||||
|
||||
x_real = x[..., 0::2] # [batch, seq_len, dim//2]
|
||||
x_imag = x[..., 1::2] # [batch, seq_len, dim//2]
|
||||
|
||||
cos = cos.unsqueeze(0).unsqueeze(2)
|
||||
sin = sin.unsqueeze(0).unsqueeze(2)
|
||||
x_real = x[..., 0::2]
|
||||
x_imag = x[..., 1::2]
|
||||
x_real_rot = x_real * cos - x_imag * sin
|
||||
x_imag_rot = x_real * sin + x_imag * cos
|
||||
|
||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1) # [batch, seq_len, dim//2, 2]
|
||||
x_out = x_out.view(*x_out.shape[:-2], -1) # [batch, seq_len, dim]
|
||||
|
||||
x_out = torch.stack([x_real_rot, x_imag_rot], dim=-1)
|
||||
x_out = x_out.view(*x_out.shape[:-2], -1)
|
||||
return x_out.to(dtype)
|
||||
|
||||
|
||||
|
|
@ -95,13 +65,10 @@ class RotaryEmbedding(nn.Module):
|
|||
|
||||
def forward(self, x: Tensor, start_pos: int = 0) -> Tuple[Tensor, Tensor]:
|
||||
seq_len = x.size(1)
|
||||
|
||||
if self.max_len_cached < seq_len + start_pos:
|
||||
self._set_rotary_buffer(self.max_len_cached * 2, x.device)
|
||||
|
||||
cos = self.cos_cached[start_pos : start_pos + seq_len]
|
||||
sin = self.sin_cached[start_pos : start_pos + seq_len]
|
||||
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
|
|
@ -185,14 +152,13 @@ class GQA(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
mask: Tensor = None,
|
||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
slot_indices: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = mask is None
|
||||
|
||||
# x(bsz, seq_len, n_heads * head_dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||
# (bsz, seq_len, dim) -> (bsz, seq_len, n_heads, head_dim)
|
||||
q = self._split_heads(self.q_proj(x), self.n_heads)
|
||||
k = self._split_heads(self.k_proj(x), self.n_kv_heads)
|
||||
v = self._split_heads(self.v_proj(x), self.n_kv_heads)
|
||||
|
|
@ -201,18 +167,14 @@ class GQA(nn.Module):
|
|||
if self.use_qk_norm:
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if kv_cache is not None:
|
||||
k_cache, v_cache = kv_cache
|
||||
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||
v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||
k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
|
||||
v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, start_pos, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
k, v = repeat_kv(k, self.n_rep), repeat_kv(v, self.n_rep)
|
||||
|
||||
# (bsz, seq_len, n_heads, head_dim) -> (bsz, n_heads, seq_len, head_dim)
|
||||
q, k, v = q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
|
||||
# (bsz, n_heads, seq_len, head_dim) - > (bsz, seq_len, n_heads*head_dim)
|
||||
sdqa_out = (
|
||||
F.scaled_dot_product_attention(q, k, v, mask, is_causal=is_causal)
|
||||
.permute(0, 2, 1, 3)
|
||||
|
|
@ -224,7 +186,6 @@ class GQA(nn.Module):
|
|||
sdqa_out = sdqa_out * F.sigmoid(self.gate(x))
|
||||
|
||||
out = self.o_proj(sdqa_out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
@ -257,7 +218,7 @@ class MLA(nn.Module):
|
|||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||
|
||||
# KV (k_nope, k_rope, v)
|
||||
# fused KV: (k_nope, k_rope, v)
|
||||
self.kv_b_proj = Linear(
|
||||
kv_lora_rank,
|
||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||
|
|
@ -273,9 +234,8 @@ class MLA(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
mask: Tensor = None,
|
||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
slot_indices: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = mask is None
|
||||
|
|
@ -303,12 +263,9 @@ class MLA(nn.Module):
|
|||
q = torch.cat([q_nope, q_rope], dim=-1)
|
||||
k = torch.cat([k_nope, k_rope], dim=-1)
|
||||
|
||||
if kv_cache is not None:
|
||||
k_cache, v_cache = kv_cache
|
||||
k_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = k
|
||||
v_cache[slot_indices, start_pos : start_pos + seq_len, self.layer_id] = v
|
||||
k = k_cache[slot_indices, : start_pos + seq_len, self.layer_id]
|
||||
v = v_cache[slot_indices, : start_pos + seq_len, self.layer_id]
|
||||
if paged_cache is not None:
|
||||
paged_cache.write(self.layer_id, start_pos, k, v)
|
||||
k, v = paged_cache.gather(self.layer_id)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
|
|
@ -321,7 +278,6 @@ class MLA(nn.Module):
|
|||
attn_out = attn_out * F.sigmoid(self.gate(x))
|
||||
|
||||
out = self.o_proj(attn_out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
|
@ -356,24 +312,19 @@ class DecoderBlock(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tuple[Tensor, Tensor],
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
slot_indices: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
# attention
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
rotary_emb,
|
||||
attention_mask,
|
||||
kv_cache,
|
||||
paged_cache,
|
||||
start_pos,
|
||||
slot_indices,
|
||||
)
|
||||
x = attn_output + x
|
||||
|
||||
# feed forward
|
||||
x = self.mlp(self.post_attention_norm(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from typing import Any, Mapping, Optional, Tuple
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.inference.cache import CacheView
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.module import (
|
||||
DecoderBlock,
|
||||
|
|
@ -21,39 +22,25 @@ def process_attention_mask(
|
|||
start_pos: int = 0,
|
||||
is_causal: bool = False,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Create attention mask for GQA
|
||||
Args:
|
||||
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
||||
input_tensor (Tensor): The input tensor.
|
||||
start_pos (int): The starting position of the sequence.
|
||||
is_causal (bool): Whether the attention is causal or not.
|
||||
Returns:
|
||||
Tensor: The attention mask tensor.
|
||||
"""
|
||||
"""Build 4D attention mask from 2D seq_mask, with optional causal masking."""
|
||||
device = input_tensor.device
|
||||
dtype = input_tensor.dtype
|
||||
seq_len = input_tensor.size(1)
|
||||
|
||||
if seq_mask is None:
|
||||
if start_pos != 0:
|
||||
# for single prompt chat
|
||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||
else:
|
||||
return None
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||
# if ndim > 2, it's 4D tensor
|
||||
return seq_mask
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
seq_mask = seq_mask[:, : start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
# (bsz, start_pos + seq_len)
|
||||
expanded_mask = seq_mask.unsqueeze(1).expand(
|
||||
batch_size, seq_len, start_pos + seq_len
|
||||
)
|
||||
# (bsz, seq_len, start_pos + seq_len)
|
||||
|
||||
if is_causal:
|
||||
expanded_mask = torch.tril(expanded_mask, diagonal=start_pos)
|
||||
|
|
@ -62,16 +49,13 @@ def process_attention_mask(
|
|||
attention_mask = attention_mask.masked_fill_(
|
||||
~expanded_mask, -torch.finfo(dtype).max / 2
|
||||
).unsqueeze(1)
|
||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@AutoModel.register("transformer")
|
||||
class Transformer(AutoModel):
|
||||
"""
|
||||
Transformer language model.
|
||||
"""
|
||||
"""Transformer language model with paged KV cache."""
|
||||
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__(config)
|
||||
|
|
@ -114,18 +98,15 @@ class Transformer(AutoModel):
|
|||
lm_head_key = "lm_head.weight"
|
||||
embed_key = "embed_tokens.weight"
|
||||
|
||||
# Make a copy to avoid modifying the original state_dict
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if self.config.tie_weight:
|
||||
# same tensor
|
||||
# same tensor for embed and lm_head
|
||||
if embed_key in state_dict:
|
||||
state_dict[lm_head_key] = state_dict[embed_key]
|
||||
else:
|
||||
# If lm_head.weight exists in checkpoint, use it directly
|
||||
# If not, copy from embed_tokens.weight
|
||||
if lm_head_key not in state_dict and embed_key in state_dict:
|
||||
# use clone to avoid sharing the same tensor
|
||||
# clone to avoid sharing gradients
|
||||
state_dict[lm_head_key] = torch.clone(state_dict[embed_key])
|
||||
|
||||
return super().load_state_dict(state_dict, strict, assign)
|
||||
|
|
@ -146,9 +127,8 @@ class Transformer(AutoModel):
|
|||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
persistent_key_values: Optional[Tuple[Tensor, Tensor]] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
start_pos: int = 0,
|
||||
slot_indices: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
|
||||
|
|
@ -157,13 +137,8 @@ class Transformer(AutoModel):
|
|||
|
||||
attn_mask = process_attention_mask(input_mask, x, start_pos, is_causal=True)
|
||||
|
||||
if slot_indices is None:
|
||||
slot_indices = slice(input_ids.size(0))
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(
|
||||
x, rotary_emb, attn_mask, persistent_key_values, start_pos, slot_indices
|
||||
)
|
||||
x = layer(x, rotary_emb, attn_mask, paged_cache, start_pos)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
|
|
|||
|
|
@ -6,102 +6,7 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from astrai.inference.cache import PrefixCacheManager
|
||||
from astrai.inference.scheduler import (
|
||||
InferenceScheduler,
|
||||
)
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_insert_find():
|
||||
"""Test concurrent insert and find operations."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
results = {"errors": [], "inserts": 0, "finds": 0}
|
||||
|
||||
def insert_worker():
|
||||
try:
|
||||
for i in range(50):
|
||||
cache.insert((i,), slot=i % 10, slot_ver=0)
|
||||
results["inserts"] += 1
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
def find_worker():
|
||||
try:
|
||||
for i in range(50):
|
||||
cache.find([i])
|
||||
results["finds"] += 1
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=insert_worker) for _ in range(3)]
|
||||
threads += [threading.Thread(target=find_worker) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
assert results["inserts"] == 150
|
||||
assert results["finds"] == 150
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_release():
|
||||
"""Test concurrent release operations."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
# Insert some prefixes
|
||||
for i in range(10):
|
||||
cache.insert((i,), slot=i, slot_ver=0)
|
||||
|
||||
results = {"errors": []}
|
||||
|
||||
def release_worker():
|
||||
try:
|
||||
for i in range(10):
|
||||
cache.release((i,))
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=release_worker) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
|
||||
|
||||
def test_prefix_cache_concurrent_insert_release_find():
|
||||
"""Test mixed concurrent operations."""
|
||||
cache = PrefixCacheManager(max_capacity=50)
|
||||
|
||||
results = {"errors": []}
|
||||
|
||||
def worker(worker_id):
|
||||
try:
|
||||
for i in range(20):
|
||||
token_ids = (worker_id * 100 + i,)
|
||||
cache.insert(token_ids, slot=worker_id, slot_ver=0)
|
||||
|
||||
# Find after insert
|
||||
cache.find(list(token_ids))
|
||||
|
||||
# Release
|
||||
cache.release(token_ids)
|
||||
except Exception as e:
|
||||
results["errors"].append(f"Worker {worker_id}: {str(e)}")
|
||||
|
||||
threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -266,54 +171,3 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|||
for stats in results["stats"]:
|
||||
assert "total_tasks" in stats
|
||||
assert stats["total_tasks"] >= 0
|
||||
|
||||
|
||||
def test_prefix_cache_insert_same_prefix_concurrently():
|
||||
"""Test inserting the same prefix concurrently."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
results = {"slot_values": [], "errors": []}
|
||||
|
||||
def insert_worker():
|
||||
try:
|
||||
# All workers try to insert the same prefix
|
||||
cache.insert((1, 2, 3), slot=0, slot_ver=0)
|
||||
node = cache.root.children.get(1)
|
||||
if node:
|
||||
node = node.children.get(2)
|
||||
if node:
|
||||
node = node.children.get(3)
|
||||
if node:
|
||||
results["slot_values"].append(node.slot)
|
||||
except Exception as e:
|
||||
results["errors"].append(str(e))
|
||||
|
||||
threads = [threading.Thread(target=insert_worker) for _ in range(10)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# All inserts should succeed, final slot should be one of the values
|
||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||
# Check ref_count is correct (should be 10)
|
||||
node = cache.root.children.get(1).children.get(2).children.get(3)
|
||||
assert node.ref_count == 10, f"Expected ref_count=10, got {node.ref_count}"
|
||||
|
||||
|
||||
def test_prefix_cache_ref_count_underflow_prevention():
|
||||
"""Test that ref_count doesn't go negative."""
|
||||
cache = PrefixCacheManager(max_capacity=100)
|
||||
|
||||
cache.insert((1, 2, 3), slot=0, slot_ver=0)
|
||||
|
||||
# Release multiple times
|
||||
for _ in range(5):
|
||||
cache.release((1, 2, 3))
|
||||
|
||||
# Try to find it - should return None since ref_count would be negative
|
||||
# or handle it gracefully
|
||||
node = cache.root.children.get(1).children.get(2).children.get(3)
|
||||
# The ref_count should be 0, not negative
|
||||
assert node.ref_count >= 0, f"ref_count went negative: {node.ref_count}"
|
||||
|
|
|
|||
Loading…
Reference in New Issue