refactor: 设计模式优化 inference 模块导入结构
- 新建 cache.py:SlotAllocator 对象池 + PrefixCacheManager - 新建 sampling.py:Temperature/TopK/TopP 可组合策略 - TaskStatus 改用 Enum,GenerationParams 值对象模式 - _STOP 移至 cache.py,解除 engine→scheduler 轻量耦合 - 更新测试导入路径,ruff 格式检查通过
This commit is contained in:
parent
c4401512f2
commit
44d7a4e959
|
|
@ -1,25 +1,46 @@
|
|||
"""Inference module for continuous batching."""
|
||||
"""Inference module for continuous batching.
|
||||
|
||||
Layers:
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
||||
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
|
||||
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
|
||||
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
|
||||
"""
|
||||
|
||||
from astrai.inference.engine import (
|
||||
GenerationParams,
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
from astrai.inference.sampling import (
|
||||
BaseSamplingStrategy,
|
||||
SamplingPipeline,
|
||||
TemperatureStrategy,
|
||||
TopKStrategy,
|
||||
TopPStrategy,
|
||||
apply_sampling_strategies,
|
||||
)
|
||||
from astrai.inference.scheduler import (
|
||||
InferenceScheduler,
|
||||
Task,
|
||||
TaskStatus,
|
||||
apply_sampling_strategies,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Engine
|
||||
# Engine / Requests
|
||||
"InferenceEngine",
|
||||
"GenerationRequest",
|
||||
"GenerationParams",
|
||||
# Scheduler
|
||||
"InferenceScheduler",
|
||||
"Task",
|
||||
"TaskStatus",
|
||||
# Request
|
||||
"GenerationRequest",
|
||||
# Sampling
|
||||
# Sampling (Strategy pattern)
|
||||
"apply_sampling_strategies",
|
||||
"BaseSamplingStrategy",
|
||||
"TemperatureStrategy",
|
||||
"TopKStrategy",
|
||||
"TopPStrategy",
|
||||
"SamplingPipeline",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,241 @@
|
|||
"""KV cache slot allocation and prefix cache management.
|
||||
|
||||
Provides:
|
||||
- SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask.
|
||||
- PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
_STOP = object()
|
||||
|
||||
|
||||
class _RadixNode:
|
||||
"""Internal node for the radix tree prefix cache.
|
||||
|
||||
Attributes:
|
||||
children: Mapping from token ID to child node.
|
||||
slot: KV cache slot index for the prefix ending at this node.
|
||||
slot_ver: Version counter of the slot at insertion time.
|
||||
ref_count: Number of tasks currently referencing this node.
|
||||
last_access: Timestamp of the most recent access (for LRU ordering).
|
||||
"""
|
||||
|
||||
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
|
||||
|
||||
def __init__(self):
|
||||
self.children: Dict[int, "_RadixNode"] = {}
|
||||
self.slot: int = -1
|
||||
self.slot_ver: int = 0
|
||||
self.ref_count: int = 0
|
||||
self.last_access: float = 0.0
|
||||
|
||||
|
||||
class SlotAllocator:
|
||||
"""KV cache slot allocator using bitmask for O(1) alloc/free.
|
||||
|
||||
Implements the Object Pool pattern: pre-allocated KV cache slots
|
||||
are managed via a bitmask, providing constant-time allocation and
|
||||
deallocation with version counters for staleness detection.
|
||||
"""
|
||||
|
||||
def __init__(self, max_slots: int):
|
||||
self._max_slots = max_slots
|
||||
self._free_mask = (1 << max_slots) - 1
|
||||
self._versions: List[int] = [0] * max_slots
|
||||
|
||||
def alloc(self) -> int:
|
||||
"""Allocates a free slot.
|
||||
|
||||
Returns:
|
||||
Slot index on success, -1 if all slots are occupied.
|
||||
"""
|
||||
lsb = self._free_mask & -self._free_mask
|
||||
if lsb == 0:
|
||||
return -1
|
||||
idx = lsb.bit_length() - 1
|
||||
self._free_mask ^= lsb
|
||||
self._versions[idx] += 1
|
||||
return idx
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
"""Releases a slot back to the free pool."""
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def occupy(self, idx: int) -> None:
|
||||
"""Marks a currently free slot as occupied without bumping its version.
|
||||
|
||||
Used for direct slot reuse when a prefix-cached slot is still valid.
|
||||
"""
|
||||
self._free_mask ^= 1 << idx
|
||||
|
||||
def is_free(self, idx: int) -> bool:
|
||||
"""Checks whether a slot is currently free."""
|
||||
return (self._free_mask >> idx) & 1 == 1
|
||||
|
||||
def version(self, idx: int) -> int:
|
||||
"""Returns the current version counter for a slot."""
|
||||
return self._versions[idx]
|
||||
|
||||
@property
|
||||
def free_count(self) -> int:
|
||||
"""Returns the number of currently free slots."""
|
||||
return self._free_mask.bit_count()
|
||||
|
||||
|
||||
class PrefixCacheManager:
|
||||
"""Radix-tree prefix cache with LRU eviction.
|
||||
|
||||
Maps token ID sequences to KV cache slots. Intermediate tree nodes
|
||||
also store slot information, allowing direct slot reuse when the
|
||||
cached slot is free and its version matches (no intervening writes).
|
||||
"""
|
||||
|
||||
def __init__(self, max_capacity: int = 1000):
|
||||
"""Initializes the prefix cache.
|
||||
|
||||
Args:
|
||||
max_capacity: Maximum number of nodes in the LRU list.
|
||||
"""
|
||||
self.root = _RadixNode()
|
||||
self.max_capacity = max_capacity
|
||||
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
|
||||
|
||||
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
|
||||
"""Inserts a token sequence into the prefix cache.
|
||||
|
||||
Every node along the path records the slot and its version,
|
||||
enabling direct slot reuse for partial prefix matches.
|
||||
|
||||
Args:
|
||||
token_ids: The token ID sequence to cache.
|
||||
slot: The KV cache slot containing this prefix's computed keys/values.
|
||||
slot_ver: The slot version at insertion time, used for staleness detection.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
nxt = _RadixNode()
|
||||
node.children[tid] = nxt
|
||||
node = nxt
|
||||
node.slot = slot
|
||||
node.slot_ver = slot_ver
|
||||
node.last_access = time.time()
|
||||
self._lru[id(node)] = node
|
||||
node.ref_count += 1
|
||||
self._evict_if_needed()
|
||||
|
||||
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
|
||||
"""Finds the longest matching prefix in the cache.
|
||||
|
||||
Walks the radix tree token by token, recording the deepest match.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence to match against.
|
||||
|
||||
Returns:
|
||||
Tuple of (prefix_len, slot, slot_ver):
|
||||
prefix_len: Number of matching tokens (0 if no match).
|
||||
slot: KV cache slot of the matched prefix (-1 if no match).
|
||||
slot_ver: Version of that slot when the prefix was inserted.
|
||||
"""
|
||||
node = self.root
|
||||
best_len, best_slot, best_ver = 0, -1, 0
|
||||
for i, tid in enumerate(token_ids):
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
break
|
||||
node = nxt
|
||||
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
|
||||
node.last_access = time.time()
|
||||
self._lru.move_to_end(id(node))
|
||||
return best_len, best_slot, best_ver
|
||||
|
||||
def pin(self, token_ids: Tuple[int, ...]) -> None:
|
||||
"""Increments the reference count of a cached prefix.
|
||||
|
||||
Called when a task reuses a cached prefix to prevent eviction.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence whose node's ref_count to increment.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
node.ref_count += 1
|
||||
|
||||
def release(self, token_ids: Tuple[int, ...]) -> None:
|
||||
"""Decrements the reference count of a cached prefix.
|
||||
|
||||
The node's slot is preserved even when ref_count reaches zero,
|
||||
allowing future tasks to reuse the slot directly if it remains free.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence whose node's ref_count to decrement.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
if node.ref_count > 0:
|
||||
node.ref_count -= 1
|
||||
|
||||
def copy_kv(
|
||||
self,
|
||||
token_ids: Tuple[int, ...],
|
||||
target_slot: int,
|
||||
kv_cache: Tuple[Tensor, Tensor],
|
||||
n_layers: int,
|
||||
) -> None:
|
||||
"""Copies cached KV data from the source slot to a target slot.
|
||||
|
||||
Args:
|
||||
token_ids: The prefix token sequence identifying the source cache node.
|
||||
target_slot: The destination KV cache slot to copy into.
|
||||
kv_cache: Tuple of (k_cache, v_cache) tensors.
|
||||
n_layers: Number of transformer layers to copy.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
src_slot = node.slot
|
||||
if src_slot < 0:
|
||||
return
|
||||
prefix_len = len(token_ids)
|
||||
k_cache, v_cache = kv_cache
|
||||
for li in range(n_layers):
|
||||
k_cache[target_slot, :prefix_len, li].copy_(
|
||||
k_cache[src_slot, :prefix_len, li]
|
||||
)
|
||||
v_cache[target_slot, :prefix_len, li].copy_(
|
||||
v_cache[src_slot, :prefix_len, li]
|
||||
)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evicts least-recently-used nodes until under capacity.
|
||||
|
||||
Skips nodes with ref_count > 0 (still in use by active tasks).
|
||||
Evicted nodes have their slot and children cleared.
|
||||
"""
|
||||
while len(self._lru) > self.max_capacity:
|
||||
key, node = next(iter(self._lru.items()))
|
||||
if node.ref_count > 0:
|
||||
self._lru.move_to_end(key)
|
||||
continue
|
||||
self._lru.pop(key)
|
||||
node.slot = -1
|
||||
node.slot_ver = 0
|
||||
node.children.clear()
|
||||
|
|
@ -1,22 +1,41 @@
|
|||
"""Unified inference engine for continuous batching."""
|
||||
"""Unified inference engine for continuous batching.
|
||||
|
||||
Layers:
|
||||
- GenerationParams: Immutable value object for sampling parameters.
|
||||
- GenerationRequest: User-facing request DTO with validation.
|
||||
- _Result: Thread-safe token accumulator (Observer pattern).
|
||||
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.inference.scheduler import _STOP, InferenceScheduler
|
||||
from astrai.inference.cache import _STOP
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationParams:
|
||||
"""Immutable value object for sampling hyperparameters."""
|
||||
|
||||
top_k: int = 50
|
||||
top_p: float = 1.0
|
||||
temperature: float = 1.0
|
||||
max_tokens: int = 1024
|
||||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation.
|
||||
|
||||
Encapsulates messages, sampling parameters, and streaming preference
|
||||
for a single generation request.
|
||||
Encapsulates messages, sampling parameters (via GenerationParams),
|
||||
and streaming preference for a single generation request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -39,13 +58,31 @@ class GenerationRequest:
|
|||
stream: Whether to return output as a token stream.
|
||||
"""
|
||||
self.messages = messages
|
||||
self.top_k = top_k
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.max_len = max_len
|
||||
self.params = GenerationParams(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
max_tokens=max_len,
|
||||
)
|
||||
self.stream = stream
|
||||
self._validate()
|
||||
|
||||
@property
|
||||
def top_k(self) -> int:
|
||||
return self.params.top_k
|
||||
|
||||
@property
|
||||
def top_p(self) -> float:
|
||||
return self.params.top_p
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
return self.params.temperature
|
||||
|
||||
@property
|
||||
def max_len(self) -> int:
|
||||
return self.params.max_tokens
|
||||
|
||||
def _validate(self):
|
||||
"""Validates sampling parameter ranges."""
|
||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||
|
|
@ -296,10 +333,10 @@ class InferenceEngine:
|
|||
return self.generate(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
max_tokens=request.max_len,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
max_tokens=request.params.max_tokens,
|
||||
temperature=request.params.temperature,
|
||||
top_p=request.params.top_p,
|
||||
top_k=request.params.top_k,
|
||||
)
|
||||
|
||||
def _generate_streaming(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,129 @@
|
|||
"""Composable sampling strategies for logit transformation.
|
||||
|
||||
Implements the Strategy pattern: each sampling technique
|
||||
(temperature, top-k, top-p) is a pluggable strategy that
|
||||
can be composed into a pipeline.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class BaseSamplingStrategy(ABC):
|
||||
"""Abstract base for a logit transformation strategy."""
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||
"""Applies the strategy to logits.
|
||||
|
||||
Args:
|
||||
logits: Raw logits tensor (batch, vocab_size).
|
||||
filter_value: Value assigned to filtered-out positions.
|
||||
|
||||
Returns:
|
||||
Transformed logits tensor (may be the same or a new tensor).
|
||||
"""
|
||||
|
||||
|
||||
class TemperatureStrategy(BaseSamplingStrategy):
|
||||
"""Divides logits by temperature to control randomness."""
|
||||
|
||||
def __init__(self, temperature: float = 1.0):
|
||||
self.temperature = temperature
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
if self.temperature != 1.0:
|
||||
logits = logits / self.temperature
|
||||
return logits
|
||||
|
||||
|
||||
class TopKStrategy(BaseSamplingStrategy):
|
||||
"""Keeps only the top-k logits, setting the rest to filter_value."""
|
||||
|
||||
def __init__(self, top_k: int = 0):
|
||||
self.top_k = top_k
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
if self.top_k > 0:
|
||||
k = min(self.top_k, logits.size(-1))
|
||||
topk_vals = torch.topk(logits, k, dim=-1)[0]
|
||||
threshold = topk_vals[..., -1, None]
|
||||
indices = logits < threshold
|
||||
logits[indices] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
class TopPStrategy(BaseSamplingStrategy):
|
||||
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
|
||||
cumulative probability exceeds top_p."""
|
||||
|
||||
def __init__(self, top_p: float = 1.0):
|
||||
self.top_p = top_p
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
if self.top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > self.top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||
..., :-1
|
||||
].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
class SamplingPipeline(BaseSamplingStrategy):
|
||||
"""Composes multiple sampling strategies into a single transformation.
|
||||
|
||||
Strategies are applied sequentially in the order they are provided,
|
||||
matching the original temperature → top-k → top-p ordering.
|
||||
"""
|
||||
|
||||
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
||||
self.strategies = strategies
|
||||
|
||||
def apply(self, logits, filter_value=-float("inf")):
|
||||
logits = logits.clone()
|
||||
for strategy in self.strategies:
|
||||
logits = strategy.apply(logits, filter_value)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
logits: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
filter_value: float = -float("inf"),
|
||||
) -> Tensor:
|
||||
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
|
||||
|
||||
Backward-compatible function that delegates to the Strategy pattern
|
||||
pipeline with TemperatureStrategy → TopKStrategy → TopPStrategy ordering.
|
||||
|
||||
Args:
|
||||
logits: Raw logits tensor of shape (batch, vocab_size).
|
||||
temperature: Temperature scaling factor (1.0 = no scaling).
|
||||
top_k: Keep only top-k logits (0 disables).
|
||||
top_p: Nucleus probability threshold (1.0 disables).
|
||||
filter_value: Value to assign to filtered-out positions.
|
||||
|
||||
Returns:
|
||||
Modified logits tensor with same shape as input.
|
||||
"""
|
||||
pipeline = SamplingPipeline(
|
||||
[
|
||||
TemperatureStrategy(temperature),
|
||||
TopKStrategy(top_k),
|
||||
TopPStrategy(top_p),
|
||||
]
|
||||
)
|
||||
return pipeline.apply(logits, filter_value)
|
||||
|
|
@ -1,200 +1,30 @@
|
|||
"""Inference scheduler for single-GPU continuous batching."""
|
||||
"""Inference scheduler for single-GPU continuous batching.
|
||||
|
||||
Splits scheduling concerns across modules:
|
||||
- cache.py: SlotAllocator (Object Pool), PrefixCacheManager
|
||||
- sampling.py: Strategy-pattern logit transformations
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator
|
||||
from astrai.inference.sampling import apply_sampling_strategies
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STOP = object()
|
||||
|
||||
|
||||
class _RadixNode:
|
||||
"""Internal node for the radix tree prefix cache.
|
||||
|
||||
Attributes:
|
||||
children: Mapping from token ID to child node.
|
||||
slot: KV cache slot index for the prefix ending at this node.
|
||||
slot_ver: Version counter of the slot at insertion time.
|
||||
ref_count: Number of tasks currently referencing this node.
|
||||
last_access: Timestamp of the most recent access (for LRU ordering).
|
||||
"""
|
||||
|
||||
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
|
||||
|
||||
def __init__(self):
|
||||
self.children: Dict[int, "_RadixNode"] = {}
|
||||
self.slot: int = -1
|
||||
self.slot_ver: int = 0
|
||||
self.ref_count: int = 0
|
||||
self.last_access: float = 0.0
|
||||
|
||||
|
||||
class PrefixCacheManager:
|
||||
"""Radix-tree prefix cache with LRU eviction.
|
||||
|
||||
Maps token ID sequences to KV cache slots. Intermediate tree nodes
|
||||
also store slot information, allowing direct slot reuse when the
|
||||
cached slot is free and its version matches (no intervening writes).
|
||||
"""
|
||||
|
||||
def __init__(self, max_capacity: int = 1000):
|
||||
"""Initializes the prefix cache.
|
||||
|
||||
Args:
|
||||
max_capacity: Maximum number of nodes in the LRU list.
|
||||
"""
|
||||
self.root = _RadixNode()
|
||||
self.max_capacity = max_capacity
|
||||
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
|
||||
|
||||
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
|
||||
"""Inserts a token sequence into the prefix cache.
|
||||
|
||||
Every node along the path records the slot and its version,
|
||||
enabling direct slot reuse for partial prefix matches.
|
||||
|
||||
Args:
|
||||
token_ids: The token ID sequence to cache.
|
||||
slot: The KV cache slot containing this prefix's computed keys/values.
|
||||
slot_ver: The slot version at insertion time, used for staleness detection.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
nxt = _RadixNode()
|
||||
node.children[tid] = nxt
|
||||
node = nxt
|
||||
node.slot = slot
|
||||
node.slot_ver = slot_ver
|
||||
node.last_access = time.time()
|
||||
self._lru[id(node)] = node
|
||||
node.ref_count += 1
|
||||
self._evict_if_needed()
|
||||
|
||||
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
|
||||
"""Finds the longest matching prefix in the cache.
|
||||
|
||||
Walks the radix tree token by token, recording the deepest match.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence to match against.
|
||||
|
||||
Returns:
|
||||
Tuple of (prefix_len, slot, slot_ver):
|
||||
prefix_len: Number of matching tokens (0 if no match).
|
||||
slot: KV cache slot of the matched prefix (-1 if no match).
|
||||
slot_ver: Version of that slot when the prefix was inserted.
|
||||
"""
|
||||
node = self.root
|
||||
best_len, best_slot, best_ver = 0, -1, 0
|
||||
for i, tid in enumerate(token_ids):
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
break
|
||||
node = nxt
|
||||
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
|
||||
node.last_access = time.time()
|
||||
self._lru.move_to_end(id(node))
|
||||
return best_len, best_slot, best_ver
|
||||
|
||||
def pin(self, token_ids: Tuple[int, ...]) -> None:
|
||||
"""Increments the reference count of a cached prefix.
|
||||
|
||||
Called when a task reuses a cached prefix to prevent eviction.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence whose node's ref_count to increment.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
node.ref_count += 1
|
||||
|
||||
def release(self, token_ids: Tuple[int, ...]) -> None:
|
||||
"""Decrements the reference count of a cached prefix.
|
||||
|
||||
The node's slot is preserved even when ref_count reaches zero,
|
||||
allowing future tasks to reuse the slot directly if it remains free.
|
||||
|
||||
Args:
|
||||
token_ids: The token sequence whose node's ref_count to decrement.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
if node.ref_count > 0:
|
||||
node.ref_count -= 1
|
||||
|
||||
def copy_kv(
|
||||
self,
|
||||
token_ids: Tuple[int, ...],
|
||||
target_slot: int,
|
||||
kv_cache: Tuple[Tensor, Tensor],
|
||||
n_layers: int,
|
||||
) -> None:
|
||||
"""Copies cached KV data from the source slot to a target slot.
|
||||
|
||||
Args:
|
||||
token_ids: The prefix token sequence identifying the source cache node.
|
||||
target_slot: The destination KV cache slot to copy into.
|
||||
kv_cache: Tuple of (k_cache, v_cache) tensors.
|
||||
n_layers: Number of transformer layers to copy.
|
||||
"""
|
||||
node = self.root
|
||||
for tid in token_ids:
|
||||
nxt = node.children.get(tid)
|
||||
if nxt is None:
|
||||
return
|
||||
node = nxt
|
||||
src_slot = node.slot
|
||||
if src_slot < 0:
|
||||
return
|
||||
prefix_len = len(token_ids)
|
||||
k_cache, v_cache = kv_cache
|
||||
for li in range(n_layers):
|
||||
k_cache[target_slot, :prefix_len, li].copy_(
|
||||
k_cache[src_slot, :prefix_len, li]
|
||||
)
|
||||
v_cache[target_slot, :prefix_len, li].copy_(
|
||||
v_cache[src_slot, :prefix_len, li]
|
||||
)
|
||||
|
||||
def _evict_if_needed(self) -> None:
|
||||
"""Evicts least-recently-used nodes until under capacity.
|
||||
|
||||
Skips nodes with ref_count > 0 (still in use by active tasks).
|
||||
Evicted nodes have their slot and children cleared.
|
||||
"""
|
||||
while len(self._lru) > self.max_capacity:
|
||||
key, node = next(iter(self._lru.items()))
|
||||
if node.ref_count > 0:
|
||||
self._lru.move_to_end(key)
|
||||
continue
|
||||
self._lru.pop(key)
|
||||
node.slot = -1
|
||||
node.slot_ver = 0
|
||||
node.children.clear()
|
||||
|
||||
|
||||
class TaskStatus:
|
||||
"""Enum-like task states in the continuous batching lifecycle."""
|
||||
class TaskStatus(Enum):
|
||||
"""Task states in the continuous batching lifecycle."""
|
||||
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
|
|
@ -286,46 +116,6 @@ class Task:
|
|||
return False
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
logits: Tensor,
|
||||
temperature: float,
|
||||
top_k: int,
|
||||
top_p: float,
|
||||
filter_value: float = -float("inf"),
|
||||
) -> Tensor:
|
||||
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
|
||||
|
||||
Args:
|
||||
logits: Raw logits tensor of shape (batch, vocab_size).
|
||||
temperature: Temperature scaling factor (1.0 = no scaling).
|
||||
top_k: Keep only top-k logits (0 disables).
|
||||
top_p: Nucleus probability threshold (1.0 disables).
|
||||
filter_value: Value to assign to filtered-out positions.
|
||||
|
||||
Returns:
|
||||
Modified logits tensor with same shape as input.
|
||||
"""
|
||||
logits = logits.clone()
|
||||
if temperature != 1.0:
|
||||
logits = logits / temperature
|
||||
if top_k > 0:
|
||||
top_k = min(top_k, logits.size(-1))
|
||||
indices = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
||||
logits[indices] = filter_value
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||
indices_to_remove.scatter_(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
|
||||
class InferenceScheduler:
|
||||
"""Continuous batching scheduler for single-GPU inference.
|
||||
|
||||
|
|
@ -397,8 +187,7 @@ class InferenceScheduler:
|
|||
dtype=torch.bool,
|
||||
)
|
||||
|
||||
self._free_slots = (1 << max_batch_size) - 1
|
||||
self._slot_ver: List[int] = [0] * max_batch_size
|
||||
self.slot_allocator = SlotAllocator(max_batch_size)
|
||||
self.waiting_queue: List[Task] = []
|
||||
self.active_tasks: List[Task] = []
|
||||
|
||||
|
|
@ -410,18 +199,12 @@ class InferenceScheduler:
|
|||
self._total_tokens = 0
|
||||
|
||||
def _alloc_slot(self) -> int:
|
||||
"""Allocates a free KV cache slot using a bitmask.
|
||||
"""Allocates a free KV cache slot using the Object Pool.
|
||||
|
||||
Returns:
|
||||
Slot index on success, -1 if all slots are occupied.
|
||||
"""
|
||||
lsb = self._free_slots & -self._free_slots
|
||||
if lsb == 0:
|
||||
return -1
|
||||
idx = lsb.bit_length() - 1
|
||||
self._free_slots ^= lsb
|
||||
self._slot_ver[idx] += 1
|
||||
return idx
|
||||
return self.slot_allocator.alloc()
|
||||
|
||||
def _free_slot(self, idx: int) -> None:
|
||||
"""Releases a KV cache slot back to the free pool.
|
||||
|
|
@ -429,7 +212,7 @@ class InferenceScheduler:
|
|||
Args:
|
||||
idx: Slot index to free.
|
||||
"""
|
||||
self._free_slots |= 1 << idx
|
||||
self.slot_allocator.free(idx)
|
||||
self.seq_mask[idx, :] = False
|
||||
|
||||
def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]:
|
||||
|
|
@ -445,9 +228,9 @@ class InferenceScheduler:
|
|||
Tuple of (slot, True) on success, or (-1, False) if reuse is not possible.
|
||||
"""
|
||||
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
||||
if cached_slot >= 0 and (self._free_slots >> cached_slot) & 1:
|
||||
if cached_ver == self._slot_ver[cached_slot]:
|
||||
self._free_slots ^= 1 << cached_slot
|
||||
if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot):
|
||||
if cached_ver == self.slot_allocator.version(cached_slot):
|
||||
self.slot_allocator.occupy(cached_slot)
|
||||
return cached_slot, True
|
||||
return -1, False
|
||||
|
||||
|
|
@ -588,7 +371,9 @@ class InferenceScheduler:
|
|||
if task.prefix_len > 0 and not reused:
|
||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
||||
if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]:
|
||||
if cached_slot >= 0 and cached_ver == self.slot_allocator.version(
|
||||
cached_slot
|
||||
):
|
||||
self.prefix_cache.pin(prefix)
|
||||
self.prefix_cache.copy_kv(
|
||||
prefix, slot, self.kv_cache, self._n_layers
|
||||
|
|
@ -663,7 +448,7 @@ class InferenceScheduler:
|
|||
t.input_tokens = len(t.prompt_ids)
|
||||
t.output_tokens = 0
|
||||
self.prefix_cache.insert(
|
||||
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
|
||||
tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot)
|
||||
)
|
||||
if t.slot >= 0:
|
||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ def chat():
|
|||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||
model.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
messages = []
|
||||
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
||||
|
||||
while True:
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@ from unittest.mock import MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from astrai.inference.cache import PrefixCacheManager
|
||||
from astrai.inference.scheduler import (
|
||||
InferenceScheduler,
|
||||
PrefixCacheManager,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue