"""Unified inference engine for continuous batching. Layers: - GenerationParams: Immutable value object for sampling parameters. - GenerationRequest: User-facing request DTO with validation. - _Result: Thread-safe token accumulator (Observer pattern). - InferenceEngine: Facade over InferenceScheduler + async wrapper. """ import asyncio import gc import threading from dataclasses import dataclass from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union import torch import torch.nn as nn from astrai.inference.cache import STOP from astrai.inference.scheduler import InferenceScheduler from astrai.tokenize import AutoTokenizer @dataclass(frozen=True) class GenerationParams: """Immutable value object for sampling hyperparameters.""" top_k: int = 50 top_p: float = 1.0 temperature: float = 1.0 max_tokens: int = 1024 class GenerationRequest: """Request parameters for text generation. Encapsulates messages, sampling parameters (via GenerationParams), and streaming preference for a single generation request. """ def __init__( self, messages: List[Dict[str, str]], top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, max_len: int = 1024, stream: bool = False, ): """Initializes a generation request. Args: messages: Conversation history as list of {"role": ..., "content": ...}. top_k: Top-k sampling count (0 disables). top_p: Nucleus sampling probability threshold. temperature: Sampling temperature. max_len: Maximum tokens to generate. stream: Whether to return output as a token stream. """ self.messages = messages self.params = GenerationParams( top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_len, ) self.stream = stream self._validate() @property def top_k(self) -> int: return self.params.top_k @property def top_p(self) -> float: return self.params.top_p @property def temperature(self) -> float: return self.params.temperature @property def max_len(self) -> int: return self.params.max_tokens def _validate(self): """Validates sampling parameter ranges.""" if not (isinstance(self.top_k, int) and self.top_k >= 0): raise ValueError("top_k must be a non-negative integer") if not (0.0 <= self.top_p <= 1.0): raise ValueError("top_p must be a float between 0.0 and 1.0") if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0): raise ValueError("temperature must be a non-negative number") class _Result: """Thread-safe token accumulator for streaming and non-streaming modes. Supports multiple concurrent generation tasks with per-index result tracking. Uses a threading.Condition for efficient completion notification and a threading.Event for streaming wakeup. """ def __init__(self, count: int = 1): """Initializes the accumulator. Args: count: Number of concurrent generation tasks to track. """ self._cond = threading.Condition() self._event = threading.Event() self.tokens: List[str] = [] self.results: List[str] = [""] * count self._done: List[bool] = [False] * count self._completed = 0 self._total = count def append(self, token: str, idx: int = 0): """Appends a token to the result buffer. In non-streaming mode, tokens are concatenated into results[idx]. The sentinel STOP marks a task as complete. Args: token: The decoded token string, or STOP sentinel. idx: Index of the generation task this token belongs to. """ with self._cond: self.tokens.append((idx, token)) if token is not STOP: self.results[idx] += token else: if not self._done[idx]: self._done[idx] = True self._completed += 1 self._cond.notify_all() self._event.set() def pop_all(self) -> List[Tuple[int, str]]: """Returns and clears all accumulated (idx, token) pairs. Returns: List of (index, token_string) tuples since the last call. """ with self._cond: out = self.tokens.copy() self.tokens.clear() if not out: self._event.clear() return out def wait(self, timeout: Optional[float] = None) -> bool: """Blocks until new tokens arrive or the timeout expires. Args: timeout: Maximum wait time in seconds (None = infinite). Returns: True if the event was set (new data available), False on timeout. """ return self._event.wait(timeout=timeout) def wait_completion(self) -> None: """Blocks until all tasks complete (non-streaming). Uses a Condition to sleep efficiently instead of busy-waiting. The calling thread is parked until a STOP signal arrives. """ with self._cond: self._cond.wait_for(lambda: self._completed >= self._total) def get_results(self) -> List[str]: """Returns all accumulated results for non-streaming mode. Returns: List of complete generated strings, one per task index. """ with self._cond: return self.results.copy() class InferenceEngine: """Unified inference engine backed by continuous-batching scheduler. Usage: with InferenceEngine(model, tokenizer) as engine: for token in engine.generate("hello", stream=True): print(token, end="") text = engine.generate("hello") """ def __init__( self, model: nn.Module, tokenizer: AutoTokenizer, max_batch_size: int = 1, max_queue_size: int = 64, request_timeout: float = 60.0, max_seq_len: Optional[int] = None, max_prompt_len: int = 2048, page_size: int = 128, ): """Initializes the inference engine. Args: model: The model instance. tokenizer: The tokenizer instance. max_batch_size: Maximum number of concurrent tasks. max_seq_len: Maximum sequence length. max_prompt_len: Maximum prompt tokens. page_size: Number of tokens per KV cache page. """ self.model = model self.tokenizer = tokenizer self.scheduler = InferenceScheduler( model=self.model, tokenizer=self.tokenizer, max_batch_size=max_batch_size, max_queue_size=max_queue_size, request_timeout=request_timeout, max_seq_len=max_seq_len, max_prompt_len=max_prompt_len, page_size=page_size, ) self.scheduler.start() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown() return False def generate( self, prompt: Union[str, List[str]], stream: bool = False, max_tokens: int = 1024, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, timeout: Optional[float] = None, ) -> Union[Generator, str, List[str]]: """Generates text from a prompt. Args: prompt: Single string or list of strings for batch generation. stream: If True, returns a generator yielding tokens. max_tokens: Maximum number of tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling probability threshold. top_k: Top-k sampling count (0 disables). timeout: Per-request timeout in seconds (None = use scheduler default). Returns: stream=False, single prompt: str stream=False, batch: List[str] stream=True, single prompt: Generator[str, None, None] stream=True, batch: Generator[Tuple[int, str], None, None] """ is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] if stream: return self._generate_streaming( prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout ) else: return self._generate_non_streaming( prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout ) def generate_async( self, prompt: str, max_tokens: int = 1024, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, timeout: Optional[float] = None, ) -> AsyncGenerator[str, None]: """Async streaming generator that does not block the event loop. Runs the synchronous generator in a background thread pool executor, yielding tokens to the async consumer as they arrive. Args: prompt: Input text to generate from. max_tokens: Maximum tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling threshold. top_k: Top-k sampling count. timeout: Per-request timeout in seconds. Yields: Decoded token strings as they are generated. """ sync_gen = self._generate_streaming( [prompt], False, max_tokens, temperature, top_p, top_k, timeout ) async def _agen(): loop = asyncio.get_event_loop() while True: token = await loop.run_in_executor(None, self._next_token, sync_gen) if token is None: break yield token return _agen() @staticmethod def _next_token(gen: Generator) -> Optional[str]: """Retrieves the next token from a synchronous generator. Args: gen: A synchronous generator yielding token strings. Returns: The next token, or None if the generator is exhausted. """ try: return next(gen) except StopIteration: return None def generate_with_request( self, request: GenerationRequest ) -> Union[Generator[str, None, None], str, List[str]]: """Generates text from a structured GenerationRequest. Applies the chat template to the request's messages before generation. Args: request: A GenerationRequest with messages and parameters. Returns: Generator, string, or list of strings (see generate()). """ prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) return self.generate( prompt=prompt, stream=request.stream, max_tokens=request.params.max_tokens, temperature=request.params.temperature, top_p=request.params.top_p, top_k=request.params.top_k, ) def _generate_streaming( self, prompts: List[str], is_batch: bool, max_tokens: int, temperature: float, top_p: float, top_k: int, timeout: Optional[float] = None, ) -> Generator: """Internal streaming generator. Polls the _Result accumulator in a loop, yielding tokens as they arrive. Single prompt yields raw token strings; batch yields (idx, token) tuples. Args: prompts: List of prompts. is_batch: If True, yields (idx, token) tuples; else yields raw tokens. max_tokens: Maximum tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling threshold. top_k: Top-k sampling count. timeout: Per-request timeout in seconds. Yields: Single prompt: decoded token strings. Batch: (sequence_index, token_string) tuples. """ n = len(prompts) result = _Result(count=n) task_ids = [] try: for i, p in enumerate(prompts): task_id = self.scheduler.add_task( prompt=p, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, stream_callback=lambda tok, idx=i: result.append(tok, idx), timeout=timeout, ) task_ids.append(task_id) except RuntimeError: for tid in task_ids: self.scheduler.remove_task(tid) raise remaining = n finished = [False] * n def gen(): nonlocal remaining try: while remaining > 0: items = result.pop_all() for idx, token in items: if token is STOP: if not finished[idx]: finished[idx] = True remaining -= 1 else: yield (idx, token) if is_batch else token if remaining > 0: if not result.wait(timeout=0.05): pass finally: for tid in task_ids: self.scheduler.remove_task(tid) return gen() def _generate_non_streaming( self, prompts: List[str], is_batch: bool, max_tokens: int, temperature: float, top_p: float, top_k: int, timeout: Optional[float] = None, ) -> Union[str, List[str]]: """Internal non-streaming generator. Submits all prompts to the scheduler and waits for all to complete. Args: prompts: List of prompt strings. is_batch: Whether multiple prompts were provided. max_tokens: Maximum tokens to generate. temperature: Sampling temperature. top_p: Nucleus sampling threshold. top_k: Top-k sampling count. timeout: Per-request timeout in seconds. Returns: Single string for one prompt, list of strings for batch. """ result = _Result(count=len(prompts)) task_ids = [] try: for i, p in enumerate(prompts): def make_cb(idx): return lambda tok: result.append(tok, idx) task_id = self.scheduler.add_task( prompt=p, max_tokens=max_tokens, temperature=temperature, top_p=top_p, top_k=top_k, stream_callback=make_cb(i), timeout=timeout, ) task_ids.append(task_id) except RuntimeError: for tid in task_ids: self.scheduler.remove_task(tid) raise result.wait_completion() for task_id in task_ids: self.scheduler.remove_task(task_id) res = result.get_results() return res if is_batch else res[0] def get_stats(self) -> Dict[str, Any]: """Returns current engine statistics. Returns: Dict with total_tasks, total_tokens, active_tasks, waiting_queue. """ return self.scheduler.get_stats() def shutdown(self) -> None: """Shuts down the engine, stops the scheduler, and frees GPU memory.""" self.scheduler.stop() if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect()