refactor: 设计模式优化 inference 模块导入结构

- 新建 cache.py:SlotAllocator 对象池 + PrefixCacheManager

- 新建 sampling.py:Temperature/TopK/TopP 可组合策略

- TaskStatus 改用 Enum,GenerationParams 值对象模式

- _STOP 移至 cache.py,解除 engine→scheduler 轻量耦合

- 更新测试导入路径,ruff 格式检查通过
This commit is contained in:
ViperEkura 2026-05-08 16:55:24 +08:00
parent c4401512f2
commit 44d7a4e959
7 changed files with 470 additions and 257 deletions

View File

@ -1,25 +1,46 @@
"""Inference module for continuous batching."""
"""Inference module for continuous batching.
Layers:
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum
- cache.py: Object Pool (SlotAllocator), PrefixCacheManager
- sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
- server.py: FastAPI HTTP server (OpenAI-compatible endpoints)
"""
from astrai.inference.engine import (
GenerationParams,
GenerationRequest,
InferenceEngine,
)
from astrai.inference.sampling import (
BaseSamplingStrategy,
SamplingPipeline,
TemperatureStrategy,
TopKStrategy,
TopPStrategy,
apply_sampling_strategies,
)
from astrai.inference.scheduler import (
InferenceScheduler,
Task,
TaskStatus,
apply_sampling_strategies,
)
__all__ = [
# Engine
# Engine / Requests
"InferenceEngine",
"GenerationRequest",
"GenerationParams",
# Scheduler
"InferenceScheduler",
"Task",
"TaskStatus",
# Request
"GenerationRequest",
# Sampling
# Sampling (Strategy pattern)
"apply_sampling_strategies",
"BaseSamplingStrategy",
"TemperatureStrategy",
"TopKStrategy",
"TopPStrategy",
"SamplingPipeline",
]

241
astrai/inference/cache.py Normal file
View File

@ -0,0 +1,241 @@
"""KV cache slot allocation and prefix cache management.
Provides:
- SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask.
- PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse.
"""
import time
from collections import OrderedDict
from typing import Dict, List, Tuple
from torch import Tensor
_STOP = object()
class _RadixNode:
"""Internal node for the radix tree prefix cache.
Attributes:
children: Mapping from token ID to child node.
slot: KV cache slot index for the prefix ending at this node.
slot_ver: Version counter of the slot at insertion time.
ref_count: Number of tasks currently referencing this node.
last_access: Timestamp of the most recent access (for LRU ordering).
"""
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
def __init__(self):
self.children: Dict[int, "_RadixNode"] = {}
self.slot: int = -1
self.slot_ver: int = 0
self.ref_count: int = 0
self.last_access: float = 0.0
class SlotAllocator:
"""KV cache slot allocator using bitmask for O(1) alloc/free.
Implements the Object Pool pattern: pre-allocated KV cache slots
are managed via a bitmask, providing constant-time allocation and
deallocation with version counters for staleness detection.
"""
def __init__(self, max_slots: int):
self._max_slots = max_slots
self._free_mask = (1 << max_slots) - 1
self._versions: List[int] = [0] * max_slots
def alloc(self) -> int:
"""Allocates a free slot.
Returns:
Slot index on success, -1 if all slots are occupied.
"""
lsb = self._free_mask & -self._free_mask
if lsb == 0:
return -1
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._versions[idx] += 1
return idx
def free(self, idx: int) -> None:
"""Releases a slot back to the free pool."""
self._free_mask |= 1 << idx
def occupy(self, idx: int) -> None:
"""Marks a currently free slot as occupied without bumping its version.
Used for direct slot reuse when a prefix-cached slot is still valid.
"""
self._free_mask ^= 1 << idx
def is_free(self, idx: int) -> bool:
"""Checks whether a slot is currently free."""
return (self._free_mask >> idx) & 1 == 1
def version(self, idx: int) -> int:
"""Returns the current version counter for a slot."""
return self._versions[idx]
@property
def free_count(self) -> int:
"""Returns the number of currently free slots."""
return self._free_mask.bit_count()
class PrefixCacheManager:
"""Radix-tree prefix cache with LRU eviction.
Maps token ID sequences to KV cache slots. Intermediate tree nodes
also store slot information, allowing direct slot reuse when the
cached slot is free and its version matches (no intervening writes).
"""
def __init__(self, max_capacity: int = 1000):
"""Initializes the prefix cache.
Args:
max_capacity: Maximum number of nodes in the LRU list.
"""
self.root = _RadixNode()
self.max_capacity = max_capacity
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
"""Inserts a token sequence into the prefix cache.
Every node along the path records the slot and its version,
enabling direct slot reuse for partial prefix matches.
Args:
token_ids: The token ID sequence to cache.
slot: The KV cache slot containing this prefix's computed keys/values.
slot_ver: The slot version at insertion time, used for staleness detection.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
nxt = _RadixNode()
node.children[tid] = nxt
node = nxt
node.slot = slot
node.slot_ver = slot_ver
node.last_access = time.time()
self._lru[id(node)] = node
node.ref_count += 1
self._evict_if_needed()
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
"""Finds the longest matching prefix in the cache.
Walks the radix tree token by token, recording the deepest match.
Args:
token_ids: The token sequence to match against.
Returns:
Tuple of (prefix_len, slot, slot_ver):
prefix_len: Number of matching tokens (0 if no match).
slot: KV cache slot of the matched prefix (-1 if no match).
slot_ver: Version of that slot when the prefix was inserted.
"""
node = self.root
best_len, best_slot, best_ver = 0, -1, 0
for i, tid in enumerate(token_ids):
nxt = node.children.get(tid)
if nxt is None:
break
node = nxt
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
node.last_access = time.time()
self._lru.move_to_end(id(node))
return best_len, best_slot, best_ver
def pin(self, token_ids: Tuple[int, ...]) -> None:
"""Increments the reference count of a cached prefix.
Called when a task reuses a cached prefix to prevent eviction.
Args:
token_ids: The token sequence whose node's ref_count to increment.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
node.ref_count += 1
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Decrements the reference count of a cached prefix.
The node's slot is preserved even when ref_count reaches zero,
allowing future tasks to reuse the slot directly if it remains free.
Args:
token_ids: The token sequence whose node's ref_count to decrement.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
if node.ref_count > 0:
node.ref_count -= 1
def copy_kv(
self,
token_ids: Tuple[int, ...],
target_slot: int,
kv_cache: Tuple[Tensor, Tensor],
n_layers: int,
) -> None:
"""Copies cached KV data from the source slot to a target slot.
Args:
token_ids: The prefix token sequence identifying the source cache node.
target_slot: The destination KV cache slot to copy into.
kv_cache: Tuple of (k_cache, v_cache) tensors.
n_layers: Number of transformer layers to copy.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
src_slot = node.slot
if src_slot < 0:
return
prefix_len = len(token_ids)
k_cache, v_cache = kv_cache
for li in range(n_layers):
k_cache[target_slot, :prefix_len, li].copy_(
k_cache[src_slot, :prefix_len, li]
)
v_cache[target_slot, :prefix_len, li].copy_(
v_cache[src_slot, :prefix_len, li]
)
def _evict_if_needed(self) -> None:
"""Evicts least-recently-used nodes until under capacity.
Skips nodes with ref_count > 0 (still in use by active tasks).
Evicted nodes have their slot and children cleared.
"""
while len(self._lru) > self.max_capacity:
key, node = next(iter(self._lru.items()))
if node.ref_count > 0:
self._lru.move_to_end(key)
continue
self._lru.pop(key)
node.slot = -1
node.slot_ver = 0
node.children.clear()

View File

@ -1,22 +1,41 @@
"""Unified inference engine for continuous batching."""
"""Unified inference engine for continuous batching.
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
import asyncio
import gc
import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
from astrai.inference.scheduler import _STOP, InferenceScheduler
from astrai.inference.cache import _STOP
from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
class GenerationRequest:
"""Request parameters for text generation.
Encapsulates messages, sampling parameters, and streaming preference
for a single generation request.
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
def __init__(
@ -39,13 +58,31 @@ class GenerationRequest:
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_len = max_len
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.stream = stream
self._validate()
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self):
"""Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
@ -296,10 +333,10 @@ class InferenceEngine:
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.max_len,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
max_tokens=request.params.max_tokens,
temperature=request.params.temperature,
top_p=request.params.top_p,
top_k=request.params.top_k,
)
def _generate_streaming(

View File

@ -0,0 +1,129 @@
"""Composable sampling strategies for logit transformation.
Implements the Strategy pattern: each sampling technique
(temperature, top-k, top-p) is a pluggable strategy that
can be composed into a pipeline.
"""
from abc import ABC, abstractmethod
from typing import List
import torch
from torch import Tensor
class BaseSamplingStrategy(ABC):
"""Abstract base for a logit transformation strategy."""
@abstractmethod
def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor:
"""Applies the strategy to logits.
Args:
logits: Raw logits tensor (batch, vocab_size).
filter_value: Value assigned to filtered-out positions.
Returns:
Transformed logits tensor (may be the same or a new tensor).
"""
class TemperatureStrategy(BaseSamplingStrategy):
"""Divides logits by temperature to control randomness."""
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def apply(self, logits, filter_value=-float("inf")):
if self.temperature != 1.0:
logits = logits / self.temperature
return logits
class TopKStrategy(BaseSamplingStrategy):
"""Keeps only the top-k logits, setting the rest to filter_value."""
def __init__(self, top_k: int = 0):
self.top_k = top_k
def apply(self, logits, filter_value=-float("inf")):
if self.top_k > 0:
k = min(self.top_k, logits.size(-1))
topk_vals = torch.topk(logits, k, dim=-1)[0]
threshold = topk_vals[..., -1, None]
indices = logits < threshold
logits[indices] = filter_value
return logits
class TopPStrategy(BaseSamplingStrategy):
"""Nucleus (top-p) filtering: keeps the smallest set of tokens whose
cumulative probability exceeds top_p."""
def __init__(self, top_p: float = 1.0):
self.top_p = top_p
def apply(self, logits, filter_value=-float("inf")):
if self.top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > self.top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
..., :-1
].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class SamplingPipeline(BaseSamplingStrategy):
"""Composes multiple sampling strategies into a single transformation.
Strategies are applied sequentially in the order they are provided,
matching the original temperature top-k top-p ordering.
"""
def __init__(self, strategies: List[BaseSamplingStrategy]):
self.strategies = strategies
def apply(self, logits, filter_value=-float("inf")):
logits = logits.clone()
for strategy in self.strategies:
logits = strategy.apply(logits, filter_value)
return logits
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
Backward-compatible function that delegates to the Strategy pattern
pipeline with TemperatureStrategy TopKStrategy TopPStrategy ordering.
Args:
logits: Raw logits tensor of shape (batch, vocab_size).
temperature: Temperature scaling factor (1.0 = no scaling).
top_k: Keep only top-k logits (0 disables).
top_p: Nucleus probability threshold (1.0 disables).
filter_value: Value to assign to filtered-out positions.
Returns:
Modified logits tensor with same shape as input.
"""
pipeline = SamplingPipeline(
[
TemperatureStrategy(temperature),
TopKStrategy(top_k),
TopPStrategy(top_p),
]
)
return pipeline.apply(logits, filter_value)

View File

@ -1,200 +1,30 @@
"""Inference scheduler for single-GPU continuous batching."""
"""Inference scheduler for single-GPU continuous batching.
Splits scheduling concerns across modules:
- cache.py: SlotAllocator (Object Pool), PrefixCacheManager
- sampling.py: Strategy-pattern logit transformations
"""
import logging
import threading
import time
import uuid
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator
from astrai.inference.sampling import apply_sampling_strategies
from astrai.model.automodel import AutoModel
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
_STOP = object()
class _RadixNode:
"""Internal node for the radix tree prefix cache.
Attributes:
children: Mapping from token ID to child node.
slot: KV cache slot index for the prefix ending at this node.
slot_ver: Version counter of the slot at insertion time.
ref_count: Number of tasks currently referencing this node.
last_access: Timestamp of the most recent access (for LRU ordering).
"""
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
def __init__(self):
self.children: Dict[int, "_RadixNode"] = {}
self.slot: int = -1
self.slot_ver: int = 0
self.ref_count: int = 0
self.last_access: float = 0.0
class PrefixCacheManager:
"""Radix-tree prefix cache with LRU eviction.
Maps token ID sequences to KV cache slots. Intermediate tree nodes
also store slot information, allowing direct slot reuse when the
cached slot is free and its version matches (no intervening writes).
"""
def __init__(self, max_capacity: int = 1000):
"""Initializes the prefix cache.
Args:
max_capacity: Maximum number of nodes in the LRU list.
"""
self.root = _RadixNode()
self.max_capacity = max_capacity
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
"""Inserts a token sequence into the prefix cache.
Every node along the path records the slot and its version,
enabling direct slot reuse for partial prefix matches.
Args:
token_ids: The token ID sequence to cache.
slot: The KV cache slot containing this prefix's computed keys/values.
slot_ver: The slot version at insertion time, used for staleness detection.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
nxt = _RadixNode()
node.children[tid] = nxt
node = nxt
node.slot = slot
node.slot_ver = slot_ver
node.last_access = time.time()
self._lru[id(node)] = node
node.ref_count += 1
self._evict_if_needed()
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
"""Finds the longest matching prefix in the cache.
Walks the radix tree token by token, recording the deepest match.
Args:
token_ids: The token sequence to match against.
Returns:
Tuple of (prefix_len, slot, slot_ver):
prefix_len: Number of matching tokens (0 if no match).
slot: KV cache slot of the matched prefix (-1 if no match).
slot_ver: Version of that slot when the prefix was inserted.
"""
node = self.root
best_len, best_slot, best_ver = 0, -1, 0
for i, tid in enumerate(token_ids):
nxt = node.children.get(tid)
if nxt is None:
break
node = nxt
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
node.last_access = time.time()
self._lru.move_to_end(id(node))
return best_len, best_slot, best_ver
def pin(self, token_ids: Tuple[int, ...]) -> None:
"""Increments the reference count of a cached prefix.
Called when a task reuses a cached prefix to prevent eviction.
Args:
token_ids: The token sequence whose node's ref_count to increment.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
node.ref_count += 1
def release(self, token_ids: Tuple[int, ...]) -> None:
"""Decrements the reference count of a cached prefix.
The node's slot is preserved even when ref_count reaches zero,
allowing future tasks to reuse the slot directly if it remains free.
Args:
token_ids: The token sequence whose node's ref_count to decrement.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
if node.ref_count > 0:
node.ref_count -= 1
def copy_kv(
self,
token_ids: Tuple[int, ...],
target_slot: int,
kv_cache: Tuple[Tensor, Tensor],
n_layers: int,
) -> None:
"""Copies cached KV data from the source slot to a target slot.
Args:
token_ids: The prefix token sequence identifying the source cache node.
target_slot: The destination KV cache slot to copy into.
kv_cache: Tuple of (k_cache, v_cache) tensors.
n_layers: Number of transformer layers to copy.
"""
node = self.root
for tid in token_ids:
nxt = node.children.get(tid)
if nxt is None:
return
node = nxt
src_slot = node.slot
if src_slot < 0:
return
prefix_len = len(token_ids)
k_cache, v_cache = kv_cache
for li in range(n_layers):
k_cache[target_slot, :prefix_len, li].copy_(
k_cache[src_slot, :prefix_len, li]
)
v_cache[target_slot, :prefix_len, li].copy_(
v_cache[src_slot, :prefix_len, li]
)
def _evict_if_needed(self) -> None:
"""Evicts least-recently-used nodes until under capacity.
Skips nodes with ref_count > 0 (still in use by active tasks).
Evicted nodes have their slot and children cleared.
"""
while len(self._lru) > self.max_capacity:
key, node = next(iter(self._lru.items()))
if node.ref_count > 0:
self._lru.move_to_end(key)
continue
self._lru.pop(key)
node.slot = -1
node.slot_ver = 0
node.children.clear()
class TaskStatus:
"""Enum-like task states in the continuous batching lifecycle."""
class TaskStatus(Enum):
"""Task states in the continuous batching lifecycle."""
PENDING = "pending"
RUNNING = "running"
@ -286,46 +116,6 @@ class Task:
return False
def apply_sampling_strategies(
logits: Tensor,
temperature: float,
top_k: int,
top_p: float,
filter_value: float = -float("inf"),
) -> Tensor:
"""Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering.
Args:
logits: Raw logits tensor of shape (batch, vocab_size).
temperature: Temperature scaling factor (1.0 = no scaling).
top_k: Keep only top-k logits (0 disables).
top_p: Nucleus probability threshold (1.0 disables).
filter_value: Value to assign to filtered-out positions.
Returns:
Modified logits tensor with same shape as input.
"""
logits = logits.clone()
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
top_k = min(top_k, logits.size(-1))
indices = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None]
logits[indices] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits[indices_to_remove] = filter_value
return logits
class InferenceScheduler:
"""Continuous batching scheduler for single-GPU inference.
@ -397,8 +187,7 @@ class InferenceScheduler:
dtype=torch.bool,
)
self._free_slots = (1 << max_batch_size) - 1
self._slot_ver: List[int] = [0] * max_batch_size
self.slot_allocator = SlotAllocator(max_batch_size)
self.waiting_queue: List[Task] = []
self.active_tasks: List[Task] = []
@ -410,18 +199,12 @@ class InferenceScheduler:
self._total_tokens = 0
def _alloc_slot(self) -> int:
"""Allocates a free KV cache slot using a bitmask.
"""Allocates a free KV cache slot using the Object Pool.
Returns:
Slot index on success, -1 if all slots are occupied.
"""
lsb = self._free_slots & -self._free_slots
if lsb == 0:
return -1
idx = lsb.bit_length() - 1
self._free_slots ^= lsb
self._slot_ver[idx] += 1
return idx
return self.slot_allocator.alloc()
def _free_slot(self, idx: int) -> None:
"""Releases a KV cache slot back to the free pool.
@ -429,7 +212,7 @@ class InferenceScheduler:
Args:
idx: Slot index to free.
"""
self._free_slots |= 1 << idx
self.slot_allocator.free(idx)
self.seq_mask[idx, :] = False
def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]:
@ -445,9 +228,9 @@ class InferenceScheduler:
Tuple of (slot, True) on success, or (-1, False) if reuse is not possible.
"""
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
if cached_slot >= 0 and (self._free_slots >> cached_slot) & 1:
if cached_ver == self._slot_ver[cached_slot]:
self._free_slots ^= 1 << cached_slot
if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot):
if cached_ver == self.slot_allocator.version(cached_slot):
self.slot_allocator.occupy(cached_slot)
return cached_slot, True
return -1, False
@ -588,7 +371,9 @@ class InferenceScheduler:
if task.prefix_len > 0 and not reused:
prefix = tuple(task.prompt_ids[: task.prefix_len])
_plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix))
if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]:
if cached_slot >= 0 and cached_ver == self.slot_allocator.version(
cached_slot
):
self.prefix_cache.pin(prefix)
self.prefix_cache.copy_kv(
prefix, slot, self.kv_cache, self._n_layers
@ -663,7 +448,7 @@ class InferenceScheduler:
t.input_tokens = len(t.prompt_ids)
t.output_tokens = 0
self.prefix_cache.insert(
tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot]
tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot)
)
if t.slot >= 0:
self.seq_mask[t.slot, : t.input_tokens] = True

View File

@ -15,7 +15,7 @@ def chat():
tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT)
model.to(device="cuda", dtype=torch.bfloat16)
messages = []
messages = [{"role": "system", "content": "You are a helpful assistant."}]
engine = InferenceEngine(model=model, tokenizer=tokenizer)
while True:

View File

@ -6,9 +6,9 @@ from unittest.mock import MagicMock, patch
import pytest
from astrai.inference.cache import PrefixCacheManager
from astrai.inference.scheduler import (
InferenceScheduler,
PrefixCacheManager,
)