diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index b3f3e2e..e867d36 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -1,25 +1,46 @@ -"""Inference module for continuous batching.""" +"""Inference module for continuous batching. + +Layers: + - engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest) + - scheduler.py: Continuous-batching loop, Task state machine, TaskStatus enum + - cache.py: Object Pool (SlotAllocator), PrefixCacheManager + - sampling.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) + - server.py: FastAPI HTTP server (OpenAI-compatible endpoints) +""" from astrai.inference.engine import ( + GenerationParams, GenerationRequest, InferenceEngine, ) +from astrai.inference.sampling import ( + BaseSamplingStrategy, + SamplingPipeline, + TemperatureStrategy, + TopKStrategy, + TopPStrategy, + apply_sampling_strategies, +) from astrai.inference.scheduler import ( InferenceScheduler, Task, TaskStatus, - apply_sampling_strategies, ) __all__ = [ - # Engine + # Engine / Requests "InferenceEngine", + "GenerationRequest", + "GenerationParams", # Scheduler "InferenceScheduler", "Task", "TaskStatus", - # Request - "GenerationRequest", - # Sampling + # Sampling (Strategy pattern) "apply_sampling_strategies", + "BaseSamplingStrategy", + "TemperatureStrategy", + "TopKStrategy", + "TopPStrategy", + "SamplingPipeline", ] diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py new file mode 100644 index 0000000..28c96cd --- /dev/null +++ b/astrai/inference/cache.py @@ -0,0 +1,241 @@ +"""KV cache slot allocation and prefix cache management. + +Provides: + - SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask. + - PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse. +""" + +import time +from collections import OrderedDict +from typing import Dict, List, Tuple + +from torch import Tensor + +_STOP = object() + + +class _RadixNode: + """Internal node for the radix tree prefix cache. + + Attributes: + children: Mapping from token ID to child node. + slot: KV cache slot index for the prefix ending at this node. + slot_ver: Version counter of the slot at insertion time. + ref_count: Number of tasks currently referencing this node. + last_access: Timestamp of the most recent access (for LRU ordering). + """ + + __slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access") + + def __init__(self): + self.children: Dict[int, "_RadixNode"] = {} + self.slot: int = -1 + self.slot_ver: int = 0 + self.ref_count: int = 0 + self.last_access: float = 0.0 + + +class SlotAllocator: + """KV cache slot allocator using bitmask for O(1) alloc/free. + + Implements the Object Pool pattern: pre-allocated KV cache slots + are managed via a bitmask, providing constant-time allocation and + deallocation with version counters for staleness detection. + """ + + def __init__(self, max_slots: int): + self._max_slots = max_slots + self._free_mask = (1 << max_slots) - 1 + self._versions: List[int] = [0] * max_slots + + def alloc(self) -> int: + """Allocates a free slot. + + Returns: + Slot index on success, -1 if all slots are occupied. + """ + lsb = self._free_mask & -self._free_mask + if lsb == 0: + return -1 + idx = lsb.bit_length() - 1 + self._free_mask ^= lsb + self._versions[idx] += 1 + return idx + + def free(self, idx: int) -> None: + """Releases a slot back to the free pool.""" + self._free_mask |= 1 << idx + + def occupy(self, idx: int) -> None: + """Marks a currently free slot as occupied without bumping its version. + + Used for direct slot reuse when a prefix-cached slot is still valid. + """ + self._free_mask ^= 1 << idx + + def is_free(self, idx: int) -> bool: + """Checks whether a slot is currently free.""" + return (self._free_mask >> idx) & 1 == 1 + + def version(self, idx: int) -> int: + """Returns the current version counter for a slot.""" + return self._versions[idx] + + @property + def free_count(self) -> int: + """Returns the number of currently free slots.""" + return self._free_mask.bit_count() + + +class PrefixCacheManager: + """Radix-tree prefix cache with LRU eviction. + + Maps token ID sequences to KV cache slots. Intermediate tree nodes + also store slot information, allowing direct slot reuse when the + cached slot is free and its version matches (no intervening writes). + """ + + def __init__(self, max_capacity: int = 1000): + """Initializes the prefix cache. + + Args: + max_capacity: Maximum number of nodes in the LRU list. + """ + self.root = _RadixNode() + self.max_capacity = max_capacity + self._lru: OrderedDict[int, _RadixNode] = OrderedDict() + + def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None: + """Inserts a token sequence into the prefix cache. + + Every node along the path records the slot and its version, + enabling direct slot reuse for partial prefix matches. + + Args: + token_ids: The token ID sequence to cache. + slot: The KV cache slot containing this prefix's computed keys/values. + slot_ver: The slot version at insertion time, used for staleness detection. + """ + node = self.root + for tid in token_ids: + nxt = node.children.get(tid) + if nxt is None: + nxt = _RadixNode() + node.children[tid] = nxt + node = nxt + node.slot = slot + node.slot_ver = slot_ver + node.last_access = time.time() + self._lru[id(node)] = node + node.ref_count += 1 + self._evict_if_needed() + + def find(self, token_ids: List[int]) -> Tuple[int, int, int]: + """Finds the longest matching prefix in the cache. + + Walks the radix tree token by token, recording the deepest match. + + Args: + token_ids: The token sequence to match against. + + Returns: + Tuple of (prefix_len, slot, slot_ver): + prefix_len: Number of matching tokens (0 if no match). + slot: KV cache slot of the matched prefix (-1 if no match). + slot_ver: Version of that slot when the prefix was inserted. + """ + node = self.root + best_len, best_slot, best_ver = 0, -1, 0 + for i, tid in enumerate(token_ids): + nxt = node.children.get(tid) + if nxt is None: + break + node = nxt + best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver + node.last_access = time.time() + self._lru.move_to_end(id(node)) + return best_len, best_slot, best_ver + + def pin(self, token_ids: Tuple[int, ...]) -> None: + """Increments the reference count of a cached prefix. + + Called when a task reuses a cached prefix to prevent eviction. + + Args: + token_ids: The token sequence whose node's ref_count to increment. + """ + node = self.root + for tid in token_ids: + nxt = node.children.get(tid) + if nxt is None: + return + node = nxt + node.ref_count += 1 + + def release(self, token_ids: Tuple[int, ...]) -> None: + """Decrements the reference count of a cached prefix. + + The node's slot is preserved even when ref_count reaches zero, + allowing future tasks to reuse the slot directly if it remains free. + + Args: + token_ids: The token sequence whose node's ref_count to decrement. + """ + node = self.root + for tid in token_ids: + nxt = node.children.get(tid) + if nxt is None: + return + node = nxt + if node.ref_count > 0: + node.ref_count -= 1 + + def copy_kv( + self, + token_ids: Tuple[int, ...], + target_slot: int, + kv_cache: Tuple[Tensor, Tensor], + n_layers: int, + ) -> None: + """Copies cached KV data from the source slot to a target slot. + + Args: + token_ids: The prefix token sequence identifying the source cache node. + target_slot: The destination KV cache slot to copy into. + kv_cache: Tuple of (k_cache, v_cache) tensors. + n_layers: Number of transformer layers to copy. + """ + node = self.root + for tid in token_ids: + nxt = node.children.get(tid) + if nxt is None: + return + node = nxt + src_slot = node.slot + if src_slot < 0: + return + prefix_len = len(token_ids) + k_cache, v_cache = kv_cache + for li in range(n_layers): + k_cache[target_slot, :prefix_len, li].copy_( + k_cache[src_slot, :prefix_len, li] + ) + v_cache[target_slot, :prefix_len, li].copy_( + v_cache[src_slot, :prefix_len, li] + ) + + def _evict_if_needed(self) -> None: + """Evicts least-recently-used nodes until under capacity. + + Skips nodes with ref_count > 0 (still in use by active tasks). + Evicted nodes have their slot and children cleared. + """ + while len(self._lru) > self.max_capacity: + key, node = next(iter(self._lru.items())) + if node.ref_count > 0: + self._lru.move_to_end(key) + continue + self._lru.pop(key) + node.slot = -1 + node.slot_ver = 0 + node.children.clear() diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 77cd1d5..d3a50ce 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -1,22 +1,41 @@ -"""Unified inference engine for continuous batching.""" +"""Unified inference engine for continuous batching. + +Layers: + - GenerationParams: Immutable value object for sampling parameters. + - GenerationRequest: User-facing request DTO with validation. + - _Result: Thread-safe token accumulator (Observer pattern). + - InferenceEngine: Facade over InferenceScheduler + async wrapper. +""" import asyncio import gc import threading +from dataclasses import dataclass from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union import torch import torch.nn as nn -from astrai.inference.scheduler import _STOP, InferenceScheduler +from astrai.inference.cache import _STOP +from astrai.inference.scheduler import InferenceScheduler from astrai.tokenize import AutoTokenizer +@dataclass(frozen=True) +class GenerationParams: + """Immutable value object for sampling hyperparameters.""" + + top_k: int = 50 + top_p: float = 1.0 + temperature: float = 1.0 + max_tokens: int = 1024 + + class GenerationRequest: """Request parameters for text generation. - Encapsulates messages, sampling parameters, and streaming preference - for a single generation request. + Encapsulates messages, sampling parameters (via GenerationParams), + and streaming preference for a single generation request. """ def __init__( @@ -39,13 +58,31 @@ class GenerationRequest: stream: Whether to return output as a token stream. """ self.messages = messages - self.top_k = top_k - self.top_p = top_p - self.temperature = temperature - self.max_len = max_len + self.params = GenerationParams( + top_k=top_k, + top_p=top_p, + temperature=temperature, + max_tokens=max_len, + ) self.stream = stream self._validate() + @property + def top_k(self) -> int: + return self.params.top_k + + @property + def top_p(self) -> float: + return self.params.top_p + + @property + def temperature(self) -> float: + return self.params.temperature + + @property + def max_len(self) -> int: + return self.params.max_tokens + def _validate(self): """Validates sampling parameter ranges.""" if not (isinstance(self.top_k, int) and self.top_k >= 0): @@ -296,10 +333,10 @@ class InferenceEngine: return self.generate( prompt=prompt, stream=request.stream, - max_tokens=request.max_len, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, + max_tokens=request.params.max_tokens, + temperature=request.params.temperature, + top_p=request.params.top_p, + top_k=request.params.top_k, ) def _generate_streaming( diff --git a/astrai/inference/sampling.py b/astrai/inference/sampling.py new file mode 100644 index 0000000..7625926 --- /dev/null +++ b/astrai/inference/sampling.py @@ -0,0 +1,129 @@ +"""Composable sampling strategies for logit transformation. + +Implements the Strategy pattern: each sampling technique +(temperature, top-k, top-p) is a pluggable strategy that +can be composed into a pipeline. +""" + +from abc import ABC, abstractmethod +from typing import List + +import torch +from torch import Tensor + + +class BaseSamplingStrategy(ABC): + """Abstract base for a logit transformation strategy.""" + + @abstractmethod + def apply(self, logits: Tensor, filter_value: float = -float("inf")) -> Tensor: + """Applies the strategy to logits. + + Args: + logits: Raw logits tensor (batch, vocab_size). + filter_value: Value assigned to filtered-out positions. + + Returns: + Transformed logits tensor (may be the same or a new tensor). + """ + + +class TemperatureStrategy(BaseSamplingStrategy): + """Divides logits by temperature to control randomness.""" + + def __init__(self, temperature: float = 1.0): + self.temperature = temperature + + def apply(self, logits, filter_value=-float("inf")): + if self.temperature != 1.0: + logits = logits / self.temperature + return logits + + +class TopKStrategy(BaseSamplingStrategy): + """Keeps only the top-k logits, setting the rest to filter_value.""" + + def __init__(self, top_k: int = 0): + self.top_k = top_k + + def apply(self, logits, filter_value=-float("inf")): + if self.top_k > 0: + k = min(self.top_k, logits.size(-1)) + topk_vals = torch.topk(logits, k, dim=-1)[0] + threshold = topk_vals[..., -1, None] + indices = logits < threshold + logits[indices] = filter_value + return logits + + +class TopPStrategy(BaseSamplingStrategy): + """Nucleus (top-p) filtering: keeps the smallest set of tokens whose + cumulative probability exceeds top_p.""" + + def __init__(self, top_p: float = 1.0): + self.top_p = top_p + + def apply(self, logits, filter_value=-float("inf")): + if self.top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > self.top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) + indices_to_remove.scatter_( + dim=1, index=sorted_indices, src=sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +class SamplingPipeline(BaseSamplingStrategy): + """Composes multiple sampling strategies into a single transformation. + + Strategies are applied sequentially in the order they are provided, + matching the original temperature → top-k → top-p ordering. + """ + + def __init__(self, strategies: List[BaseSamplingStrategy]): + self.strategies = strategies + + def apply(self, logits, filter_value=-float("inf")): + logits = logits.clone() + for strategy in self.strategies: + logits = strategy.apply(logits, filter_value) + return logits + + +def apply_sampling_strategies( + logits: Tensor, + temperature: float, + top_k: int, + top_p: float, + filter_value: float = -float("inf"), +) -> Tensor: + """Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering. + + Backward-compatible function that delegates to the Strategy pattern + pipeline with TemperatureStrategy → TopKStrategy → TopPStrategy ordering. + + Args: + logits: Raw logits tensor of shape (batch, vocab_size). + temperature: Temperature scaling factor (1.0 = no scaling). + top_k: Keep only top-k logits (0 disables). + top_p: Nucleus probability threshold (1.0 disables). + filter_value: Value to assign to filtered-out positions. + + Returns: + Modified logits tensor with same shape as input. + """ + pipeline = SamplingPipeline( + [ + TemperatureStrategy(temperature), + TopKStrategy(top_k), + TopPStrategy(top_p), + ] + ) + return pipeline.apply(logits, filter_value) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 2c387a5..b41063d 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -1,200 +1,30 @@ -"""Inference scheduler for single-GPU continuous batching.""" +"""Inference scheduler for single-GPU continuous batching. + +Splits scheduling concerns across modules: + - cache.py: SlotAllocator (Object Pool), PrefixCacheManager + - sampling.py: Strategy-pattern logit transformations +""" import logging import threading import time import uuid -from collections import OrderedDict +from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import Tensor +from astrai.inference.cache import _STOP, PrefixCacheManager, SlotAllocator +from astrai.inference.sampling import apply_sampling_strategies from astrai.model.automodel import AutoModel from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) -_STOP = object() - -class _RadixNode: - """Internal node for the radix tree prefix cache. - - Attributes: - children: Mapping from token ID to child node. - slot: KV cache slot index for the prefix ending at this node. - slot_ver: Version counter of the slot at insertion time. - ref_count: Number of tasks currently referencing this node. - last_access: Timestamp of the most recent access (for LRU ordering). - """ - - __slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access") - - def __init__(self): - self.children: Dict[int, "_RadixNode"] = {} - self.slot: int = -1 - self.slot_ver: int = 0 - self.ref_count: int = 0 - self.last_access: float = 0.0 - - -class PrefixCacheManager: - """Radix-tree prefix cache with LRU eviction. - - Maps token ID sequences to KV cache slots. Intermediate tree nodes - also store slot information, allowing direct slot reuse when the - cached slot is free and its version matches (no intervening writes). - """ - - def __init__(self, max_capacity: int = 1000): - """Initializes the prefix cache. - - Args: - max_capacity: Maximum number of nodes in the LRU list. - """ - self.root = _RadixNode() - self.max_capacity = max_capacity - self._lru: OrderedDict[int, _RadixNode] = OrderedDict() - - def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None: - """Inserts a token sequence into the prefix cache. - - Every node along the path records the slot and its version, - enabling direct slot reuse for partial prefix matches. - - Args: - token_ids: The token ID sequence to cache. - slot: The KV cache slot containing this prefix's computed keys/values. - slot_ver: The slot version at insertion time, used for staleness detection. - """ - node = self.root - for tid in token_ids: - nxt = node.children.get(tid) - if nxt is None: - nxt = _RadixNode() - node.children[tid] = nxt - node = nxt - node.slot = slot - node.slot_ver = slot_ver - node.last_access = time.time() - self._lru[id(node)] = node - node.ref_count += 1 - self._evict_if_needed() - - def find(self, token_ids: List[int]) -> Tuple[int, int, int]: - """Finds the longest matching prefix in the cache. - - Walks the radix tree token by token, recording the deepest match. - - Args: - token_ids: The token sequence to match against. - - Returns: - Tuple of (prefix_len, slot, slot_ver): - prefix_len: Number of matching tokens (0 if no match). - slot: KV cache slot of the matched prefix (-1 if no match). - slot_ver: Version of that slot when the prefix was inserted. - """ - node = self.root - best_len, best_slot, best_ver = 0, -1, 0 - for i, tid in enumerate(token_ids): - nxt = node.children.get(tid) - if nxt is None: - break - node = nxt - best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver - node.last_access = time.time() - self._lru.move_to_end(id(node)) - return best_len, best_slot, best_ver - - def pin(self, token_ids: Tuple[int, ...]) -> None: - """Increments the reference count of a cached prefix. - - Called when a task reuses a cached prefix to prevent eviction. - - Args: - token_ids: The token sequence whose node's ref_count to increment. - """ - node = self.root - for tid in token_ids: - nxt = node.children.get(tid) - if nxt is None: - return - node = nxt - node.ref_count += 1 - - def release(self, token_ids: Tuple[int, ...]) -> None: - """Decrements the reference count of a cached prefix. - - The node's slot is preserved even when ref_count reaches zero, - allowing future tasks to reuse the slot directly if it remains free. - - Args: - token_ids: The token sequence whose node's ref_count to decrement. - """ - node = self.root - for tid in token_ids: - nxt = node.children.get(tid) - if nxt is None: - return - node = nxt - if node.ref_count > 0: - node.ref_count -= 1 - - def copy_kv( - self, - token_ids: Tuple[int, ...], - target_slot: int, - kv_cache: Tuple[Tensor, Tensor], - n_layers: int, - ) -> None: - """Copies cached KV data from the source slot to a target slot. - - Args: - token_ids: The prefix token sequence identifying the source cache node. - target_slot: The destination KV cache slot to copy into. - kv_cache: Tuple of (k_cache, v_cache) tensors. - n_layers: Number of transformer layers to copy. - """ - node = self.root - for tid in token_ids: - nxt = node.children.get(tid) - if nxt is None: - return - node = nxt - src_slot = node.slot - if src_slot < 0: - return - prefix_len = len(token_ids) - k_cache, v_cache = kv_cache - for li in range(n_layers): - k_cache[target_slot, :prefix_len, li].copy_( - k_cache[src_slot, :prefix_len, li] - ) - v_cache[target_slot, :prefix_len, li].copy_( - v_cache[src_slot, :prefix_len, li] - ) - - def _evict_if_needed(self) -> None: - """Evicts least-recently-used nodes until under capacity. - - Skips nodes with ref_count > 0 (still in use by active tasks). - Evicted nodes have their slot and children cleared. - """ - while len(self._lru) > self.max_capacity: - key, node = next(iter(self._lru.items())) - if node.ref_count > 0: - self._lru.move_to_end(key) - continue - self._lru.pop(key) - node.slot = -1 - node.slot_ver = 0 - node.children.clear() - - -class TaskStatus: - """Enum-like task states in the continuous batching lifecycle.""" +class TaskStatus(Enum): + """Task states in the continuous batching lifecycle.""" PENDING = "pending" RUNNING = "running" @@ -286,46 +116,6 @@ class Task: return False -def apply_sampling_strategies( - logits: Tensor, - temperature: float, - top_k: int, - top_p: float, - filter_value: float = -float("inf"), -) -> Tensor: - """Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering. - - Args: - logits: Raw logits tensor of shape (batch, vocab_size). - temperature: Temperature scaling factor (1.0 = no scaling). - top_k: Keep only top-k logits (0 disables). - top_p: Nucleus probability threshold (1.0 disables). - filter_value: Value to assign to filtered-out positions. - - Returns: - Modified logits tensor with same shape as input. - """ - logits = logits.clone() - if temperature != 1.0: - logits = logits / temperature - if top_k > 0: - top_k = min(top_k, logits.size(-1)) - indices = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] - logits[indices] = filter_value - if top_p < 1.0: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_remove = cum_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() - sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) - indices_to_remove.scatter_( - dim=1, index=sorted_indices, src=sorted_indices_to_remove - ) - logits[indices_to_remove] = filter_value - return logits - - class InferenceScheduler: """Continuous batching scheduler for single-GPU inference. @@ -397,8 +187,7 @@ class InferenceScheduler: dtype=torch.bool, ) - self._free_slots = (1 << max_batch_size) - 1 - self._slot_ver: List[int] = [0] * max_batch_size + self.slot_allocator = SlotAllocator(max_batch_size) self.waiting_queue: List[Task] = [] self.active_tasks: List[Task] = [] @@ -410,18 +199,12 @@ class InferenceScheduler: self._total_tokens = 0 def _alloc_slot(self) -> int: - """Allocates a free KV cache slot using a bitmask. + """Allocates a free KV cache slot using the Object Pool. Returns: Slot index on success, -1 if all slots are occupied. """ - lsb = self._free_slots & -self._free_slots - if lsb == 0: - return -1 - idx = lsb.bit_length() - 1 - self._free_slots ^= lsb - self._slot_ver[idx] += 1 - return idx + return self.slot_allocator.alloc() def _free_slot(self, idx: int) -> None: """Releases a KV cache slot back to the free pool. @@ -429,7 +212,7 @@ class InferenceScheduler: Args: idx: Slot index to free. """ - self._free_slots |= 1 << idx + self.slot_allocator.free(idx) self.seq_mask[idx, :] = False def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]: @@ -445,9 +228,9 @@ class InferenceScheduler: Tuple of (slot, True) on success, or (-1, False) if reuse is not possible. """ _plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix)) - if cached_slot >= 0 and (self._free_slots >> cached_slot) & 1: - if cached_ver == self._slot_ver[cached_slot]: - self._free_slots ^= 1 << cached_slot + if cached_slot >= 0 and self.slot_allocator.is_free(cached_slot): + if cached_ver == self.slot_allocator.version(cached_slot): + self.slot_allocator.occupy(cached_slot) return cached_slot, True return -1, False @@ -588,7 +371,9 @@ class InferenceScheduler: if task.prefix_len > 0 and not reused: prefix = tuple(task.prompt_ids[: task.prefix_len]) _plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix)) - if cached_slot >= 0 and cached_ver == self._slot_ver[cached_slot]: + if cached_slot >= 0 and cached_ver == self.slot_allocator.version( + cached_slot + ): self.prefix_cache.pin(prefix) self.prefix_cache.copy_kv( prefix, slot, self.kv_cache, self._n_layers @@ -663,7 +448,7 @@ class InferenceScheduler: t.input_tokens = len(t.prompt_ids) t.output_tokens = 0 self.prefix_cache.insert( - tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot] + tuple(t.prompt_ids), t.slot, self.slot_allocator.version(t.slot) ) if t.slot >= 0: self.seq_mask[t.slot, : t.input_tokens] = True diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index a87a685..88bda06 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -15,7 +15,7 @@ def chat(): tokenizer = AutoTokenizer.from_pretrained(PARAMETER_ROOT) model.to(device="cuda", dtype=torch.bfloat16) - messages = [] + messages = [{"role": "system", "content": "You are a helpful assistant."}] engine = InferenceEngine(model=model, tokenizer=tokenizer) while True: diff --git a/tests/inference/test_scheduler_concurrency.py b/tests/inference/test_scheduler_concurrency.py index 28e4967..d006bbd 100644 --- a/tests/inference/test_scheduler_concurrency.py +++ b/tests/inference/test_scheduler_concurrency.py @@ -6,9 +6,9 @@ from unittest.mock import MagicMock, patch import pytest +from astrai.inference.cache import PrefixCacheManager from astrai.inference.scheduler import ( InferenceScheduler, - PrefixCacheManager, )