refactor: 设计模式优化 inference 模块导入结构
- 新建 cache.py:SlotAllocator 对象池 + PrefixCacheManager - 新建 sampling.py:Temperature/TopK/TopP 可组合策略 - TaskStatus 改用 Enum,GenerationParams 值对象模式 - _STOP 移至 cache.py,解除 engine→scheduler 轻量耦合 - 更新测试导入路径,ruff 格式检查通过
This commit is contained in:
parent
c4401512f2
commit
44d7a4e959
|
|
@ -1,25 +1,46 @@
|
||||||
"""Inference module for continuous batching."""
|
"""Inference module for continuous batching.
|
||||||
|
|
||||||
|
Layers:
|
||||||
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
||||||
|
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
|
||||||
|
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
|
||||||
|
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
|
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
|
||||||
|
"""
|
||||||
|
|
||||||
from astrai.inference.engine import (
|
from astrai.inference.engine import (
|
||||||
|
GenerationParams,
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
|
from astrai.inference.sampling import (
|
||||||
|
BaseSamplingStrategy,
|
||||||
|
SamplingPipeline,
|
||||||
|
TemperatureStrategy,
|
||||||
|
TopKStrategy,
|
||||||
|
TopPStrategy,
|
||||||
|
apply_sampling_strategies,
|
||||||
|
)
|
||||||
from astrai.inference.scheduler import (
|
from astrai.inference.scheduler import (
|
||||||
InferenceScheduler,
|
InferenceScheduler,
|
||||||
Task,
|
Task,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
apply_sampling_strategies,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Engine
|
# Engine / Requests
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
|
"GenerationRequest",
|
||||||
|
"GenerationParams",
|
||||||
# Scheduler
|
# Scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
"Task",
|
"Task",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
# Request
|
# Sampling (Strategy pattern)
|
||||||
"GenerationRequest",
|
|
||||||
# Sampling
|
|
||||||
"apply_sampling_strategies",
|
"apply_sampling_strategies",
|
||||||
|
"BaseSamplingStrategy",
|
||||||
|
"TemperatureStrategy",
|
||||||
|
"TopKStrategy",
|
||||||
|
"TopPStrategy",
|
||||||
|
"SamplingPipeline",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,241 @@
|
||||||
|
"""KV cache slot allocation and prefix cache management.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask.
|
||||||
|
- PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
_STOP = object()
|
||||||
|
|
||||||
|
|
||||||
|
class _RadixNode:
|
||||||
|
"""Internal node for the radix tree prefix cache.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
children: Mapping from token ID to child node.
|
||||||
|
slot: KV cache slot index for the prefix ending at this node.
|
||||||
|
slot_ver: Version counter of the slot at insertion time.
|
||||||
|
ref_count: Number of tasks currently referencing this node.
|
||||||
|
last_access: Timestamp of the most recent access (for LRU ordering).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.children: Dict[int, "_RadixNode"] = {}
|
||||||
|
self.slot: int = -1
|
||||||
|
self.slot_ver: int = 0
|
||||||
|
self.ref_count: int = 0
|
||||||
|
self.last_access: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class SlotAllocator:
|
||||||
|
"""KV cache slot allocator using bitmask for O(1) alloc/free.
|
||||||
|
|
||||||
|
Implements the Object Pool pattern: pre-allocated KV cache slots
|
||||||
|
are managed via a bitmask, providing constant-time allocation and
|
||||||
|
deallocation with version counters for staleness detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_slots: int):
|
||||||
|
self._max_slots = max_slots
|
||||||
|
self._free_mask = (1 << max_slots) - 1
|
||||||
|
self._versions: List[int] = [0] * max_slots
|
||||||
|
|
||||||
|
def alloc(self) -> int:
|
||||||
|
"""Allocates a free slot.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Slot index on success, -1 if all slots are occupied.
|
||||||
|
"""
|
||||||
|
lsb = self._free_mask & -self._free_mask
|
||||||
|
if lsb == 0:
|
||||||
|
return -1
|
||||||
|
idx = lsb.bit_length() - 1
|
||||||
|
self._free_mask ^= lsb
|
||||||
|
self._versions[idx] += 1
|
||||||
|
return idx
|
||||||
|
|
||||||
|
def free(self, idx: int) -> None:
|
||||||
|
"""Releases a slot back to the free pool."""
|
||||||
|
self._free_mask |= 1 << idx
|
||||||
|
|
||||||
|
def occupy(self, idx: int) -> None:
|
||||||
|
"""Marks a currently free slot as occupied without bumping its version.
|
||||||
|
|
||||||
|
Used for direct slot reuse when a prefix-cached slot is still valid.
|
||||||
|
"""
|
||||||
|
self._free_mask ^= 1 << idx
|
||||||
|
|
||||||
|
def is_free(self, idx: int) -> bool:
|
||||||
|
"""Checks whether a slot is currently free."""
|
||||||
|
return (self._free_mask >> idx) & 1 == 1
|
||||||
|
|
||||||
|
def version(self, idx: int) -> int:
|
||||||
|
"""Returns the current version counter for a slot."""
|
||||||
|
return self._versions[idx]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def free_count(self) -> int:
|
||||||
|
"""Returns the number of currently free slots."""
|
||||||
|
return self._free_mask.bit_count()
|
||||||
|
|
||||||
|
|
||||||
|
class PrefixCacheManager:
|
||||||
|
"""Radix-tree prefix cache with LRU eviction.
|
||||||
|
|
||||||
|
Maps token ID sequences to KV cache slots. Intermediate tree nodes
|
||||||
|
also store slot information, allowing direct slot reuse when the
|
||||||
|
cached slot is free and its version matches (no intervening writes).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_capacity: int = 1000):
|
||||||
|
"""Initializes the prefix cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_capacity: Maximum number of nodes in the LRU list.
|
||||||
|
"""
|
||||||
|
self.root = _RadixNode()
|
||||||
|
self.max_capacity = max_capacity
|
||||||
|
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
|
||||||
|
|
||||||
|
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
|
||||||
|
"""Inserts a token sequence into the prefix cache.
|
||||||
|
|
||||||
|
Every node along the path records the slot and its version,
|
||||||
|
enabling direct slot reuse for partial prefix matches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: The token ID sequence to cache.
|
||||||
|
slot: The KV cache slot containing this prefix's computed keys/values.
|
||||||
|
slot_ver: The slot version at insertion time, used for staleness detection.
|
||||||
|
"""
|
||||||
|
node = self.root
|
||||||
|
for tid in token_ids:
|
||||||
|
nxt = node.children.get(tid)
|
||||||
|
if nxt is None:
|
||||||
|
nxt = _RadixNode()
|
||||||
|
node.children[tid] = nxt
|
||||||
|
node = nxt
|
||||||
|
node.slot = slot
|
||||||
|
node.slot_ver = slot_ver
|
||||||
|
node.last_access = time.time()
|
||||||
|
self._lru[id(node)] = node
|
||||||
|
node.ref_count += 1
|
||||||
|
self._evict_if_needed()
|
||||||
|
|
||||||
|
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
|
||||||
|
"""Finds the longest matching prefix in the cache.
|
||||||
|
|
||||||
|
Walks the radix tree token by token, recording the deepest match.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: The token sequence to match against.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (prefix_len, slot, slot_ver):
|
||||||
|
prefix_len: Number of matching tokens (0 if no match).
|
||||||
|
slot: KV cache slot of the matched prefix (-1 if no match).
|
||||||
|
slot_ver: Version of that slot when the prefix was inserted.
|
||||||
|
"""
|
||||||
|
node = self.root
|
||||||
|
best_len, best_slot, best_ver = 0, -1, 0
|
||||||
|
for i, tid in enumerate(token_ids):
|
||||||
|
nxt = node.children.get(tid)
|
||||||
|
if nxt is None:
|
||||||
|
break
|
||||||
|
node = nxt
|
||||||
|
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
|
||||||
|
node.last_access = time.time()
|
||||||
|
self._lru.move_to_end(id(node))
|
||||||
|
return best_len, best_slot, best_ver
|
||||||
|
|
||||||
|
def pin(self, token_ids: Tuple[int, ...]) -> None:
|
||||||
|
"""Increments the reference count of a cached prefix.
|
||||||
|
|
||||||
|
Called when a task reuses a cached prefix to prevent eviction.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: The token sequence whose node's ref_count to increment.
|
||||||
|
"""
|
||||||
|
node = self.root
|
||||||
|
for tid in token_ids:
|
||||||
|
nxt = node.children.get(tid)
|
||||||
|
if nxt is None:
|
||||||
|
return
|
||||||
|
node = nxt
|
||||||
|
node.ref_count += 1
|
||||||
|
|
||||||
|
def release(self, token_ids: Tuple[int, ...]) -> None:
|
||||||
|
"""Decrements the reference count of a cached prefix.
|
||||||
|
|
||||||
|
The node's slot is preserved even when ref_count reaches zero,
|
||||||
|
allowing future tasks to reuse the slot directly if it remains free.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: The token sequence whose node's ref_count to decrement.
|
||||||
|
"""
|
||||||
|
node = self.root
|
||||||
|
for tid in token_ids:
|
||||||
|
nxt = node.children.get(tid)
|
||||||
|
if nxt is None:
|
||||||
|
return
|
||||||
|
node = nxt
|
||||||
|
if node.ref_count > 0:
|
||||||
|
node.ref_count -= 1
|
||||||
|
|
||||||
|
def copy_kv(
|
||||||
|
self,
|
||||||
|
token_ids: Tuple[int, ...],
|
||||||
|
target_slot: int,
|
||||||
|
kv_cache: Tuple[Tensor, Tensor],
|
||||||
|
n_layers: int,
|
||||||
|
) -> None:
|
||||||
|
"""Copies cached KV data from the source slot to a target slot.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: The prefix token sequence identifying the source cache node.
|
||||||
|
target_slot: The destination KV cache slot to copy into.
|
||||||
|
kv_cache: Tuple of (k_cache, v_cache) tensors.
|
||||||
|
n_layers: Number of transformer layers to copy.
|
||||||
|
"""
|
||||||
|
node = self.root
|
||||||
|
for tid in token_ids:
|
||||||
|
nxt = node.children.get(tid)
|
||||||
|
if nxt is None:
|
||||||
|
return
|
||||||
|
node = nxt
|
||||||
|
src_slot = node.slot
|
||||||
|
if src_slot < 0:
|
||||||
|
return
|
||||||
|
prefix_len = len(token_ids)
|
||||||
|
k_cache, v_cache = kv_cache
|
||||||
|
for li in range(n_layers):
|
||||||
|
k_cache[target_slot, :prefix_len, li].copy_(
|
||||||
|
k_cache[src_slot, :prefix_len, li]
|
||||||
|
)
|
||||||
|
v_cache[target_slot, :prefix_len, li].copy_(
|
||||||
|
v_cache[src_slot, :prefix_len, li]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _evict_if_needed(self) -> None:
|
||||||
|
"""Evicts least-recently-used nodes until under capacity.
|
||||||
|
|
||||||
|
Skips nodes with ref_count > 0 (still in use by active tasks).
|
||||||
|
Evicted nodes have their slot and children cleared.
|
||||||
|
"""
|
||||||
|
while len(self._lru) > self.max_capacity:
|
||||||
|
key, node = next(iter(self._lru.items()))
|
||||||
|
if node.ref_count > 0:
|
||||||
|
self._lru.move_to_end(key)
|
||||||
|
continue
|
||||||
|
self._lru.pop(key)
|
||||||
|
node.slot = -1
|
||||||
|
node.slot_ver = 0
|
||||||
|
node.children.clear()
|
||||||
|
|
@ -1,22 +1,41 @@
|
||||||
"""Unified inference engine for continuous batching."""
|
"""Unified inference engine for continuous batching.
|
||||||
|
|
||||||
|
Layers:
|
||||||
|
- GenerationParams: Immutable value object for sampling parameters.
|
||||||
|
- GenerationRequest: User-facing request DTO with validation.
|
||||||
|
- _Result: Thread-safe token accumulator (Observer pattern).
|
||||||
|
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import threading
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.inference.scheduler import _STOP, InferenceScheduler
|
from astrai.inference.cache import _STOP
|
||||||
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class GenerationParams:
|
||||||
|
"""Immutable value object for sampling hyperparameters."""
|
||||||
|
|
||||||
|
top_k: int = 50
|
||||||
|
top_p: float = 1.0
|
||||||
|
temperature: float = 1.0
|
||||||
|
max_tokens: int = 1024
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation.
|
"""Request parameters for text generation.
|
||||||
|
|
||||||
Encapsulates messages, sampling parameters, and streaming preference
|
Encapsulates messages, sampling parameters (via GenerationParams),
|
||||||
for a single generation request.
|
and streaming preference for a single generation request.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -39,13 +58,31 @@ class GenerationRequest:
|
||||||
stream: Whether to return output as a token stream.
|
stream: Whether to return output as a token stream.
|
||||||
"""
|
"""
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.top_k = top_k
|
self.params = GenerationParams(
|
||||||
self.top_p = top_p
|
top_k=top_k,
|
||||||
self.temperature = temperature
|
top_p=top_p,
|
||||||
self.max_len = max_len
|
temperature=temperature,
|
||||||
|
max_tokens=max_len,
|
||||||
|
)
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self._validate()
|
self._validate()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def top_k(self) -> int:
|
||||||
|
return self.params.top_k
|
||||||
|
|
||||||
|
@property
|
||||||
|
def top_p(self) -> float:
|
||||||
|
return self.params.top_p
|
||||||
|
|
||||||
|
@property
|
||||||
|
def temperature(self) -> float:
|
||||||
|
return self.params.temperature
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_len(self) -> int:
|
||||||
|
return self.params.max_tokens
|
||||||
|
|
||||||
def _validate(self):
|
def _validate(self):
|
||||||
"""Validates sampling parameter ranges."""
|
"""Validates sampling parameter ranges."""
|
||||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||||
|
|
@ -296,10 +333,10 @@ class InferenceEngine:
|
||||||
return self.generate(
|
return self.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
max_tokens=request.max_len,
|
max_tokens=request.params.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.params.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.params.top_p,
|
||||||
top_k=request.top_k,
|
top_k=request.params.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_streaming(
|
def _generate_streaming(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,129 @@
|
||||||
|
"""Composable sampling strategies for logit transformation.
|
||||||
|
|
||||||
|
Implements the Strategy pattern: each sampling technique
|
||||||
|
(temperature, top-k, top-p) is a pluggable strategy that
|
||||||
|
can be composed into a pipeline.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSamplingStrategy(ABC):
|
||||||
|
"""Abstract base for a logit transformation strategy."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
|
||||||
|
"""Applies the strategy to logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Raw logits tensor (batch, vocab_size).
|
||||||
|
filter_value: Value assigned to filtered-out positions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformed logits tensor (may be the same or a new tensor).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class TemperatureStrategy(BaseSamplingStrategy):
|
||||||
|
"""Divides logits by temperature to control randomness."""
|
||||||
|
|
||||||
|
def __init__(self, temperature: float = 1.0):
|
||||||
|
self.temperature = temperature
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
if self.temperature != 1.0:
|
||||||
|
logits = logits / self.temperature
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class TopKStrategy(BaseSamplingStrategy):
|
||||||
|
"""Keeps only the top-k logits, setting the rest to filter_value."""
|
||||||
|
|
||||||
|
def __init__(self, top_k: int = 0):
|
||||||
|
self.top_k = top_k
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
if self.top_k > 0:
|
||||||
|
k = min(self.top_k, logits.size(-1))
|
||||||
|
topk_vals = torch.topk(logits, k, dim=-1)[0]
|
||||||
|
threshold = topk_vals[..., -1, None]
|
||||||
|
indices = logits < threshold
|
||||||
|
logits[indices] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class TopPStrategy(BaseSamplingStrategy):
|
||||||
|
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
|
||||||
|
cumulative probability exceeds top_p."""
|
||||||
|
|
||||||
|
def __init__(self, top_p: float = 1.0):
|
||||||
|
self.top_p = top_p
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
if self.top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||||
|
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
sorted_indices_to_remove = cum_probs > self.top_p
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
||||||
|
..., :-1
|
||||||
|
].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
||||||
|
indices_to_remove.scatter_(
|
||||||
|
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||||
|
)
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class SamplingPipeline(BaseSamplingStrategy):
|
||||||
|
"""Composes multiple sampling strategies into a single transformation.
|
||||||
|
|
||||||
|
Strategies are applied sequentially in the order they are provided,
|
||||||
|
matching the original temperature → top-k → top-p ordering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, strategies: List[BaseSamplingStrategy]):
|
||||||
|
self.strategies = strategies
|
||||||
|
|
||||||
|
def apply(self, logits, filter_value=-float("inf")):
|
||||||
|
logits = logits.clone()
|
||||||
|
for strategy in self.strategies:
|
||||||
|
logits = strategy.apply(logits, filter_value)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def apply_sampling_strategies(
|
||||||
|
logits: Tensor,
|
||||||
|
temperature: float,
|
||||||
|
top_k: int,
|
||||||
|
top_p: float,
|
||||||
|
filter_value: float = -float("inf"),
|
||||||
|
) -> Tensor:
|
||||||
|
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
|
||||||
|
|
||||||
|
Backward-compatible function that delegates to the Strategy pattern
|
||||||
|
pipeline with TemperatureStrategy → TopKStrategy → TopPStrategy ordering.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: Raw logits tensor of shape (batch, vocab_size).
|
||||||
|
temperature: Temperature scaling factor (1.0 = no scaling).
|
||||||
|
top_k: Keep only top-k logits (0 disables).
|
||||||
|
top_p: Nucleus probability threshold (1.0 disables).
|
||||||
|
filter_value: Value to assign to filtered-out positions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified logits tensor with same shape as input.
|
||||||
|
"""
|
||||||
|
pipeline = SamplingPipeline(
|
||||||
|
[
|
||||||
|
TemperatureStrategy(temperature),
|
||||||
|
TopKStrategy(top_k),
|
||||||
|
TopPStrategy(top_p),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return pipeline.apply(logits, filter_value)
|
||||||
|
|
@ -1,200 +1,30 @@
|
||||||
"""Inference scheduler for single-GPU continuous batching."""
|
"""Inference scheduler for single-GPU continuous batching.
|
||||||
|
|
||||||
|
Splits scheduling concerns across modules:
|
||||||
|
- cache.py: SlotAllocator (Object Pool), PrefixCacheManager
|
||||||
|
- sampling.py: Strategy-pattern logit transformations
|
||||||
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import OrderedDict
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator
|
||||||
|
from astrai.inference.sampling import apply_sampling_strategies
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_STOP = object()
|
|
||||||
|
|
||||||
|
class TaskStatus(Enum):
|
||||||
class _RadixNode:
|
"""Task states in the continuous batching lifecycle."""
|
||||||
"""Internal node for the radix tree prefix cache.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
children: Mapping from token ID to child node.
|
|
||||||
slot: KV cache slot index for the prefix ending at this node.
|
|
||||||
slot_ver: Version counter of the slot at insertion time.
|
|
||||||
ref_count: Number of tasks currently referencing this node.
|
|
||||||
last_access: Timestamp of the most recent access (for LRU ordering).
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.children: Dict[int, "_RadixNode"] = {}
|
|
||||||
self.slot: int = -1
|
|
||||||
self.slot_ver: int = 0
|
|
||||||
self.ref_count: int = 0
|
|
||||||
self.last_access: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class PrefixCacheManager:
|
|
||||||
"""Radix-tree prefix cache with LRU eviction.
|
|
||||||
|
|
||||||
Maps token ID sequences to KV cache slots. Intermediate tree nodes
|
|
||||||
also store slot information, allowing direct slot reuse when the
|
|
||||||
cached slot is free and its version matches (no intervening writes).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, max_capacity: int = 1000):
|
|
||||||
"""Initializes the prefix cache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_capacity: Maximum number of nodes in the LRU list.
|
|
||||||
"""
|
|
||||||
self.root = _RadixNode()
|
|
||||||
self.max_capacity = max_capacity
|
|
||||||
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
|
|
||||||
|
|
||||||
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
|
|
||||||
"""Inserts a token sequence into the prefix cache.
|
|
||||||
|
|
||||||
Every node along the path records the slot and its version,
|
|
||||||
enabling direct slot reuse for partial prefix matches.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids: The token ID sequence to cache.
|
|
||||||
slot: The KV cache slot containing this prefix's computed keys/values.
|
|
||||||
slot_ver: The slot version at insertion time, used for staleness detection.
|
|
||||||
"""
|
|
||||||
node = self.root
|
|
||||||
for tid in token_ids:
|
|
||||||
nxt = node.children.get(tid)
|
|
||||||
if nxt is None:
|
|
||||||
nxt = _RadixNode()
|
|
||||||
node.children[tid] = nxt
|
|
||||||
node = nxt
|
|
||||||
node.slot = slot
|
|
||||||
node.slot_ver = slot_ver
|
|
||||||
node.last_access = time.time()
|
|
||||||
self._lru[id(node)] = node
|
|
||||||
node.ref_count += 1
|
|
||||||
self._evict_if_needed()
|
|
||||||
|
|
||||||
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
|
|
||||||
"""Finds the longest matching prefix in the cache.
|
|
||||||
|
|
||||||
Walks the radix tree token by token, recording the deepest match.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids: The token sequence to match against.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (prefix_len, slot, slot_ver):
|
|
||||||
prefix_len: Number of matching tokens (0 if no match).
|
|
||||||
slot: KV cache slot of the matched prefix (-1 if no match).
|
|
||||||
slot_ver: Version of that slot when the prefix was inserted.
|
|
||||||
"""
|
|
||||||
node = self.root
|
|
||||||
best_len, best_slot, best_ver = 0, -1, 0
|
|
||||||
for i, tid in enumerate(token_ids):
|
|
||||||
nxt = node.children.get(tid)
|
|
||||||
if nxt is None:
|
|
||||||
break
|
|
||||||
node = nxt
|
|
||||||
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
|
|
||||||
node.last_access = time.time()
|
|
||||||
self._lru.move_to_end(id(node))
|
|
||||||
return best_len, best_slot, best_ver
|
|
||||||
|
|
||||||
def pin(self, token_ids: Tuple[int, ...]) -> None:
|
|
||||||
"""Increments the reference count of a cached prefix.
|
|
||||||
|
|
||||||
Called when a task reuses a cached prefix to prevent eviction.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids: The token sequence whose node's ref_count to increment.
|
|
||||||
"""
|
|
||||||
node = self.root
|
|
||||||
for tid in token_ids:
|
|
||||||
nxt = node.children.get(tid)
|
|
||||||
if nxt is None:
|
|
||||||
return
|
|
||||||
node = nxt
|
|
||||||
node.ref_count += 1
|
|
||||||
|
|
||||||
def release(self, token_ids: Tuple[int, ...]) -> None:
|
|
||||||
"""Decrements the reference count of a cached prefix.
|
|
||||||
|
|
||||||
The node's slot is preserved even when ref_count reaches zero,
|
|
||||||
allowing future tasks to reuse the slot directly if it remains free.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids: The token sequence whose node's ref_count to decrement.
|
|
||||||
"""
|
|
||||||
node = self.root
|
|
||||||
for tid in token_ids:
|
|
||||||
nxt = node.children.get(tid)
|
|
||||||
if nxt is None:
|
|
||||||
return
|
|
||||||
node = nxt
|
|
||||||
if node.ref_count > 0:
|
|
||||||
node.ref_count -= 1
|
|
||||||
|
|
||||||
def copy_kv(
|
|
||||||
self,
|
|
||||||
token_ids: Tuple[int, ...],
|
|
||||||
target_slot: int,
|
|
||||||
kv_cache: Tuple[Tensor, Tensor],
|
|
||||||
n_layers: int,
|
|
||||||
) -> None:
|
|
||||||
"""Copies cached KV data from the source slot to a target slot.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token_ids: The prefix token sequence identifying the source cache node.
|
|
||||||
target_slot: The destination KV cache slot to copy into.
|
|
||||||
kv_cache: Tuple of (k_cache, v_cache) tensors.
|
|
||||||
n_layers: Number of transformer layers to copy.
|
|
||||||
"""
|
|
||||||
node = self.root
|
|
||||||
for tid in token_ids:
|
|
||||||
nxt = node.children.get(tid)
|
|
||||||
if nxt is None:
|
|
||||||
return
|
|
||||||
node = nxt
|
|
||||||
src_slot = node.slot
|
|
||||||
if src_slot < 0:
|
|
||||||
return
|
|
||||||
prefix_len = len(token_ids)
|
|
||||||
k_cache, v_cache = kv_cache
|
|
||||||
for li in range(n_layers):
|
|
||||||
k_cache[target_slot, :prefix_len, li].copy_(
|
|
||||||
k_cache[src_slot, :prefix_len, li]
|
|
||||||
)
|
|
||||||
v_cache[target_slot, :prefix_len, li].copy_(
|
|
||||||
v_cache[src_slot, :prefix_len, li]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _evict_if_needed(self) -> None:
|
|
||||||
"""Evicts least-recently-used nodes until under capacity.
|
|
||||||
|
|
||||||
Skips nodes with ref_count > 0 (still in use by active tasks).
|
|
||||||
Evicted nodes have their slot and children cleared.
|
|
||||||
"""
|
|
||||||
while len(self._lru) > self.max_capacity:
|
|
||||||
key, node = next(iter(self._lru.items()))
|
|
||||||
if node.ref_count > 0:
|
|
||||||
self._lru.move_to_end(key)
|
|
||||||
continue
|
|
||||||
self._lru.pop(key)
|
|
||||||
node.slot = -1
|
|
||||||
node.slot_ver = 0
|
|
||||||
node.children.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus:
|
|
||||||
"""Enum-like task states in the continuous batching lifecycle."""
|
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
|
|
@ -286,46 +116,6 @@ class Task:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
|
||||||
logits: Tensor,
|
|
||||||
temperature: float,
|
|
||||||
top_k: int,
|
|
||||||
top_p: float,
|
|
||||||
filter_value: float = -float("inf"),
|
|
||||||
) -> Tensor:
|
|
||||||
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
logits: Raw logits tensor of shape (batch, vocab_size).
|
|
||||||
temperature: Temperature scaling factor (1.0 = no scaling).
|
|
||||||
top_k: Keep only top-k logits (0 disables).
|
|
||||||
top_p: Nucleus probability threshold (1.0 disables).
|
|
||||||
filter_value: Value to assign to filtered-out positions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Modified logits tensor with same shape as input.
|
|
||||||
"""
|
|
||||||
logits = logits.clone()
|
|
||||||
if temperature != 1.0:
|
|
||||||
logits = logits / temperature
|
|
||||||
if top_k > 0:
|
|
||||||
top_k = min(top_k, logits.size(-1))
|
|
||||||
indices = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
|
|
||||||
logits[indices] = filter_value
|
|
||||||
if top_p < 1.0:
|
|
||||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
|
||||||
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
||||||
sorted_indices_to_remove = cum_probs > top_p
|
|
||||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
|
||||||
sorted_indices_to_remove[..., 0] = 0
|
|
||||||
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
|
|
||||||
indices_to_remove.scatter_(
|
|
||||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
|
||||||
)
|
|
||||||
logits[indices_to_remove] = filter_value
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceScheduler:
|
class InferenceScheduler:
|
||||||
"""Continuous batching scheduler for single-GPU inference.
|
"""Continuous batching scheduler for single-GPU inference.
|
||||||
|
|
||||||
|
|
@ -397,8 +187,7 @@ class InferenceScheduler:
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._free_slots = (1 << max_batch_size) - 1
|
self.slot_allocator = SlotAllocator(max_batch_size)
|
||||||
self._slot_ver: List[int] = [0] * max_batch_size
|
|
||||||
self.waiting_queue: List[Task] = []
|
self.waiting_queue: List[Task] = []
|
||||||
self.active_tasks: List[Task] = []
|
self.active_tasks: List[Task] = []
|
||||||
|
|
||||||
|
|
@ -410,18 +199,12 @@ class InferenceScheduler:
|
||||||
self._total_tokens = 0
|
self._total_tokens = 0
|
||||||
|
|
||||||
def _alloc_slot(self) -> int:
|
def _alloc_slot(self) -> int:
|
||||||
"""Allocates a free KV cache slot using a bitmask.
|
"""Allocates a free KV cache slot using the Object Pool.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Slot index on success, -1 if all slots are occupied.
|
Slot index on success, -1 if all slots are occupied.
|
||||||
"""
|
"""
|
||||||
lsb = self._free_slots & -self._free_slots
|
return self.slot_allocator.alloc()
|
||||||
if lsb == 0:
|
|
||||||
return -1
|
|
||||||
idx = lsb.bit_length() - 1
|
|
||||||
self._free_slots ^= lsb
|
|
||||||
self._slot_ver[idx] += 1
|
|
||||||
return idx
|
|
||||||
|
|
||||||
def _free_slot(self, idx: int) -> None:
|
def _free_slot(self, idx: int) -> None:
|
||||||
"""Releases a KV cache slot back to the free pool.
|
"""Releases a KV cache slot back to the free pool.
|
||||||
|
|
@ -429,7 +212,7 @@ class InferenceScheduler:
|
||||||
Args:
|
Args:
|
||||||
idx: Slot index to free.
|
idx: Slot index to free.
|
||||||
"""
|
"""
|
||||||
self._free_slots |= 1 << idx
|
self.slot_allocator.free(idx)
|
||||||
self.seq_mask[idx, :] = False
|
self.seq_mask[idx, :] = False
|
||||||
|
|
||||||
def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]:
|
def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]:
|
||||||
|
|
@ -445,9 +228,9 @@ class InferenceScheduler:
|
||||||
Tuple of (slot, True) on success, or (-1, False) if reuse is not possible.
|
Tuple of (slot, True) on success, or (-1, False) if reuse is not possible.
|
||||||
"""
|
"""
|
||||||
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
||||||
if cached_slot >= 0 and (self._free_slots >> cached_slot) & 1:
|
if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot):
|
||||||
if cached_ver == self._slot_ver[cached_slot]:
|
if cached_ver == self.slot_allocator.version(cached_slot):
|
||||||
self._free_slots ^= 1 << cached_slot
|
self.slot_allocator.occupy(cached_slot)
|
||||||
return cached_slot, True
|
return cached_slot, True
|
||||||
return -1, False
|
return -1, False
|
||||||
|
|
||||||
|
|
@ -588,7 +371,9 @@ class InferenceScheduler:
|
||||||
if task.prefix_len > 0 and not reused:
|
if task.prefix_len > 0 and not reused:
|
||||||
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
prefix = tuple(task.prompt_ids[: task.prefix_len])
|
||||||
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
|
||||||
if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]:
|
if cached_slot >= 0 and cached_ver == self.slot_allocator.version(
|
||||||
|
cached_slot
|
||||||
|
):
|
||||||
self.prefix_cache.pin(prefix)
|
self.prefix_cache.pin(prefix)
|
||||||
self.prefix_cache.copy_kv(
|
self.prefix_cache.copy_kv(
|
||||||
prefix, slot, self.kv_cache, self._n_layers
|
prefix, slot, self.kv_cache, self._n_layers
|
||||||
|
|
@ -663,7 +448,7 @@ class InferenceScheduler:
|
||||||
t.input_tokens = len(t.prompt_ids)
|
t.input_tokens = len(t.prompt_ids)
|
||||||
t.output_tokens = 0
|
t.output_tokens = 0
|
||||||
self.prefix_cache.insert(
|
self.prefix_cache.insert(
|
||||||
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
|
tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot)
|
||||||
)
|
)
|
||||||
if t.slot >= 0:
|
if t.slot >= 0:
|
||||||
self.seq_mask[t.slot, : t.input_tokens] = True
|
self.seq_mask[t.slot, : t.input_tokens] = True
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ def chat():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
|
||||||
model.to(device="cuda", dtype=torch.bfloat16)
|
model.to(device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
messages = []
|
messages = [{"role": "system", "content": "You are a helpful assistant."}]
|
||||||
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
engine = InferenceEngine(model=model, tokenizer=tokenizer)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,9 @@ from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from astrai.inference.cache import PrefixCacheManager
|
||||||
from astrai.inference.scheduler import (
|
from astrai.inference.scheduler import (
|
||||||
InferenceScheduler,
|
InferenceScheduler,
|
||||||
PrefixCacheManager,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue