AstrAI/astrai/inference/engine.py

474 lines
15 KiB
Python

"""Unified inference engine for continuous batching.
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
import asyncio
import gc
import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from astrai.inference.cache import STOP
from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
class GenerationRequest:
"""Request parameters for text generation.
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
def __init__(
self,
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
stream: bool = False,
):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.stream = stream
self._validate()
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self):
"""Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class _Result:
"""Thread-safe token accumulator for streaming and non-streaming modes.
Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Condition for efficient completion notification
and a threading.Event for streaming wakeup.
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated (idx, token) pairs.
Returns:
List of (index, token_string) tuples since the last call.
"""
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout)
def wait_completion(self) -> None:
"""Blocks until all tasks complete (non-streaming).
Uses a Condition to sleep efficiently instead of busy-waiting.
The calling thread is parked until a STOP signal arrives.
"""
with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._cond:
return self.results.copy()
class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
def __init__(
self,
model: nn.Module,
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
page_size: int = 128,
):
"""Initializes the inference engine.
Args:
model: The model instance.
tokenizer: The tokenizer instance.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
"""
self.model = model
self.tokenizer = tokenizer
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len,
page_size=page_size,
)
self.scheduler.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()
return False
def generate(
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
"""Generates text from a prompt.
Args:
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
stream=False, single prompt: str
stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
"""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
return None
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a structured GenerationRequest.
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.params.max_tokens,
temperature=request.params.temperature,
top_p=request.params.top_p,
top_k=request.params.top_k,
)
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Generator:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Single prompt yields raw token strings; batch yields (idx, token) tuples.
Args:
prompts: List of prompts.
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
"""
n = len(prompts)
result = _Result(count=n)
task_ids = []
for i, p in enumerate(prompts):
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok, idx=i: result.append(tok, idx),
)
task_ids.append(task_id)
remaining = n
finished = [False] * n
def gen():
nonlocal remaining
try:
while remaining > 0:
items = result.pop_all()
for idx, token in items:
if token is STOP:
if not finished[idx]:
finished[idx] = True
remaining -= 1
else:
yield (idx, token) if is_batch else token
if remaining > 0:
if not result.wait(timeout=0.05):
pass
finally:
for tid in task_ids:
self.scheduler.remove_task(tid)
return gen()
def _generate_non_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Internal non-streaming generator.
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts))
task_ids = []
for i, p in enumerate(prompts):
def make_cb(idx):
return lambda tok: result.append(tok, idx)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_cb(i),
)
task_ids.append(task_id)
result.wait_completion()
for task_id in task_ids:
self.scheduler.remove_task(task_id)
res = result.get_results()
return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]:
"""Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()