From 520de3ebe864ce18619f536122cd70ff44c168cc Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 6 May 2026 16:04:06 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=E6=8E=A8?= =?UTF-8?q?=E7=90=86=E5=BC=95=E6=93=8E=E6=8E=A7=E5=88=B6=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E5=A4=8D=E8=BF=9E=E7=BB=AD=E6=89=B9=E5=A4=84?= =?UTF-8?q?=E7=90=86=E6=A0=B8=E5=BF=83=E7=BC=BA=E9=99=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 修复 decode 阶段新任务覆盖已有任务的严重缺陷 - 修复线程安全问题(热路径无锁竞争) - 修复前缀缓存引用计数管理不当导致缓存被驱逐 - 修复 pad_id 缺失导致全量 prefill 崩溃 - 修复 RoPE 位置错乱(不同位置任务共用 start_pos) - 新增 slot 版本追踪实现前缀缓存零拷贝复用 - 新增异步流式生成接口避免阻塞事件循环 - 添加完整英文文档字符串 --- astrai/inference/engine.py | 302 +++++-- astrai/inference/scheduler.py | 834 +++++++++++------- astrai/inference/server.py | 79 +- tests/inference/conftest.py | 7 + tests/inference/test_scheduler_concurrency.py | 15 +- tests/inference/test_server.py | 5 +- 6 files changed, 757 insertions(+), 485 deletions(-) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index f701c94..e0b0161 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -1,9 +1,10 @@ -"""Unified inference engine.""" +"""Unified inference engine for continuous batching.""" +import asyncio import gc import logging import threading -from typing import Any, Dict, Generator, List, Optional, Union +from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union import torch import torch.nn as nn @@ -15,7 +16,11 @@ logger = logging.getLogger(__name__) class GenerationRequest: - """Request parameters for text generation.""" + """Request parameters for text generation. + + Encapsulates messages, sampling parameters, and streaming preference + for a single generation request. + """ def __init__( self, @@ -26,17 +31,26 @@ class GenerationRequest: max_len: int = 1024, stream: bool = False, ): + """Initializes a generation request. + + Args: + messages: Conversation history as list of {"role": ..., "content": ...}. + top_k: Top-k sampling count (0 disables). + top_p: Nucleus sampling probability threshold. + temperature: Sampling temperature. + max_len: Maximum tokens to generate. + stream: Whether to return output as a token stream. + """ self.messages = messages self.top_k = top_k self.top_p = top_p self.temperature = temperature self.max_len = max_len self.stream = stream - self._validate() def _validate(self): - """Validate request parameters.""" + """Validates sampling parameter ranges.""" if not (isinstance(self.top_k, int) and self.top_k >= 0): raise ValueError("top_k must be a non-negative integer") if not (0.0 <= self.top_p <= 1.0): @@ -46,50 +60,90 @@ class GenerationRequest: class _Result: - """Unified result holder for streaming/non-streaming modes.""" + """Thread-safe token accumulator for streaming and non-streaming modes. - def __init__(self, count: int = 1, stream: bool = False): - self._stream = stream + Supports multiple concurrent generation tasks with per-index result tracking. + Uses a threading.Event for efficient waiting on completion. + """ + + def __init__(self, count: int = 1): + """Initializes the accumulator. + + Args: + count: Number of concurrent generation tasks to track. + """ self._lock = threading.Lock() self._event = threading.Event() self.tokens: List[str] = [] - self.results: List[str] = [""] * count if count > 1 else [""] - self.done_flags: List[bool] = [False] * count - self._completed_count = 0 + self.results: List[str] = [""] * count + self._done: List[bool] = [False] * count + self._completed = 0 + self._total = count def append(self, token: str, idx: int = 0): + """Appends a token to the result buffer. + + In non-streaming mode, tokens are concatenated into results[idx]. + The sentinel "[DONE]" marks a task as complete. + + Args: + token: The decoded token string, or "[DONE]" sentinel. + idx: Index of the generation task this token belongs to. + """ with self._lock: - if self._stream: - self.tokens.append(token) + self.tokens.append(token) + if token != "[DONE]": + self.results[idx] += token else: - if token == "[DONE]": - if not self.done_flags[idx]: - self.done_flags[idx] = True - self._completed_count += 1 - if self._completed_count == len(self.results): - self._event.set() - else: - self.results[idx] += token + if not self._done[idx]: + self._done[idx] = True + self._completed += 1 self._event.set() def pop_all(self) -> List[str]: - with self._lock: - tokens = self.tokens.copy() - self.tokens.clear() - if not tokens: - self._event.clear() - return tokens + """Returns and clears all accumulated tokens. - def wait(self, timeout: float = None) -> bool: + Returns: + List of token strings since the last call. + """ + with self._lock: + out = self.tokens.copy() + self.tokens.clear() + if not out: + self._event.clear() + return out + + def wait(self, timeout: Optional[float] = None) -> bool: + """Blocks until new tokens arrive or the timeout expires. + + Args: + timeout: Maximum wait time in seconds (None = infinite). + + Returns: + True if the event was set (new data available), False on timeout. + """ return self._event.wait(timeout=timeout) def get_results(self) -> List[str]: + """Returns all accumulated results for non-streaming mode. + + Returns: + List of complete generated strings, one per task index. + """ with self._lock: return self.results.copy() class InferenceEngine: - """Unified inference engine for continuous batching.""" + """Unified inference engine backed by continuous-batching scheduler. + + Usage: + with InferenceEngine(model, tokenizer) as engine: + for token in engine.generate("hello", stream=True): + print(token, end="") + + text = engine.generate("hello") + """ def __init__( self, @@ -97,40 +151,36 @@ class InferenceEngine: tokenizer: AutoTokenizer, max_batch_size: int = 1, max_seq_len: Optional[int] = None, - max_prefix_len: int = 512, + max_prompt_len: int = 512, cache_capacity: int = 1000, ): - """ - Initialize inference engine with separate model and tokenizer. + """Initializes the engine and starts the scheduler background thread. Args: - model: The language model for inference (nn.Module, e.g., Transformer) - tokenizer: The tokenizer for encoding/decoding text - config: Model configuration - max_batch_size: Maximum batch size for continuous batching - max_seq_len: Maximum sequence length (defaults to config.max_len) - max_prefix_len: Maximum prefix length for cache (default: 512) - cache_capacity: Maximum number of cached prefixes (default: 1000) + model: The language model (nn.Module, e.g. Transformer). + tokenizer: Tokenizer for encoding/decoding. + max_batch_size: Maximum concurrent tasks in the scheduler. + max_seq_len: Maximum sequence length (defaults to model config). + max_prompt_len: Maximum prompt tokens (longer prompts truncated). + cache_capacity: Maximum prefix cache nodes. """ - self.model = model - self.tokenizer = tokenizer - - # Get device and dtype from model parameters try: first_param = next(model.parameters()) device = first_param.device dtype = first_param.dtype except StopIteration: - # Model has no parameters, use default device/dtype device = torch.device("cuda" if torch.cuda.is_available() else "cpu") dtype = torch.float32 + self.model = model + self.tokenizer = tokenizer + self.scheduler = InferenceScheduler( model=self.model, tokenizer=self.tokenizer, max_batch_size=max_batch_size, max_seq_len=max_seq_len, - max_prefix_len=max_prefix_len, + max_prompt_len=max_prompt_len, cache_capacity=cache_capacity, device=device, dtype=dtype, @@ -138,14 +188,12 @@ class InferenceEngine: self.kv_cache = self.scheduler.kv_cache self.seq_mask = self.scheduler.seq_mask - self.scheduler.start() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - """Handle exceptions on exit.""" self.shutdown() return False @@ -157,39 +205,99 @@ class InferenceEngine: temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, - abort_on_exception: bool = True, ) -> Union[Generator[str, None, None], str, List[str]]: - """Unified generation interface. + """Generates text from a prompt. Args: - abort_on_exception: If True, abort the generation when consumer - stops iterating (GeneratorExit/StopIteration). Default: True. + prompt: Single string or list of strings for batch generation. + stream: If True, returns a generator yielding tokens one by one. + max_tokens: Maximum number of tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling probability threshold. + top_k: Top-k sampling count (0 disables). + + Returns: + Generator (stream=True), single string (non-stream, single prompt), + or list of strings (non-stream, batch prompts). """ is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] if stream: return self._generate_streaming( - prompts, - is_batch, - max_tokens, - temperature, - top_p, - top_k, - abort_on_exception, + prompts, is_batch, max_tokens, temperature, top_p, top_k ) else: return self._generate_non_streaming( prompts, is_batch, max_tokens, temperature, top_p, top_k ) + def generate_async( + self, + prompt: str, + max_tokens: int = 1024, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = 50, + ) -> AsyncGenerator[str, None]: + """Async streaming generator that does not block the event loop. + + Runs the synchronous generator in a background thread pool executor, + yielding tokens to the async consumer as they arrive. + + Args: + prompt: Input text to generate from. + max_tokens: Maximum tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling threshold. + top_k: Top-k sampling count. + + Yields: + Decoded token strings as they are generated. + """ + sync_gen = self._generate_streaming( + [prompt], False, max_tokens, temperature, top_p, top_k + ) + + async def _agen(): + loop = asyncio.get_event_loop() + while True: + token = await loop.run_in_executor(None, self._next_token, sync_gen) + if token is None: + break + yield token + + return _agen() + + @staticmethod + def _next_token(gen: Generator) -> Optional[str]: + """Retrieves the next token from a synchronous generator. + + Args: + gen: A synchronous generator yielding token strings. + + Returns: + The next token, or None if the generator is exhausted. + """ + try: + return next(gen) + except StopIteration: + return None + def generate_with_request( self, request: GenerationRequest ) -> Union[Generator[str, None, None], str, List[str]]: - """Generate with GenerationRequest object.""" - # Use tokenizer's chat template with messages - prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) + """Generates text from a structured GenerationRequest. + Applies the chat template to the request's messages before generation. + + Args: + request: A GenerationRequest with messages and parameters. + + Returns: + Generator, string, or list of strings (see generate()). + """ + prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) return self.generate( prompt=prompt, stream=request.stream, @@ -207,18 +315,27 @@ class InferenceEngine: temperature: float, top_p: float, top_k: int, - abort_on_exception: bool = True, - ) -> Union[Generator[str, None, None], List[Generator[str, None, None]]]: - """Generate with streaming output. + ) -> Generator[str, None, None]: + """Internal streaming generator. + + Polls the _Result accumulator in a loop, yielding tokens as they arrive. + Cleans up the scheduler task on GeneratorExit. Args: - abort_on_exception: If True, abort the task when generator is - stopped early by consumer (GeneratorExit/StopIteration). + prompts: List of prompts (only first is used; batch not yet supported). + is_batch: If True, raises NotImplementedError. + max_tokens: Maximum tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling threshold. + top_k: Top-k sampling count. + + Yields: + Decoded token strings. """ if is_batch: - raise NotImplementedError("Batch streaming is not implemented yet") + raise NotImplementedError("Batch streaming not yet supported") - result = _Result(stream=True) + result = _Result() task_id = self.scheduler.add_task( prompt=prompts[0], @@ -226,7 +343,7 @@ class InferenceEngine: temperature=temperature, top_p=top_p, top_k=top_k, - stream_callback=result.append, + stream_callback=lambda tok: result.append(tok, 0), ) def gen(): @@ -237,14 +354,12 @@ class InferenceEngine: if token == "[DONE]": return yield token - result.wait(timeout=0.05) - except Exception: - # Consumer stopped iterating - abort the task - if abort_on_exception: - self.scheduler.remove_task(task_id) + if not result.wait(timeout=0.05): + pass + except GeneratorExit: + self.scheduler.remove_task(task_id) raise - gen.task_id = task_id return gen() def _generate_non_streaming( @@ -256,16 +371,27 @@ class InferenceEngine: top_p: float, top_k: int, ) -> Union[str, List[str]]: - """Generate without streaming.""" + """Internal non-streaming generator. + + Submits all prompts to the scheduler and waits for all to complete. + + Args: + prompts: List of prompt strings. + is_batch: Whether multiple prompts were provided. + max_tokens: Maximum tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling threshold. + top_k: Top-k sampling count. + + Returns: + Single string for one prompt, list of strings for batch. + """ result = _Result(count=len(prompts)) for i, p in enumerate(prompts): - # Create closure to capture current index value using factory function - def make_callback(idx): - def callback(token): - result.append(idx, token) - return callback + def make_cb(idx): + return lambda tok: result.append(tok, idx) self.scheduler.add_task( prompt=p, @@ -273,19 +399,23 @@ class InferenceEngine: temperature=temperature, top_p=top_p, top_k=top_k, - stream_callback=make_callback(i), + stream_callback=make_cb(i), ) result.wait() - results = result.get_results() - return results if is_batch else results[0] + res = result.get_results() + return res if is_batch else res[0] def get_stats(self) -> Dict[str, Any]: - """Get engine statistics.""" + """Returns current engine statistics. + + Returns: + Dict with total_tasks, total_tokens, active_tasks, waiting_queue. + """ return self.scheduler.get_stats() def shutdown(self) -> None: - """Shutdown the engine and release all resources.""" + """Shuts down the engine, stops the scheduler, and frees GPU memory.""" self.scheduler.stop() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 0d67651..858f264 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -1,8 +1,9 @@ -"""Inference scheduler for continuous batching.""" +"""Inference scheduler for single-GPU continuous batching.""" import threading import time import uuid +from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple import torch @@ -12,137 +13,186 @@ from astrai.model.automodel import AutoModel from astrai.tokenize import AutoTokenizer -class RadixNode: - """Radix tree node for prefix cache.""" +class _RadixNode: + """Internal node for the radix tree prefix cache. + + Attributes: + children: Mapping from token ID to child node. + slot: KV cache slot index for the prefix ending at this node. + slot_ver: Version counter of the slot at insertion time. + ref_count: Number of tasks currently referencing this node. + last_access: Timestamp of the most recent access (for LRU ordering). + """ + + __slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access") def __init__(self): - self.children: Dict[int, "RadixNode"] = {} # token_id -> child node - self.hash: Optional[int] = None # 64-bit hash of the prefix - self.slot: int = -1 # KV Cache slot, valid only for leaf nodes - self.ref_count: int = 0 # number of tasks referencing this prefix - self.last_access: float = 0.0 # timestamp for LRU - self.token_sequence: list = [] # full token sequence from root to this node + self.children: Dict[int, "_RadixNode"] = {} + self.slot: int = -1 + self.slot_ver: int = 0 + self.ref_count: int = 0 + self.last_access: float = 0.0 class PrefixCacheManager: - """Prefix cache manager using Radix tree with LRU eviction.""" + """Radix-tree prefix cache with LRU eviction. - def __init__(self, max_capacity: int = 1000, base: int = 131, mod: int = 10**9 + 7): - self.root = RadixNode() - self.base = base - self.mod = mod + Maps token ID sequences to KV cache slots. Intermediate tree nodes + also store slot information, allowing direct slot reuse when the + cached slot is free and its version matches (no intervening writes). + """ + + def __init__(self, max_capacity: int = 1000): + """Initializes the prefix cache. + + Args: + max_capacity: Maximum number of nodes in the LRU list. + """ + self.root = _RadixNode() self.max_capacity = max_capacity - self.lru: List[Tuple[float, RadixNode]] = [] # (timestamp, node) for LRU + self._lru: OrderedDict[int, _RadixNode] = OrderedDict() - def insert(self, token_ids: Tuple[int, ...], slot: int) -> None: - """Insert a prefix, increase ref_count if already exists, otherwise create new node.""" - node = self.root - path = [] - h = 0 - for i, token_id in enumerate(token_ids): - if token_id not in node.children: - node.children[token_id] = RadixNode() - node = node.children[token_id] - h = (h * self.base + token_id) % self.mod - node.hash = h - path.append(token_id) - node.token_sequence = list( - path - ) # store full sequence for exact verification + def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None: + """Inserts a token sequence into the prefix cache. - # Leaf node: set slot and increase ref_count - if node.slot == -1: - node.slot = slot - node.ref_count += 1 - node.last_access = time.time() - self._update_lru(node) - self._evict_if_needed() + Every node along the path records the slot and its version, + enabling direct slot reuse for partial prefix matches. - def find_longest_prefix(self, token_ids: List[int]) -> Optional[Tuple[int, int]]: - """Find longest matching prefix, return (prefix_len, slot). - - During traversal, compute hash per token and compare with node hash. - If hash matches, perform full token sequence verification to avoid - hash collision errors. + Args: + token_ids: The token ID sequence to cache. + slot: The KV cache slot containing this prefix's computed keys/values. + slot_ver: The slot version at insertion time, used for staleness detection. """ node = self.root - best_len = 0 - best_slot = -1 - h = 0 + for tid in token_ids: + nxt = node.children.get(tid) + if nxt is None: + nxt = _RadixNode() + node.children[tid] = nxt + node = nxt + node.slot = slot + node.slot_ver = slot_ver + node.last_access = time.time() + self._lru[id(node)] = node + node.ref_count += 1 + self._evict_if_needed() - for i, token_id in enumerate(token_ids): - if token_id not in node.children: + def find(self, token_ids: List[int]) -> Tuple[int, int, int]: + """Finds the longest matching prefix in the cache. + + Walks the radix tree token by token, recording the deepest match. + + Args: + token_ids: The token sequence to match against. + + Returns: + Tuple of (prefix_len, slot, slot_ver): + prefix_len: Number of matching tokens (0 if no match). + slot: KV cache slot of the matched prefix (-1 if no match). + slot_ver: Version of that slot when the prefix was inserted. + """ + node = self.root + best_len, best_slot, best_ver = 0, -1, 0 + for i, tid in enumerate(token_ids): + nxt = node.children.get(tid) + if nxt is None: break - node = node.children[token_id] - h = (h * self.base + token_id) % self.mod - if node.hash == h: # hash matches - # Exact verification: compare full token sequence - if node.token_sequence == token_ids[: i + 1]: - best_len = i + 1 - best_slot = node.slot - node.last_access = time.time() - self._update_lru(node) + node = nxt + best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver + node.last_access = time.time() + self._lru.move_to_end(id(node)) + return best_len, best_slot, best_ver - if best_len > 0: - return (best_len, best_slot) - return None + def pin(self, token_ids: Tuple[int, ...]) -> None: + """Increments the reference count of a cached prefix. + + Called when a task reuses a cached prefix to prevent eviction. + + Args: + token_ids: The token sequence whose node's ref_count to increment. + """ + node = self.root + for tid in token_ids: + nxt = node.children.get(tid) + if nxt is None: + return + node = nxt + node.ref_count += 1 def release(self, token_ids: Tuple[int, ...]) -> None: - """Release reference to a prefix, decrease ref_count. If zero, mark as evictable.""" + """Decrements the reference count of a cached prefix. + + The node's slot is preserved even when ref_count reaches zero, + allowing future tasks to reuse the slot directly if it remains free. + + Args: + token_ids: The token sequence whose node's ref_count to decrement. + """ node = self.root - for token_id in token_ids: - if token_id not in node.children: + for tid in token_ids: + nxt = node.children.get(tid) + if nxt is None: return - node = node.children[token_id] + node = nxt if node.ref_count > 0: node.ref_count -= 1 - if node.ref_count == 0: - node.slot = -1 # slot can be reused - def _update_lru(self, node: RadixNode) -> None: - """Update LRU list, move node to most recently used position.""" - self.lru = [(ts, n) for (ts, n) in self.lru if n is not node] - self.lru.append((node.last_access, node)) + def copy_kv( + self, + token_ids: Tuple[int, ...], + target_slot: int, + kv_cache: Tuple[Tensor, Tensor], + n_layers: int, + ) -> None: + """Copies cached KV data from the source slot to a target slot. + + Used when the cached slot is occupied and cannot be reused directly. + Copies the key/value tensors for all layers. + + Args: + token_ids: The prefix token sequence identifying the source cache node. + target_slot: The destination KV cache slot to copy into. + kv_cache: Tuple of (k_cache, v_cache) tensors. + n_layers: Number of transformer layers to copy. + """ + node = self.root + for tid in token_ids: + nxt = node.children.get(tid) + if nxt is None: + return + node = nxt + src_slot = node.slot + if src_slot < 0: + return + prefix_len = len(token_ids) + k_cache, v_cache = kv_cache + for li in range(n_layers): + k_cache[target_slot, :prefix_len, li].copy_( + k_cache[src_slot, :prefix_len, li] + ) + v_cache[target_slot, :prefix_len, li].copy_( + v_cache[src_slot, :prefix_len, li] + ) def _evict_if_needed(self) -> None: - """If cache entries exceed capacity, evict least recently used leaf nodes (ref_count must be 0).""" - if len(self.lru) <= self.max_capacity: - return - # Sort by timestamp - self.lru.sort(key=lambda x: x[0]) - for ts, node in self.lru: - if node.ref_count == 0: - # Remove leaf node from tree (need to recursively delete empty branches) - self._remove_node(node) - self.lru.remove((ts, node)) - if len(self.lru) <= self.max_capacity: - break + """Evicts least-recently-used nodes until under capacity. - def _remove_node( - self, - node: RadixNode, - parent: Optional[RadixNode] = None, - child_key: Optional[int] = None, - ) -> None: - """Remove node from tree, including empty parent nodes.""" - # First, recursively remove all children - for child_key, child_node in list(node.children.items()): - self._remove_node(child_node, node, child_key) - - # Clear the node's leaf properties - node.slot = -1 - node.hash = None - node.token_sequence = [] - node.children.clear() - - # If this node has no children and has a parent, remove the reference from parent - if parent is not None and child_key is not None and len(node.children) == 0: - if child_key in parent.children: - del parent.children[child_key] + Skips nodes with ref_count > 0 (still in use by active tasks). + Evicted nodes have their slot and children cleared. + """ + while len(self._lru) > self.max_capacity: + key, node = next(iter(self._lru.items())) + if node.ref_count > 0: + self._lru.move_to_end(key) + continue + self._lru.pop(key) + node.slot = -1 + node.slot_ver = 0 + node.children.clear() class TaskStatus: - """Task state for continuous batching.""" + """Enum-like task states in the continuous batching lifecycle.""" PENDING = "pending" RUNNING = "running" @@ -151,7 +201,29 @@ class TaskStatus: class Task: - """Individual task for continuous batching.""" + """Represents a single generation request within the scheduler. + + Tracks prompt tokens, generated output, sampling parameters, + KV cache slot assignment, and prefix cache matching state. + """ + + __slots__ = ( + "task_id", + "prompt_ids", + "max_tokens", + "temperature", + "top_p", + "top_k", + "status", + "output_ids", + "input_tokens", + "output_tokens", + "slot", + "prefix_len", + "arrival_time", + "finish_time", + "stream_callback", + ) def __init__( self, @@ -163,6 +235,17 @@ class Task: top_k: int = 50, stream_callback: Optional[Callable[[str], None]] = None, ): + """Initializes a new task. + + Args: + task_id: Unique identifier for this task. + prompt_ids: Tokenized prompt sequence. + max_tokens: Maximum number of tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling probability threshold. + top_k: Top-k sampling count (0 disables). + stream_callback: Optional callback invoked per decoded token. + """ self.task_id = task_id self.prompt_ids = prompt_ids self.max_tokens = max_tokens @@ -175,18 +258,30 @@ class Task: self.input_tokens: int = 0 self.output_tokens: int = 0 self.slot: int = -1 - self.prefix_len: int = 0 # prefix cache matched length + self.prefix_len: int = 0 self.arrival_time = time.time() self.finish_time: Optional[float] = None - self.stream_callback = stream_callback + @property + def next_pos(self) -> int: + """Returns the next KV cache position to write during decode.""" + return self.input_tokens + len(self.output_ids) + def is_finished(self, stop_ids: List[int]) -> bool: - """Check if task is finished.""" - return ( - bool(self.output_ids and self.output_ids[-1] in stop_ids) - or self.output_tokens >= self.max_tokens - ) + """Checks whether the task has reached a stopping condition. + + Args: + stop_ids: List of stop token IDs (e.g., EOS). + + Returns: + True if max_tokens reached or the last output token is a stop ID. + """ + if self.output_tokens >= self.max_tokens: + return True + if self.output_ids and self.output_ids[-1] in stop_ids: + return True + return False def apply_sampling_strategies( @@ -196,38 +291,54 @@ def apply_sampling_strategies( top_p: float, filter_value: float = -float("inf"), ) -> Tensor: - """Apply sampling strategies to the logits tensor.""" - # Clone logits to avoid inplace updates on inference tensor - logits = logits.clone() + """Applies temperature scaling, top-k filtering, and top-p (nucleus) filtering. + Operates on a clone of the input logits to avoid in-place modification + of the inference tensor. + + Args: + logits: Raw logits tensor of shape (batch, vocab_size). + temperature: Temperature scaling factor (1.0 = no scaling). + top_k: Keep only top-k logits (0 disables). + top_p: Nucleus probability threshold (1.0 disables). + filter_value: Value to assign to filtered-out positions. + + Returns: + Modified logits tensor with same shape as input. + """ + logits = logits.clone() if temperature != 1.0: logits = logits / temperature - if top_k > 0: top_k = min(top_k, logits.size(-1)) - indices_to_remove = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] - logits[indices_to_remove] = filter_value - + indices = logits < torch.topk(logits, top_k, dim=-1)[0][..., -1, None] + logits[indices] = filter_value if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - sorted_indices_to_remove = cumulative_probs > top_p + cum_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cum_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 - indices_to_remove = torch.zeros_like(logits, dtype=torch.bool) indices_to_remove.scatter_( dim=1, index=sorted_indices, src=sorted_indices_to_remove ) - logits[indices_to_remove] = filter_value - return logits class InferenceScheduler: - """Inference scheduler with continuous batching support.""" + """Continuous batching scheduler for single-GPU inference. + + Runs a background generation loop with four phases per iteration: + 1. Cleanup finished tasks and release resources. + 2. Refill active batch from the waiting queue. + 3. Prefill newly activated tasks (full, partial, or fully cached). + 4. Decode the largest same-position group of active tasks. + + Tasks at different positions are never batched together in decode, + avoiding RoPE corruption from misaligned KV cache writes. + """ def __init__( self, @@ -235,55 +346,60 @@ class InferenceScheduler: tokenizer: AutoTokenizer, max_batch_size: int = 16, max_seq_len: Optional[int] = None, - max_prefix_len: int = 512, + max_prompt_len: int = 512, cache_capacity: int = 1000, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, ): + """Initializes the scheduler and pre-allocates the KV cache. + + Args: + model: The language model (must have config with n_layers, n_kv_heads, etc.). + tokenizer: Tokenizer for encoding prompts and decoding outputs. + max_batch_size: Maximum number of concurrent tasks. + max_seq_len: Maximum sequence length (defaults to config.max_len). + max_prompt_len: Maximum prompt tokens (longer prompts are truncated). + cache_capacity: Maximum prefix cache node count. + device: Target device for tensors. + dtype: Data type for KV cache tensors. + """ config = model.config self.model = model self.tokenizer = tokenizer self.max_batch_size = max_batch_size self.max_seq_len = max_seq_len or config.max_len - self.max_prefix_len = max_prefix_len + self.max_prompt_len = max_prompt_len self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype - # Initialize prefix cache self.prefix_cache = PrefixCacheManager(max_capacity=cache_capacity) - num_kv_heads = config.n_kv_heads + n_kv_heads = config.n_kv_heads head_dim = config.dim // config.n_heads n_layers = config.n_layers + self._n_layers = n_layers k_cache = torch.empty( - ( - max_batch_size, - self.max_seq_len, - n_layers, - num_kv_heads, - head_dim, - ), + (max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim), device=self.device, dtype=self.dtype, ) v_cache = torch.empty( - ( - max_batch_size, - self.max_seq_len, - n_layers, - num_kv_heads, - head_dim, - ), + (max_batch_size, self.max_seq_len, n_layers, n_kv_heads, head_dim), device=self.device, dtype=self.dtype, ) self.kv_cache = (k_cache, v_cache) - self.seq_mask = torch.ones( - (max_batch_size, self.max_seq_len), device=self.device, dtype=torch.bool + + self.seq_mask = torch.zeros( + (max_batch_size, self.max_seq_len), + device=self.device, + dtype=torch.bool, ) + self._free_slots = (1 << max_batch_size) - 1 + self._slot_ver: List[int] = [0] * max_batch_size self.waiting_queue: List[Task] = [] self.active_tasks: List[Task] = [] @@ -294,6 +410,48 @@ class InferenceScheduler: self._total_tasks = 0 self._total_tokens = 0 + def _alloc_slot(self) -> int: + """Allocates a free KV cache slot using a bitmask. + + Returns: + Slot index on success, -1 if all slots are occupied. + """ + lsb = self._free_slots & -self._free_slots + if lsb == 0: + return -1 + idx = lsb.bit_length() - 1 + self._free_slots ^= lsb + self._slot_ver[idx] += 1 + return idx + + def _free_slot(self, idx: int) -> None: + """Releases a KV cache slot back to the free pool. + + Args: + idx: Slot index to free. + """ + self._free_slots |= 1 << idx + self.seq_mask[idx, :] = False + + def _try_reuse_slot(self, prefix: Tuple[int, ...]) -> Tuple[int, bool]: + """Attempts to reuse a prefix-cached slot directly without KV copy. + + The slot is reusable only if it is free and its version matches + the current slot version (no intervening allocation overwrote it). + + Args: + prefix: The matched prefix token sequence. + + Returns: + Tuple of (slot, True) on success, or (-1, False) if reuse is not possible. + """ + _plen, cached_slot, cached_ver = self.prefix_cache.find(list(prefix)) + if cached_slot >= 0 and (self._free_slots >> cached_slot) & 1: + if cached_ver == self._slot_ver[cached_slot]: + self._free_slots ^= 1 << cached_slot + return cached_slot, True + return -1, False + def add_task( self, prompt: str, @@ -303,13 +461,27 @@ class InferenceScheduler: top_k: int = 50, stream_callback: Optional[Callable[[str], None]] = None, ) -> str: - """Add a new task to the waiting queue.""" + """Adds a generation task to the waiting queue. + + Encodes the prompt, queries the prefix cache for a match, + and enqueues the task for the background generation loop. + + Args: + prompt: Input text to generate from. + max_tokens: Maximum tokens to generate. + temperature: Sampling temperature. + top_p: Nucleus sampling threshold. + top_k: Top-k sampling count. + stream_callback: Called per decoded token with the string representation. + + Returns: + Unique task ID string. + """ task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" prompt_ids = self.tokenizer.encode(prompt) - # Truncate if exceeds max_prefix_len - if len(prompt_ids) > self.max_prefix_len: - prompt_ids = prompt_ids[: self.max_prefix_len] + if len(prompt_ids) > self.max_prompt_len: + prompt_ids = prompt_ids[: self.max_prompt_len] task = Task( task_id=task_id, @@ -321,15 +493,8 @@ class InferenceScheduler: stream_callback=stream_callback, ) - # Find longest matching prefix from cache - match = self.prefix_cache.find_longest_prefix(prompt_ids) - if match: - prefix_len, slot = match - task.prefix_len = prefix_len - task.slot = slot - else: - task.prefix_len = 0 - task.slot = -1 + prefix_len, _cached_slot, _cached_ver = self.prefix_cache.find(prompt_ids) + task.prefix_len = prefix_len with self._lock: self.waiting_queue.append(task) @@ -339,13 +504,21 @@ class InferenceScheduler: return task_id def remove_task(self, task_id: str) -> None: - """Remove a task from the scheduler.""" + """Removes a task from both the waiting queue and active tasks. + + Args: + task_id: The task to remove. + """ with self._lock: self.waiting_queue = [t for t in self.waiting_queue if t.task_id != task_id] self.active_tasks = [t for t in self.active_tasks if t.task_id != task_id] def _remove_finished_tasks(self) -> None: - """Remove finished tasks from active batch.""" + """Removes all finished tasks from the active batch. + + Releases prefix cache references and frees the KV cache slot + for each completed task. + """ finished = [] for task in self.active_tasks: if task.is_finished(self.tokenizer.stop_ids): @@ -355,14 +528,13 @@ class InferenceScheduler: self._total_tokens += task.output_tokens for task in finished: - slot = task.slot - if slot >= 0 and slot < len(self.active_tasks): - self.seq_mask[slot, :] = False - - # Release prefix cache reference if task.prefix_len > 0: - self.prefix_cache.release(tuple(task.prompt_ids[: task.prefix_len])) - + prefix = tuple(task.prompt_ids[: task.prefix_len]) + self.prefix_cache.release(prefix) + if task.prefix_len < len(task.prompt_ids): + self.prefix_cache.release(tuple(task.prompt_ids)) + if task.slot >= 0: + self._free_slot(task.slot) task.slot = -1 self.active_tasks = [ @@ -370,47 +542,68 @@ class InferenceScheduler: ] def _refill_active_batch(self) -> None: - """Refill active batch with waiting tasks.""" - available_slots = self.max_batch_size - len(self.active_tasks) - if available_slots <= 0: + """Moves waiting tasks into the active batch, up to max_batch_size. + + Attempts direct slot reuse for prefix-matched tasks; falls back + to allocating a fresh slot with KV cache copy when reuse is not possible. + """ + available = self.max_batch_size - len(self.active_tasks) + if available <= 0: return + to_add: List[Task] = [] with self._lock: - to_add = [ - self.waiting_queue.pop(0) - for _ in range(min(available_slots, len(self.waiting_queue))) - ] - for task in to_add: - task.slot = self._allocate_slot() - task.status = TaskStatus.RUNNING - self.active_tasks.append(task) + n = min(available, len(self.waiting_queue)) + for _ in range(n): + to_add.append(self.waiting_queue.pop(0)) - def _allocate_slot(self) -> int: - """Allocate an available slot for a task.""" - for i in range(self.max_batch_size): - if not any(t.slot == i for t in self.active_tasks): - return i - return -1 + for task in to_add: + slot = -1 + if task.prefix_len > 0: + prefix = tuple(task.prompt_ids[: task.prefix_len]) + cached_slot, reused = self._try_reuse_slot(prefix) + if reused: + slot = cached_slot + if slot < 0: + slot = self._alloc_slot() + if slot < 0: + break + task.slot = slot + task.status = TaskStatus.RUNNING + self.active_tasks.append(task) + + if task.prefix_len > 0: + prefix = tuple(task.prompt_ids[: task.prefix_len]) + if not reused: + self.prefix_cache.pin(prefix) + self.prefix_cache.copy_kv( + prefix, slot, self.kv_cache, self._n_layers + ) def _execute_prefill(self, tasks: List[Task]) -> None: - """Execute Prefill phase with incremental prefill support.""" + """Runs the prefill phase for a batch of newly activated tasks. + + Groups tasks by cache status: + - fully cached: no model call, just set seq_mask. + - partial: incremental prefill from the cached prefix. + - full: complete prefill from position 0. + """ if not tasks: return - # Group tasks by prefix cache status fully_cached, partial, full = [], [], [] - for task in tasks: - total_len, prefix_len = len(task.prompt_ids), task.prefix_len - if prefix_len == total_len: - fully_cached.append(task) - elif prefix_len > 0: - partial.append(task) + for t in tasks: + plen = len(t.prompt_ids) + if t.prefix_len == plen: + fully_cached.append(t) + elif t.prefix_len > 0: + partial.append(t) else: - full.append(task) + full.append(t) - # Handle fully cached tasks for t in fully_cached: - t.input_tokens, t.output_tokens = len(t.prompt_ids), 0 + t.input_tokens = len(t.prompt_ids) + t.output_tokens = 0 if t.slot >= 0: self.seq_mask[t.slot, : t.input_tokens] = True @@ -420,30 +613,29 @@ class InferenceScheduler: self._execute_partial_prefill(partial) def _execute_full_prefill(self, tasks: List[Task]) -> None: - """Execute full prefill for tasks without prefix cache.""" - if not tasks: - return + """Executes full prefill for tasks without any cache match. + Pads all prompts to the same length and runs a single batched + forward pass. Inserts the full prompt into the prefix cache. + + Args: + tasks: List of tasks with prefix_len == 0. + """ tasks = sorted(tasks, key=lambda t: t.slot) - - prompt_lens = [len(task.prompt_ids) for task in tasks] + prompt_lens = [len(t.prompt_ids) for t in tasks] max_len = max(prompt_lens) + batch_sz = len(tasks) - input_ids = torch.zeros( - len(tasks), max_len, dtype=torch.long, device=self.device + input_ids = torch.zeros(batch_sz, max_len, dtype=torch.long, device=self.device) + input_mask = torch.zeros( + batch_sz, max_len, dtype=torch.bool, device=self.device ) - for i, task in enumerate(tasks): - if len(task.prompt_ids) > 0: - input_ids[i, : len(task.prompt_ids)] = torch.tensor( - task.prompt_ids, device=self.device + for i, t in enumerate(tasks): + if prompt_lens[i] > 0: + input_ids[i, : prompt_lens[i]] = torch.tensor( + t.prompt_ids, device=self.device ) - - if self.tokenizer.pad_id is not None: - input_mask = torch.ne(input_ids, self.tokenizer.pad_id) - else: - input_mask = torch.ones( - input_ids.shape, dtype=torch.bool, device=self.device - ) + input_mask[i, : prompt_lens[i]] = True with torch.inference_mode(): self.model( @@ -453,41 +645,43 @@ class InferenceScheduler: persistent_key_values=self.kv_cache, ) - for i, task in enumerate(tasks): - task.input_tokens = prompt_lens[i] - task.output_tokens = 0 - # Insert new prefix into cache - self.prefix_cache.insert(tuple(task.prompt_ids), task.slot) + for i, t in enumerate(tasks): + t.input_tokens = prompt_lens[i] + t.output_tokens = 0 + self.prefix_cache.insert( + tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot] + ) - for task in tasks: - if task.slot >= 0: - self.seq_mask[task.slot, : task.input_tokens] = True + for t in tasks: + if t.slot >= 0: + self.seq_mask[t.slot, : t.input_tokens] = True def _execute_partial_prefill(self, tasks: List[Task]) -> None: - """Execute incremental prefill for tasks with partial prefix cache match.""" - for task in tasks: - total_len = len(task.prompt_ids) - prefix_len = task.prefix_len + """Executes incremental prefill for tasks with a partial cache match. + + Only the tokens beyond the matched prefix are forwarded through + the model. The full prompt is inserted into the cache afterward. + + Args: + tasks: List of tasks with 0 < prefix_len < len(prompt_ids). + """ + for t in tasks: + total_len = len(t.prompt_ids) + prefix_len = t.prefix_len if prefix_len >= total_len: - task.input_tokens = total_len - task.output_tokens = 0 + t.input_tokens = total_len + t.output_tokens = 0 continue - # Get new tokens that need prefill - new_ids = task.prompt_ids[prefix_len:] + new_ids = t.prompt_ids[prefix_len:] new_len = len(new_ids) - if new_len == 0: - task.input_tokens = total_len - task.output_tokens = 0 + t.input_tokens = total_len + t.output_tokens = 0 continue - # Build input for incremental prefill input_ids = torch.tensor([new_ids], dtype=torch.long, device=self.device) - - # Input mask should cover from position 0 to prefix_len + new_len - # The prefix part uses cached KV, new part needs computation input_mask = torch.ones( (1, prefix_len + new_len), dtype=torch.bool, device=self.device ) @@ -500,135 +694,135 @@ class InferenceScheduler: persistent_key_values=self.kv_cache, ) - task.input_tokens = total_len - task.output_tokens = 0 + t.input_tokens = total_len + t.output_tokens = 0 + self.prefix_cache.insert( + tuple(t.prompt_ids), t.slot, self._slot_ver[t.slot] + ) - # Insert full prefix into cache (ref_count already increased in add_task) - self.prefix_cache.insert(tuple(task.prompt_ids), task.slot) - - if task.slot >= 0: - self.seq_mask[task.slot, : task.input_tokens] = True + if t.slot >= 0: + self.seq_mask[t.slot, : t.input_tokens] = True def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: - """Execute Decode phase.""" + """Executes the decode phase for a group of tasks at the same position. + + The input is the last generated token (or last prompt token for + newly prefilled tasks). After the forward pass, sampling strategies + are applied to produce the next token. + + Args: + tasks: Tasks sharing the same next_pos value. + start_pos: Common KV cache write position for the batch. + """ if not tasks: return tasks = sorted(tasks, key=lambda t: t.slot) + batch_sz = len(tasks) - input_ids = torch.zeros(len(tasks), dtype=torch.long, device=self.device) - for i, task in enumerate(tasks): - if task.output_ids: - input_ids[i] = task.output_ids[-1] + input_ids = torch.zeros(batch_sz, dtype=torch.long, device=self.device) + for i, t in enumerate(tasks): + if t.output_ids: + input_ids[i] = t.output_ids[-1] else: - input_ids[i] = task.prompt_ids[-1] + input_ids[i] = t.prompt_ids[-1] - input_tensor = input_ids.unsqueeze(1) - active_mask = torch.ones((len(tasks), 1), dtype=torch.bool, device=self.device) + for t in tasks: + if t.slot >= 0 and start_pos < self.max_seq_len: + self.seq_mask[t.slot, start_pos] = True with torch.inference_mode(): outputs = self.model( - input_tensor, - input_mask=active_mask, + input_ids.unsqueeze(1), + input_mask=self.seq_mask[:batch_sz], persistent_key_values=self.kv_cache, start_pos=start_pos, ) logits = outputs["logits"][:, -1, :] - next_token_ids = [] - for i, task in enumerate(tasks): - logit = logits[i : i + 1] + next_tokens = [] + for i, t in enumerate(tasks): logit = apply_sampling_strategies( - logit, - task.temperature, - task.top_k, - task.top_p, + logits[i : i + 1], + t.temperature, + t.top_k, + t.top_p, ) - probs = torch.softmax(logit, dim=-1) - next_token = torch.multinomial(probs, num_samples=1) - next_token_ids.append(next_token.item()) + prob = torch.softmax(logit, dim=-1) + ntok = torch.multinomial(prob, num_samples=1).item() + next_tokens.append(ntok) - for task, next_token in zip(tasks, next_token_ids): - task.output_ids.append(next_token) - task.output_tokens += 1 + for t, ntok in zip(tasks, next_tokens): + t.output_ids.append(ntok) + t.output_tokens += 1 - pos = task.input_tokens + task.output_tokens - if task.slot >= 0 and pos < self.max_seq_len: - self.seq_mask[task.slot, pos] = True + if t.stream_callback: + token_str = self.tokenizer.decode([ntok]) + t.stream_callback(token_str) - if task.stream_callback: - token_str = self.tokenizer.decode([next_token]) - task.stream_callback(token_str) - - for task in tasks: - if task.output_tokens >= task.max_tokens or ( - task.output_ids and task.output_ids[-1] in self.tokenizer.stop_ids - ): - if task.stream_callback: - task.stream_callback("[DONE]") + for t in tasks: + if t.is_finished(self.tokenizer.stop_ids): + if t.stream_callback: + t.stream_callback("[DONE]") def _run_generation_loop(self) -> None: - """Main generation loop.""" + """Main generation loop run in a daemon thread. + + Continuously cycles through cleanup, refill, prefill, and decode. + Decode processes only the largest position group to ensure all + batched tasks share the same KV cache write position. + """ while self._running: self._remove_finished_tasks() self._refill_active_batch() - if not self.active_tasks: - self._task_event.wait(timeout=0.01) - self._task_event.clear() - continue + with self._lock: + if not self.active_tasks and not self.waiting_queue: + self._task_event.clear() + self._task_event.wait(timeout=0.01) + continue + tasks = self.active_tasks[:] - new_tasks = [t for t in self.active_tasks if t.output_tokens == 0] - decode_tasks = [t for t in self.active_tasks if t.output_tokens > 0] + to_prefill = [t for t in tasks if t.output_tokens == 0] + if to_prefill: + self._execute_prefill(to_prefill) - if decode_tasks: - start_pos = max(t.input_tokens + t.output_tokens for t in decode_tasks) - else: - start_pos = 0 + pos_groups: Dict[int, List[Task]] = {} + for t in self.active_tasks: + pos_groups.setdefault(t.next_pos, []).append(t) - if new_tasks: - self._execute_prefill(new_tasks) - decode_tasks = new_tasks - start_pos = max(t.input_tokens for t in decode_tasks) + if pos_groups: + best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) + self._execute_decode(pos_groups[best_pos], best_pos) - if decode_tasks: - self._execute_decode(decode_tasks, start_pos) - - if not self.active_tasks and not self.waiting_queue: - self._task_event.wait(timeout=0.05) + if not self.waiting_queue and len(self.active_tasks) <= 1: + self._task_event.wait(timeout=0.005) self._task_event.clear() def start(self) -> None: - """Start the generation loop.""" + """Starts the background generation loop thread.""" if not self._running: self._running = True - self._loop_thread = threading.Thread(target=self._run_generation_loop) - self._loop_thread.daemon = True - self._loop_thread.start() + t = threading.Thread(target=self._run_generation_loop, daemon=True) + t.start() def stop(self) -> None: - """Stop the generation loop.""" + """Stops the generation loop and releases all resources.""" self._running = False + self._task_event.set() if hasattr(self, "_loop_thread"): - self._loop_thread.join(timeout=1.0) - - # Clear KV cache to free GPU memory - if self.kv_cache is not None: - k_cache, v_cache = self.kv_cache - if k_cache is not None: - k_cache.detach() - if v_cache is not None: - v_cache.detach() - - # Clear seq mask - self.seq_mask.detach() - - # Clear task lists + self._loop_thread.join(timeout=2.0) self.waiting_queue.clear() self.active_tasks.clear() + if torch.cuda.is_available(): + torch.cuda.empty_cache() def get_stats(self) -> Dict[str, Any]: - """Get scheduler statistics.""" + """Returns current scheduler statistics. + + Returns: + Dict with total_tasks, total_tokens, active_tasks, waiting_queue. + """ return { "total_tasks": self._total_tasks, "total_tokens": self._total_tokens, diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 8a68ac0..23e3334 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -23,12 +23,10 @@ from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) -# Global model parameter and engine (loaded once) _engine: Optional[InferenceEngine] = None _model_param: Optional[Any] = None _project_root = Path(__file__).parent.parent.parent -# Server configuration (set before running server) _server_config: Dict[str, Any] = { "device": "cuda", "dtype": torch.bfloat16, @@ -43,14 +41,6 @@ def configure_server( param_path: Optional[Path] = None, max_batch_size: int = 16, ): - """Configure server settings before starting. - - Args: - device: Device to load model on (e.g., "cuda", "cpu", "cuda:0") - dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16) - param_path: Path to model parameters directory - max_batch_size: Maximum batch size for continuous batching - """ _server_config["device"] = device _server_config["dtype"] = dtype _server_config["param_path"] = param_path @@ -59,9 +49,7 @@ def configure_server( @asynccontextmanager async def lifespan(app: FastAPI): - """Lifespan context manager for startup and shutdown events.""" global _model_param, _engine - # Startup: Load model with configured settings try: load_model( param_path=_server_config["param_path"], @@ -73,7 +61,6 @@ async def lifespan(app: FastAPI): logger.error(f"Failed to load model: {e}") raise yield - # Shutdown: Cleanup engine if _engine: _engine.shutdown() logger.info("Inference engine shutdown complete") @@ -88,20 +75,17 @@ def load_model( dtype: torch.dtype = torch.bfloat16, max_batch_size: int = 16, ): - """Load model parameters and initialize inference engine.""" global _model_param, _engine if param_path is None: param_path = _project_root / "params" if not param_path.exists(): raise FileNotFoundError(f"Parameter directory not found: {param_path}") - # Load tokenizer separately tokenizer = AutoTokenizer.from_pretrained(param_path) _model_param = AutoModel.from_pretrained(param_path) _model_param.to(device=device, dtype=dtype) logger.info(f"Model loaded on {device} with dtype {dtype}") - # Initialize inference engine with separate model and tokenizer _engine = InferenceEngine( model=_model_param, tokenizer=tokenizer, @@ -110,9 +94,8 @@ def load_model( logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") -# Pydantic models for API request/response class ChatMessage(BaseModel): - role: str # "user", "assistant", "system" + role: str content: str @@ -145,7 +128,6 @@ async def health(): @app.get("/stats") async def get_stats(): - """Get inference engine statistics.""" if _engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") return _engine.get_stats() @@ -153,46 +135,36 @@ async def get_stats(): @app.post("/v1/chat/completions", response_model=CompletionResponse) async def chat_completion(request: ChatCompletionRequest): - """OpenAI-compatible chat completion endpoint. - - Supports both streaming and non-streaming modes with continuous batching. - """ if _engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") - # Convert messages to prompt using engine's tokenizer - # Extract system prompt if present, then apply chat template - # Apply chat template directly with messages prompt = _engine.tokenizer.apply_chat_template( [{"role": m.role, "content": m.content} for m in request.messages], tokenize=False, ) if request.stream: - # Streaming response (use synchronous generator) - generator = _engine.generate( + agen = _engine.generate_async( prompt=prompt, - stream=True, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, ) - def generate_stream(): - for token in generator: + async def event_stream(): + async for token in agen: if token == "[DONE]": break yield f"data: {json.dumps({'choices': [{'delta': {'content': token}}]})}\n\n" yield "data: [DONE]\n\n" return StreamingResponse( - generate_stream(), + event_stream(), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, ) else: - # Non-streaming response result = _engine.generate( prompt=prompt, stream=False, @@ -202,7 +174,6 @@ async def chat_completion(request: ChatCompletionRequest): top_k=request.top_k, ) - # Build OpenAI-style response import time resp = CompletionResponse( @@ -229,52 +200,35 @@ async def generate( max_len: int = 2048, stream: bool = False, ): - """Simple generation endpoint. - - Args: - query: Input query string - history: Conversation history as list of [user, assistant] pairs - temperature: Sampling temperature - top_p: Top-p sampling parameter - top_k: Top-k sampling parameter - max_len: Maximum tokens to generate - stream: Enable streaming output - - Returns: - dict: Generation result with response field - """ if _engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") - # Build messages for chat template messages = [] if history: - # Convert history format: List[List[str]] -> List[Dict] for h in history: if len(h) >= 2: messages.append({"role": "user", "content": h[0]}) messages.append({"role": "assistant", "content": h[1]}) messages.append({"role": "user", "content": query}) - # Use tokenizer's chat template prompt = _engine.tokenizer.apply_chat_template(messages, tokenize=False) if stream: - # Synchronous streaming - result = _engine.generate( + agen = _engine.generate_async( prompt=prompt, - stream=True, max_tokens=max_len, temperature=temperature, top_p=top_p, top_k=top_k, ) - def stream_generator(): - for token in result: + async def text_stream(): + async for token in agen: + if token == "[DONE]": + break yield token + "\n" - return StreamingResponse(stream_generator(), media_type="text/plain") + return StreamingResponse(text_stream(), media_type="text/plain") else: result = _engine.generate( prompt=prompt, @@ -296,17 +250,6 @@ def run_server( param_path: Optional[Path] = None, max_batch_size: int = 16, ): - """Run the FastAPI server with uvicorn. - - Args: - host: Server host address - port: Server port number - reload: Enable auto-reload for development - device: Device to load model on (e.g., "cuda", "cpu", "cuda:0") - dtype: Data type for model weights (e.g., torch.bfloat16, torch.float16) - param_path: Path to model parameters directory - max_batch_size: Maximum batch size for continuous batching - """ configure_server( device=device, dtype=dtype, diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py index dd35692..d17d21b 100644 --- a/tests/inference/conftest.py +++ b/tests/inference/conftest.py @@ -32,8 +32,15 @@ def mock_model_param(): @pytest.fixture def mock_engine(): """Create a mock InferenceEngine.""" + + async def _async_gen(): + yield "chunk1" + yield "chunk2" + yield "[DONE]" + mock = MagicMock() mock.generate.return_value = "mock response" + mock.generate_async.return_value = _async_gen() mock.get_stats.return_value = { "total_tasks": 0, "total_tokens": 0, diff --git a/tests/inference/test_scheduler_concurrency.py b/tests/inference/test_scheduler_concurrency.py index 1c63995..28e4967 100644 --- a/tests/inference/test_scheduler_concurrency.py +++ b/tests/inference/test_scheduler_concurrency.py @@ -21,7 +21,7 @@ def test_prefix_cache_concurrent_insert_find(): def insert_worker(): try: for i in range(50): - cache.insert((i,), slot=i % 10) + cache.insert((i,), slot=i % 10, slot_ver=0) results["inserts"] += 1 except Exception as e: results["errors"].append(str(e)) @@ -29,7 +29,7 @@ def test_prefix_cache_concurrent_insert_find(): def find_worker(): try: for i in range(50): - cache.find_longest_prefix([i]) + cache.find([i]) results["finds"] += 1 except Exception as e: results["errors"].append(str(e)) @@ -53,7 +53,7 @@ def test_prefix_cache_concurrent_release(): # Insert some prefixes for i in range(10): - cache.insert((i,), slot=i) + cache.insert((i,), slot=i, slot_ver=0) results = {"errors": []} @@ -84,10 +84,10 @@ def test_prefix_cache_concurrent_insert_release_find(): try: for i in range(20): token_ids = (worker_id * 100 + i,) - cache.insert(token_ids, slot=worker_id) + cache.insert(token_ids, slot=worker_id, slot_ver=0) # Find after insert - cache.find_longest_prefix(list(token_ids)) + cache.find(list(token_ids)) # Release cache.release(token_ids) @@ -277,7 +277,7 @@ def test_prefix_cache_insert_same_prefix_concurrently(): def insert_worker(): try: # All workers try to insert the same prefix - cache.insert((1, 2, 3), slot=threading.current_thread().name) + cache.insert((1, 2, 3), slot=0, slot_ver=0) node = cache.root.children.get(1) if node: node = node.children.get(2) @@ -306,8 +306,7 @@ def test_prefix_cache_ref_count_underflow_prevention(): """Test that ref_count doesn't go negative.""" cache = PrefixCacheManager(max_capacity=100) - # Insert a prefix - cache.insert((1, 2, 3), slot=0) + cache.insert((1, 2, 3), slot=0, slot_ver=0) # Release multiple times for _ in range(5): diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index c1a0a7e..5bdcace 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -100,13 +100,12 @@ def test_chat_completions_non_stream(client, loaded_model, mock_engine, monkeypa def test_chat_completions_stream(client, loaded_model, mock_engine, monkeypatch): """POST /v1/chat/completions with stream=true returns SSE stream.""" - # Simulate a streaming generator that yields cumulative responses - def stream_gen(): + async def async_gen(): yield "cumulative1" yield "cumulative2" yield "[DONE]" - mock_engine.generate.return_value = stream_gen() + mock_engine.generate_async.return_value = async_gen() monkeypatch.setattr("astrai.inference.server._engine", mock_engine) response = client.post( "/v1/chat/completions",