From df0845e91670f5582e5c272dfbe9a75e0d92b916 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 12 May 2026 13:44:55 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=A7=A3=E8=80=A6=20Executor/Schedule?= =?UTF-8?q?r/TaskManager=EF=BC=8C=E4=BF=AE=E5=A4=8D=20stop=20=E9=A1=B5?= =?UTF-8?q?=E6=B3=84=E6=BC=8F=EF=BC=8C=E7=A7=BB=E9=99=A4=20ServerState=20?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E5=8D=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/inference/cache.py | 25 ++- astrai/inference/engine.py | 358 +++++++++------------------------ astrai/inference/executor.py | 42 +--- astrai/inference/scheduler.py | 64 ++++-- astrai/inference/server.py | 182 +++++++---------- astrai/inference/task.py | 42 ++-- tests/inference/conftest.py | 12 +- tests/inference/test_cache.py | 51 +++-- tests/inference/test_engine.py | 28 +-- tests/inference/test_server.py | 41 ++-- 10 files changed, 336 insertions(+), 509 deletions(-) diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index 968516f..5522c9b 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -104,8 +104,7 @@ class PrefixCache: class TaskTable: """Maps task_ids to page tables and cached token counts.""" - def __init__(self, pool: PagePool, page_size: int): - self._pool = pool + def __init__(self, page_size: int): self._page_size = page_size self._pages: Dict[str, List[int]] = {} self._cached: Dict[str, int] = {} @@ -125,15 +124,8 @@ class TaskTable: cached = self._cached.pop(task_id, 0) return pages, cached - def extend(self, task_id: str, pos: int) -> bool: - page_table = self._pages[task_id] - needed = (pos + 1 + self._page_size - 1) // self._page_size - while len(page_table) < needed: - p = self._pool.alloc() - if p < 0: - return False - page_table.append(p) - return True + def get_ref(self, task_id: str) -> List[int]: + return self._pages.setdefault(task_id, []) def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: states = [self._pages.get(tid, []) for tid in task_ids] @@ -158,7 +150,7 @@ class PagedCache: self.page_size = page_size self._prefix = PrefixCache(page_size) self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict) - self._table = TaskTable(self._pool, page_size) + self._table = TaskTable(page_size) self.k_cache = torch.empty( (n_layers, n_pages, page_size, n_kv_heads, head_dim), @@ -219,7 +211,14 @@ class PagedCache: self.free(idx) def task_extend(self, task_id: str, pos: int) -> bool: - return self._table.extend(task_id, pos) + page_table = self._table.get(task_id) + needed = (pos + 1 + self.page_size - 1) // self.page_size + while len(page_table) < needed: + p = self._pool.alloc() + if p < 0: + return False + page_table.append(p) + return True def task_cached(self, task_id: str) -> int: return self._table.get_cached(task_id) diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 040ca03..00a73d2 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -1,11 +1,4 @@ -"""Unified inference engine for continuous batching. - -Layers: - - GenerationParams: Immutable value object for sampling parameters. - - GenerationRequest: User-facing request DTO with validation. - - _Result: Thread-safe token accumulator (Observer pattern). - - InferenceEngine: Facade over InferenceScheduler + async wrapper. -""" +"""Unified inference engine for continuous batching.""" import asyncio import gc @@ -21,6 +14,59 @@ from astrai.inference.task import STOP from astrai.tokenize import AutoTokenizer +class GenerateResult: + """Thread-safe token accumulator for streaming and non-streaming modes.""" + + def __init__(self, count: int = 1): + self._cond = threading.Condition() + self._event = threading.Event() + self.tokens: List[Tuple[int, str]] = [] + self.results: List[str] = [""] * count + self._done: List[bool] = [False] * count + self._completed = 0 + self._total = count + + def append(self, token: str, idx: int = 0): + with self._cond: + self.tokens.append((idx, token)) + if token is not STOP: + self.results[idx] += token + else: + if not self._done[idx]: + self._done[idx] = True + self._completed += 1 + self._cond.notify_all() + self._event.set() + + def pop_all(self) -> List[Tuple[int, str]]: + with self._cond: + out = self.tokens.copy() + self.tokens.clear() + if not out: + self._event.clear() + return out + + def wait(self, timeout: Optional[float] = None) -> bool: + return self._event.wait(timeout=timeout) + + def wait_completion(self) -> None: + with self._cond: + self._cond.wait_for(lambda: self._completed >= self._total) + + def get_results(self) -> List[str]: + with self._cond: + return self.results.copy() + + +def _validate_params(top_k: int, top_p: float, temperature: float) -> None: + if not (isinstance(top_k, int) and top_k >= 0): + raise ValueError("top_k must be a non-negative integer") + if not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be a float between 0.0 and 1.0") + if not (isinstance(temperature, (int, float)) and temperature >= 0): + raise ValueError("temperature must be a non-negative number") + + @dataclass(frozen=True) class GenerationParams: """Immutable value object for sampling hyperparameters.""" @@ -32,11 +78,7 @@ class GenerationParams: class GenerationRequest: - """Request parameters for text generation. - - Encapsulates messages, sampling parameters (via GenerationParams), - and streaming preference for a single generation request. - """ + """Request parameters for text generation.""" def __init__( self, @@ -47,16 +89,6 @@ class GenerationRequest: max_len: int = 1024, stream: bool = False, ): - """Initializes a generation request. - - Args: - messages: Conversation history as list of {"role": ..., "content": ...}. - top_k: Top-k sampling count (0 disables). - top_p: Nucleus sampling probability threshold. - temperature: Sampling temperature. - max_len: Maximum tokens to generate. - stream: Whether to return output as a token stream. - """ self.messages = messages self.params = GenerationParams( top_k=top_k, @@ -65,7 +97,7 @@ class GenerationRequest: max_tokens=max_len, ) self.stream = stream - self._validate() + _validate_params(top_k, top_p, temperature) @property def top_k(self) -> int: @@ -83,112 +115,9 @@ class GenerationRequest: def max_len(self) -> int: return self.params.max_tokens - def _validate(self): - """Validates sampling parameter ranges.""" - if not (isinstance(self.top_k, int) and self.top_k >= 0): - raise ValueError("top_k must be a non-negative integer") - if not (0.0 <= self.top_p <= 1.0): - raise ValueError("top_p must be a float between 0.0 and 1.0") - if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0): - raise ValueError("temperature must be a non-negative number") - - -class _Result: - """Thread-safe token accumulator for streaming and non-streaming modes. - - Supports multiple concurrent generation tasks with per-index result tracking. - Uses a threading.Condition for efficient completion notification - and a threading.Event for streaming wakeup. - """ - - def __init__(self, count: int = 1): - """Initializes the accumulator. - - Args: - count: Number of concurrent generation tasks to track. - """ - self._cond = threading.Condition() - self._event = threading.Event() - self.tokens: List[str] = [] - self.results: List[str] = [""] * count - self._done: List[bool] = [False] * count - self._completed = 0 - self._total = count - - def append(self, token: str, idx: int = 0): - """Appends a token to the result buffer. - - In non-streaming mode, tokens are concatenated into results[idx]. - The sentinel STOP marks a task as complete. - - Args: - token: The decoded token string, or STOP sentinel. - idx: Index of the generation task this token belongs to. - """ - with self._cond: - self.tokens.append((idx, token)) - if token is not STOP: - self.results[idx] += token - else: - if not self._done[idx]: - self._done[idx] = True - self._completed += 1 - self._cond.notify_all() - self._event.set() - - def pop_all(self) -> List[Tuple[int, str]]: - """Returns and clears all accumulated (idx, token) pairs. - - Returns: - List of (index, token_string) tuples since the last call. - """ - with self._cond: - out = self.tokens.copy() - self.tokens.clear() - if not out: - self._event.clear() - return out - - def wait(self, timeout: Optional[float] = None) -> bool: - """Blocks until new tokens arrive or the timeout expires. - - Args: - timeout: Maximum wait time in seconds (None = infinite). - - Returns: - True if the event was set (new data available), False on timeout. - """ - return self._event.wait(timeout=timeout) - - def wait_completion(self) -> None: - """Blocks until all tasks complete (non-streaming). - - Uses a Condition to sleep efficiently instead of busy-waiting. - The calling thread is parked until a STOP signal arrives. - """ - with self._cond: - self._cond.wait_for(lambda: self._completed >= self._total) - - def get_results(self) -> List[str]: - """Returns all accumulated results for non-streaming mode. - - Returns: - List of complete generated strings, one per task index. - """ - with self._cond: - return self.results.copy() - class InferenceEngine: - """Unified inference engine backed by continuous-batching scheduler. - - Usage: - with InferenceEngine(model, tokenizer) as engine: - for token in engine.generate("hello", stream=True): - print(token, end="") - - text = engine.generate("hello") - """ + """Unified inference engine backed by continuous-batching scheduler.""" def __init__( self, @@ -199,17 +128,6 @@ class InferenceEngine: max_prompt_len: int = 2048, page_size: int = 128, ): - """Initializes the inference engine. - - Args: - model: The model instance. - tokenizer: The tokenizer instance. - max_batch_size: Maximum number of concurrent tasks. - max_seq_len: Maximum sequence length. - max_prompt_len: Maximum prompt tokens. - compile: Whether to compile the model with torch.compile. - page_size: Number of tokens per KV cache page. - """ self.model = model self.tokenizer = tokenizer self.scheduler = InferenceScheduler( @@ -239,22 +157,8 @@ class InferenceEngine: top_p: float = 1.0, top_k: int = 50, ) -> Union[Generator, str, List[str]]: - """Generates text from a prompt. + _validate_params(top_k, top_p, temperature) - Args: - prompt: Single string or list of strings for batch generation. - stream: If True, returns a generator yielding tokens. - max_tokens: Maximum number of tokens to generate. - temperature: Sampling temperature. - top_p: Nucleus sampling probability threshold. - top_k: Top-k sampling count (0 disables). - - Returns: - stream=False, single prompt: str - stream=False, batch: List[str] - stream=True, single prompt: Generator[str, None, None] - stream=True, batch: Generator[Tuple[int, str], None, None] - """ is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] @@ -275,21 +179,6 @@ class InferenceEngine: top_p: float = 1.0, top_k: int = 50, ) -> AsyncGenerator[str, None]: - """Async streaming generator that does not block the event loop. - - Runs the synchronous generator in a background thread pool executor, - yielding tokens to the async consumer as they arrive. - - Args: - prompt: Input text to generate from. - max_tokens: Maximum tokens to generate. - temperature: Sampling temperature. - top_p: Nucleus sampling threshold. - top_k: Top-k sampling count. - - Yields: - Decoded token strings as they are generated. - """ sync_gen = self._generate_streaming( [prompt], False, max_tokens, temperature, top_p, top_k ) @@ -306,14 +195,6 @@ class InferenceEngine: @staticmethod def _next_token(gen: Generator) -> Optional[str]: - """Retrieves the next token from a synchronous generator. - - Args: - gen: A synchronous generator yielding token strings. - - Returns: - The next token, or None if the generator is exhausted. - """ try: return next(gen) except StopIteration: @@ -322,16 +203,6 @@ class InferenceEngine: def generate_with_request( self, request: GenerationRequest ) -> Union[Generator[str, None, None], str, List[str]]: - """Generates text from a structured GenerationRequest. - - Applies the chat template to the request's messages before generation. - - Args: - request: A GenerationRequest with messages and parameters. - - Returns: - Generator, string, or list of strings (see generate()). - """ prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) return self.generate( prompt=prompt, @@ -342,6 +213,37 @@ class InferenceEngine: top_k=request.params.top_k, ) + def _submit_tasks( + self, + prompts: List[str], + max_tokens: int, + temperature: float, + top_p: float, + top_k: int, + ) -> Tuple[GenerateResult, List[str]]: + n = len(prompts) + result = GenerateResult(count=n) + task_ids = [] + for i, p in enumerate(prompts): + cb = self._make_callback(result, i) + task_id = self.scheduler.add_task( + prompt=p, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=cb, + ) + task_ids.append(task_id) + return result, task_ids + + @staticmethod + def _make_callback(result: GenerateResult, idx: int): + def cb(token): + result.append(token, idx) + + return cb + def _generate_streaming( self, prompts: List[str], @@ -351,38 +253,10 @@ class InferenceEngine: top_p: float, top_k: int, ) -> Generator: - """Internal streaming generator. - - Polls the _Result accumulator in a loop, yielding tokens as they arrive. - Single prompt yields raw token strings; batch yields (idx, token) tuples. - - Args: - prompts: List of prompts. - is_batch: If True, yields (idx, token) tuples; else yields raw tokens. - max_tokens: Maximum tokens to generate. - temperature: Sampling temperature. - top_p: Nucleus sampling threshold. - top_k: Top-k sampling count. - - Yields: - Single prompt: decoded token strings. - Batch: (sequence_index, token_string) tuples. - """ + result, task_ids = self._submit_tasks( + prompts, max_tokens, temperature, top_p, top_k + ) n = len(prompts) - result = _Result(count=n) - task_ids = [] - - for i, p in enumerate(prompts): - task_id = self.scheduler.add_task( - prompt=p, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream_callback=lambda tok, idx=i: result.append(tok, idx), - ) - task_ids.append(task_id) - remaining = n finished = [False] * n @@ -399,8 +273,7 @@ class InferenceEngine: else: yield (idx, token) if is_batch else token if remaining > 0: - if not result.wait(timeout=0.05): - pass + result.wait(timeout=0.05) finally: for tid in task_ids: self.scheduler.remove_task(tid) @@ -416,57 +289,22 @@ class InferenceEngine: top_p: float, top_k: int, ) -> Union[str, List[str]]: - """Internal non-streaming generator. - - Submits all prompts to the scheduler and waits for all to complete. - - Args: - prompts: List of prompt strings. - is_batch: Whether multiple prompts were provided. - max_tokens: Maximum tokens to generate. - temperature: Sampling temperature. - top_p: Nucleus sampling threshold. - top_k: Top-k sampling count. - - Returns: - Single string for one prompt, list of strings for batch. - """ - result = _Result(count=len(prompts)) - task_ids = [] - - for i, p in enumerate(prompts): - - def make_cb(idx): - return lambda tok: result.append(tok, idx) - - task_id = self.scheduler.add_task( - prompt=p, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream_callback=make_cb(i), - ) - task_ids.append(task_id) + result, task_ids = self._submit_tasks( + prompts, max_tokens, temperature, top_p, top_k + ) result.wait_completion() - for task_id in task_ids: - self.scheduler.remove_task(task_id) + for tid in task_ids: + self.scheduler.remove_task(tid) res = result.get_results() return res if is_batch else res[0] def get_stats(self) -> Dict[str, Any]: - """Returns current engine statistics. - - Returns: - Dict with total_tasks, total_tokens, active_tasks, waiting_queue. - """ return self.scheduler.get_stats() def shutdown(self) -> None: - """Shuts down the engine, stops the scheduler, and frees GPU memory.""" self.scheduler.stop() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/astrai/inference/executor.py b/astrai/inference/executor.py index 62657f2..f6bb110 100644 --- a/astrai/inference/executor.py +++ b/astrai/inference/executor.py @@ -5,7 +5,7 @@ import torch from astrai.inference.cache import PagedCache from astrai.inference.sample import sample -from astrai.inference.task import STOP, Task, TaskStatus +from astrai.inference.task import Task from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer @@ -60,31 +60,10 @@ class Executor: paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) - start_logical_page = start_pos // self.page_cache.page_size - for t in tasks: - self.page_cache.task_record_hashes( - t.task_id, t.prompt_ids, start_logical_page=start_logical_page - ) - - def execute_decode(self, tasks: List[Task], start_pos: int) -> None: + def execute_decode(self, tasks: List[Task], start_pos: int) -> List[int]: if not tasks: - return + return [] - tasks = sorted(tasks, key=lambda t: t.task_id) - - valid: List[Task] = [] - for t in tasks: - if self.page_cache.task_extend(t.task_id, start_pos): - valid.append(t) - else: - t.status = TaskStatus.ABORTED - if t.stream_callback: - t.stream_callback(STOP) - - if not valid: - return - - tasks = valid batch_sz = len(tasks) input_ids = torch.tensor( @@ -112,22 +91,9 @@ class Executor: ) logits = outputs["logits"][:, -1, :] - next_tokens = sample( + return sample( logits, temperature=temperatures, top_k=top_ks, top_p=top_ps, ).tolist() - - for t, ntok in zip(tasks, next_tokens): - t.output_ids.append(ntok) - t.output_tokens += 1 - pos = t.input_tokens + t.output_tokens - self.page_cache.task_extend(t.task_id, pos) - if t.stream_callback: - t.stream_callback(self.tokenizer.decode([ntok])) - - for t in tasks: - if t.is_finished(self.tokenizer.stop_ids): - if t.stream_callback: - t.stream_callback(STOP) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 87c278e..c6dc5e0 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -6,7 +6,7 @@ import torch from astrai.inference.cache import PagedCache from astrai.inference.executor import Executor -from astrai.inference.task import STOP, Task, TaskManager +from astrai.inference.task import STOP, Task, TaskManager, TaskStatus from astrai.model.automodel import AutoModel from astrai.tokenize.tokenizer import AutoTokenizer @@ -75,17 +75,15 @@ class InferenceScheduler: return self._task_mgr.get_stats() def _run_generation_loop(self) -> None: + stop_ids = self._task_mgr.tokenizer.stop_ids try: while self._running: - finished = self._task_mgr.remove_finished_tasks( - self._task_mgr.tokenizer.stop_ids - ) + finished = self._task_mgr.remove_finished_tasks(stop_ids) for task in finished: self._page_cache.task_free(task.task_id) - available = self._task_mgr.max_batch_size - len( - self._task_mgr.active_tasks - ) + active = self._task_mgr.get_active_tasks() + available = self._task_mgr.max_batch_size - len(active) if available > 0: candidates = self._task_mgr.pull_candidates(available) failed = [] @@ -102,7 +100,7 @@ class InferenceScheduler: continue to_prefill = [ - t for t in self._task_mgr.active_tasks if t.output_tokens == 0 + t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0 ] if to_prefill: for t in to_prefill: @@ -118,23 +116,56 @@ class InferenceScheduler: for (prompt_len, start_pos), group in groups.items(): self._executor.execute_prefill(group, prompt_len, start_pos) + start_logical_page = start_pos // self._page_cache.page_size + for t in group: + self._page_cache.task_record_hashes( + t.task_id, + t.prompt_ids, + start_logical_page=start_logical_page, + ) pos_groups: Dict[int, List[Task]] = {} - for t in self._task_mgr.active_tasks: + for t in self._task_mgr.get_active_tasks(): pos_groups.setdefault(t.next_pos, []).append(t) if pos_groups: best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) - self._executor.execute_decode(pos_groups[best_pos], best_pos) + group = sorted(pos_groups[best_pos], key=lambda t: t.task_id) + + valid: List[Task] = [] + for t in group: + if self._page_cache.task_extend(t.task_id, best_pos): + valid.append(t) + else: + t.status = TaskStatus.ABORTED + if t.stream_callback: + t.stream_callback(STOP) + + if valid: + next_tokens = self._executor.execute_decode(valid, best_pos) + + for t, ntok in zip(valid, next_tokens): + t.output_ids.append(ntok) + t.output_tokens += 1 + pos = t.input_tokens + t.output_tokens + self._page_cache.task_extend(t.task_id, pos) + if t.stream_callback: + t.stream_callback( + self._task_mgr.tokenizer.decode([ntok]) + ) + + for t in valid: + if t.is_finished(stop_ids): + if t.stream_callback: + t.stream_callback(STOP) except Exception as e: logger.error(f"Scheduler loop crashed: {e}", exc_info=True) - for task in self._task_mgr.active_tasks: - if task.stream_callback: - task.stream_callback(STOP) - for task in self._task_mgr.waiting_queue: + for task in self._task_mgr.get_active_tasks(): if task.stream_callback: task.stream_callback(STOP) + self._page_cache.task_free(task.task_id) + self._task_mgr.clear_queues() raise def start(self) -> None: @@ -149,7 +180,8 @@ class InferenceScheduler: self._task_mgr.wake() if hasattr(self, "_loop_thread"): self._loop_thread.join(timeout=2.0) - self._task_mgr.waiting_queue.clear() - self._task_mgr.active_tasks.clear() + for task in self._task_mgr.get_active_tasks(): + self._page_cache.task_free(task.task_id) + self._task_mgr.clear_queues() if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 7216eaa..370eedc 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union import torch import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel, Field @@ -25,20 +25,6 @@ logger = logging.getLogger(__name__) _project_root = Path(__file__).parent.parent.parent -class ServerState: - def __init__(self): - self.engine: Optional[InferenceEngine] = None - self.config: Dict[str, Any] = { - "device": "cuda", - "dtype": torch.bfloat16, - "param_path": None, - "max_batch_size": 16, - } - - -_state = ServerState() - - class ChatMessage(BaseModel): role: str content: str @@ -81,47 +67,12 @@ class MessagesRequest(BaseModel): stop_sequences: Optional[List[str]] = None -def configure_server( - device: str = "cuda", - dtype: torch.dtype = torch.bfloat16, - param_path: Optional[Path] = None, - max_batch_size: int = 16, -): - _state.config.update( - device=device, - dtype=dtype, - param_path=param_path, - max_batch_size=max_batch_size, - ) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - try: - load_model( - param_path=_state.config["param_path"], - device=_state.config["device"], - dtype=_state.config["dtype"], - max_batch_size=_state.config["max_batch_size"], - ) - except Exception as e: - logger.error(f"Failed to load model: {e}") - raise - yield - if _state.engine: - _state.engine.shutdown() - logger.info("Inference engine shutdown complete") - - -app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan) - - -def load_model( +def _create_engine( param_path: Optional[Path] = None, device: str = "cuda", dtype: torch.dtype = torch.bfloat16, max_batch_size: int = 16, -): +) -> InferenceEngine: if param_path is None: param_path = _project_root / "params" if not param_path.exists(): @@ -132,18 +83,38 @@ def load_model( model.to(device=device, dtype=dtype) logger.info(f"Model loaded on {device} with dtype {dtype}") - _state.engine = InferenceEngine( + engine = InferenceEngine( model=model, tokenizer=tokenizer, max_batch_size=max_batch_size, ) logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") + return engine -def _get_engine() -> InferenceEngine: - if _state.engine is None: +@asynccontextmanager +async def lifespan(app: FastAPI): + config = app.state.server_config + if not config.get("_test", False): + try: + app.state.engine = _create_engine(**config) + except Exception as e: + logger.error(f"Failed to load model: {e}") + raise + yield + if app.state.engine: + app.state.engine.shutdown() + logger.info("Inference engine shutdown complete") + + +app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan) + + +def _get_engine(request: Request) -> InferenceEngine: + engine = request.app.state.engine + if engine is None: raise HTTPException(status_code=503, detail="Engine not initialized") - return _state.engine + return engine def _make_chunk( @@ -155,7 +126,6 @@ def _make_chunk( model: str, index: int = 0, ) -> str: - """Build a single SSE ``data:`` chunk matching OpenAI streaming format.""" data = { "id": resp_id, "object": "chat.completion.chunk", @@ -172,23 +142,56 @@ def _make_chunk( return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" +def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str: + return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" + + +def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]: + for seq in stop_sequences: + if seq and seq in text: + return seq + return None + + +def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + return block.get("text", "") + return "" + + +def _build_anthropic_messages( + messages: List[AnthropicMessage], system: Optional[str] +) -> List[Dict[str, str]]: + result: List[Dict[str, str]] = [] + if system: + result.append({"role": "system", "content": system}) + for m in messages: + content = _extract_text_content(m.content) + if content: + result.append({"role": m.role, "content": content}) + return result + + @app.get("/health") -async def health(): +async def health(request: Request): return { "status": "ok", - "model_loaded": _state.engine is not None, + "model_loaded": request.app.state.engine is not None, } @app.get("/stats") -async def get_stats(): - return _get_engine().get_stats() +async def get_stats(request: Request): + return _get_engine(request).get_stats() @app.post("/v1/chat/completions") -async def chat_completion(request: ChatCompletionRequest): - """OpenAI-compatible chat completion endpoint (streaming + non-streaming).""" - engine = _get_engine() +async def chat_completion(request: ChatCompletionRequest, req: Request): + engine = _get_engine(req) resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" created = int(time.time()) model = request.model @@ -284,44 +287,9 @@ async def chat_completion(request: ChatCompletionRequest): } -def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str: - return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - - -def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]: - for seq in stop_sequences: - if seq and seq in text: - return seq - return None - - -def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str: - if isinstance(content, str): - return content - if isinstance(content, list): - for block in content: - if isinstance(block, dict) and block.get("type") == "text": - return block.get("text", "") - return "" - - -def _build_anthropic_messages( - messages: List[AnthropicMessage], system: Optional[str] -) -> List[Dict[str, str]]: - result: List[Dict[str, str]] = [] - if system: - result.append({"role": "system", "content": system}) - for m in messages: - content = _extract_text_content(m.content) - if content: - result.append({"role": m.role, "content": content}) - return result - - @app.post("/v1/messages") -async def create_message(request: MessagesRequest): - """Anthropic-compatible Messages API endpoint (streaming + non-streaming).""" - engine = _get_engine() +async def create_message(request: MessagesRequest, req: Request): + engine = _get_engine(req) resp_id = f"msg_{uuid.uuid4().hex[:24]}" model = request.model @@ -472,12 +440,12 @@ def run_server( param_path: Optional[Path] = None, max_batch_size: int = 16, ): - configure_server( - device=device, - dtype=dtype, - param_path=param_path, - max_batch_size=max_batch_size, - ) + app.state.server_config = { + "device": device, + "dtype": dtype, + "param_path": param_path, + "max_batch_size": max_batch_size, + } uvicorn.run( "astrai.inference.server:app", host=host, diff --git a/astrai/inference/task.py b/astrai/inference/task.py index 76a571e..d31fcfb 100644 --- a/astrai/inference/task.py +++ b/astrai/inference/task.py @@ -139,23 +139,24 @@ class TaskManager: } def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]: - finished = [] - for task in self.active_tasks: - if task.status == TaskStatus.ABORTED: - task.finish_time = time.time() - finished.append(task) - elif task.is_finished(stop_ids): - task.status = TaskStatus.FINISHED - task.finish_time = time.time() - finished.append(task) - self._total_tokens += task.output_tokens + with self._lock: + finished = [] + for task in self.active_tasks: + if task.status == TaskStatus.ABORTED: + task.finish_time = time.time() + finished.append(task) + elif task.is_finished(stop_ids): + task.status = TaskStatus.FINISHED + task.finish_time = time.time() + finished.append(task) + self._total_tokens += task.output_tokens - self.active_tasks = [ - t - for t in self.active_tasks - if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED) - ] - return finished + self.active_tasks = [ + t + for t in self.active_tasks + if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED) + ] + return finished def pull_candidates(self, n: int) -> List[Task]: to_add: List[Task] = [] @@ -180,5 +181,14 @@ class TaskManager: self._task_event.clear() self._task_event.wait(timeout=timeout) + def get_active_tasks(self) -> List[Task]: + with self._lock: + return list(self.active_tasks) + + def clear_queues(self) -> None: + with self._lock: + self.waiting_queue.clear() + self.active_tasks.clear() + def wake(self) -> None: self._task_event.set() diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py index 0ffa4a3..c782ba1 100644 --- a/tests/inference/conftest.py +++ b/tests/inference/conftest.py @@ -11,6 +11,14 @@ from astrai.inference.server import app @pytest.fixture def client(): """Provide a test client for the FastAPI app.""" + app.state.server_config = { + "device": "cpu", + "dtype": "bfloat16", + "param_path": None, + "max_batch_size": 1, + "_test": True, + } + app.state.engine = None return TestClient(app) @@ -39,7 +47,7 @@ def mock_engine(): @pytest.fixture -def loaded_model(mock_engine, monkeypatch): +def loaded_model(client, mock_engine): """Simulate that the engine is loaded.""" - monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) + app.state.engine = mock_engine return mock_engine diff --git a/tests/inference/test_cache.py b/tests/inference/test_cache.py index f9c80d3..cc410e4 100644 --- a/tests/inference/test_cache.py +++ b/tests/inference/test_cache.py @@ -162,23 +162,20 @@ def test_prefix_cache_has_page(): def test_task_table_set_get(): - pool = PagePool(n_pages=8) - table = TaskTable(pool, page_size=64) + table = TaskTable(page_size=64) table.set("task1", [0, 1, 2], 128) assert table.get("task1") == [0, 1, 2] assert table.get_cached("task1") == 128 def test_task_table_get_missing(): - pool = PagePool(n_pages=8) - table = TaskTable(pool, page_size=64) + table = TaskTable(page_size=64) assert table.get("nonexistent") == [] assert table.get_cached("nonexistent") == 0 def test_task_table_pop(): - pool = PagePool(n_pages=8) - table = TaskTable(pool, page_size=64) + table = TaskTable(page_size=64) table.set("task1", [0, 1], 64) pages, cached = table.pop("task1") assert pages == [0, 1] @@ -186,26 +183,39 @@ def test_task_table_pop(): assert table.get("task1") == [] -def test_task_table_extend_allocates_pages(): - pool = PagePool(n_pages=8) - table = TaskTable(pool, page_size=64) - table.set("task1", [], 0) - ok = table.extend("task1", 200) +def test_paged_cache_task_extend_allocates(): + cache = PagedCache( + n_layers=1, + n_pages=8, + page_size=64, + n_kv_heads=2, + head_dim=8, + device=torch.device("cpu"), + dtype=torch.float32, + ) + cache._table.set("task1", [], 0) + ok = cache.task_extend("task1", 200) assert ok - assert len(table.get("task1")) == 4 + assert len(cache._table.get("task1")) == 4 -def test_task_table_extend_fails_when_pool_full(): - pool = PagePool(n_pages=2) - table = TaskTable(pool, page_size=64) - table.set("task1", [pool.alloc(), pool.alloc()], 0) - ok = table.extend("task1", 300) +def test_paged_cache_task_extend_fails_when_pool_full(): + cache = PagedCache( + n_layers=1, + n_pages=2, + page_size=64, + n_kv_heads=2, + head_dim=8, + device=torch.device("cpu"), + dtype=torch.float32, + ) + cache._table.set("task1", [0, 1], 0) + ok = cache.task_extend("task1", 300) assert not ok def test_task_table_table_tensor(): - pool = PagePool(n_pages=16) - table = TaskTable(pool, page_size=64) + table = TaskTable(page_size=64) table.set("a", [0, 1], 0) table.set("b", [2, 3, 4], 0) t = table.table_tensor(["a", "b"], torch.device("cpu")) @@ -215,8 +225,7 @@ def test_task_table_table_tensor(): def test_task_table_table_tensor_empty_input(): - pool = PagePool(n_pages=4) - table = TaskTable(pool, page_size=64) + table = TaskTable(page_size=64) t = table.table_tensor([], torch.device("cpu")) assert t.numel() == 0 diff --git a/tests/inference/test_engine.py b/tests/inference/test_engine.py index 11180ba..9573357 100644 --- a/tests/inference/test_engine.py +++ b/tests/inference/test_engine.py @@ -1,20 +1,20 @@ -"""Unit tests for _Result accumulator and InferenceEngine.generate().""" +"""Unit tests for GenerateResult accumulator and InferenceEngine.generate().""" import threading from unittest.mock import MagicMock, patch -from astrai.inference.engine import _Result +from astrai.inference.engine import GenerateResult from astrai.inference.task import STOP def test_result_append_single(): - r = _Result(count=1) + r = GenerateResult(count=1) r.append("hello", 0) assert r.results[0] == "hello" def test_result_append_multiple_tasks(): - r = _Result(count=3) + r = GenerateResult(count=3) r.append("a", 0) r.append("b", 1) r.append("c", 2) @@ -24,7 +24,7 @@ def test_result_append_multiple_tasks(): def test_result_stop_marks_complete(): - r = _Result(count=2) + r = GenerateResult(count=2) r.append("text", 0) r.append(STOP, 0) r.append("more", 1) @@ -34,14 +34,14 @@ def test_result_stop_marks_complete(): def test_result_stop_does_not_double_count(): - r = _Result(count=1) + r = GenerateResult(count=1) r.append(STOP, 0) r.append(STOP, 0) assert r._completed == 1 def test_result_pop_all_returns_and_clears(): - r = _Result(count=2) + r = GenerateResult(count=2) r.append("a", 0) r.append("b", 1) out = r.pop_all() @@ -52,7 +52,7 @@ def test_result_pop_all_returns_and_clears(): def test_result_wait_blocks_until_data(): - r = _Result(count=1) + r = GenerateResult(count=1) def delayed_append(): import time @@ -69,13 +69,13 @@ def test_result_wait_blocks_until_data(): def test_result_wait_timeout(): - r = _Result(count=1) + r = GenerateResult(count=1) ok = r.wait(timeout=0.01) assert not ok def test_result_wait_completion_non_streaming(): - r = _Result(count=2) + r = GenerateResult(count=2) def finish_later(): import time @@ -93,7 +93,7 @@ def test_result_wait_completion_non_streaming(): def test_result_get_results(): - r = _Result(count=2) + r = GenerateResult(count=2) r.append("hello", 0) r.append("world", 1) results = r.get_results() @@ -148,9 +148,9 @@ def test_engine_generate_streaming_yields_tokens(): gen = eng.generate("hello", stream=True) cb = callbacks_saved[0] - cb("t1", 0) - cb("t2", 0) - cb(STOP, 0) + cb("t1") + cb("t2") + cb(STOP) tokens = list(gen) assert tokens == ["t1", "t2"] diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index 63fbc67..ef77d29 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -2,10 +2,12 @@ import pytest +from astrai.inference.server import app -def test_health_no_model(client, monkeypatch): + +def test_health_no_model(client): """GET /health should return 200 even when engine not loaded.""" - monkeypatch.setattr("astrai.inference.server._state.engine", None) + app.state.engine = None response = client.get("/health") assert response.status_code == 200 data = response.json() @@ -22,15 +24,14 @@ def test_health_with_model(client, loaded_model): assert data["model_loaded"] is True -def test_chat_completions_non_stream(client, loaded_model, monkeypatch): +def test_chat_completions_non_stream(client, loaded_model): """POST /v1/chat/completions with stream=false returns OpenAI-style JSON.""" async def async_gen(): yield "Assistant reply" - mock_engine = loaded_model - mock_engine.generate_async.return_value = async_gen() - monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) + app.state.engine = loaded_model + loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/chat/completions", json={ @@ -48,16 +49,15 @@ def test_chat_completions_non_stream(client, loaded_model, monkeypatch): assert "prompt_tokens" in data["usage"] -def test_chat_completions_stream(client, loaded_model, monkeypatch): +def test_chat_completions_stream(client, loaded_model): """POST /v1/chat/completions with stream=true returns SSE stream.""" async def async_gen(): yield "cumulative1" yield "cumulative2" - mock_engine = loaded_model - mock_engine.generate_async.return_value = async_gen() - monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) + app.state.engine = loaded_model + loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/chat/completions", json={ @@ -77,15 +77,14 @@ def test_chat_completions_stream(client, loaded_model, monkeypatch): assert any("[DONE]" in line for line in lines) -def test_messages_non_stream(client, loaded_model, monkeypatch): +def test_messages_non_stream(client, loaded_model): """POST /v1/messages with stream=false returns Anthropic-style JSON.""" async def async_gen(): yield "Assistant reply" - mock_engine = loaded_model - mock_engine.generate_async.return_value = async_gen() - monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) + app.state.engine = loaded_model + loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/messages", json={ @@ -105,16 +104,15 @@ def test_messages_non_stream(client, loaded_model, monkeypatch): assert "input_tokens" in data["usage"] -def test_messages_stream(client, loaded_model, monkeypatch): +def test_messages_stream(client, loaded_model): """POST /v1/messages with stream=true returns Anthropic SSE stream.""" async def async_gen(): yield "cumulative1" yield "cumulative2" - mock_engine = loaded_model - mock_engine.generate_async.return_value = async_gen() - monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) + app.state.engine = loaded_model + loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/messages", json={ @@ -137,15 +135,14 @@ def test_messages_stream(client, loaded_model, monkeypatch): assert "message_stop" in content -def test_messages_with_system(client, loaded_model, monkeypatch): +def test_messages_with_system(client, loaded_model): """POST /v1/messages with system prompt.""" async def async_gen(): yield "Reply" - mock_engine = loaded_model - mock_engine.generate_async.return_value = async_gen() - monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) + app.state.engine = loaded_model + loaded_model.generate_async.return_value = async_gen() response = client.post( "/v1/messages", json={