497 lines
16 KiB
Python
497 lines
16 KiB
Python
"""Unified inference engine for continuous batching.
|
|
|
|
Layers:
|
|
- GenerationParams: Immutable value object for sampling parameters.
|
|
- GenerationRequest: User-facing request DTO with validation.
|
|
- _Result: Thread-safe token accumulator (Observer pattern).
|
|
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
|
|
"""
|
|
|
|
import asyncio
|
|
import gc
|
|
import threading
|
|
from dataclasses import dataclass
|
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from astrai.inference.cache import STOP
|
|
from astrai.inference.scheduler import InferenceScheduler
|
|
from astrai.tokenize import AutoTokenizer
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GenerationParams:
|
|
"""Immutable value object for sampling hyperparameters."""
|
|
|
|
top_k: int = 50
|
|
top_p: float = 1.0
|
|
temperature: float = 1.0
|
|
max_tokens: int = 1024
|
|
|
|
|
|
class GenerationRequest:
|
|
"""Request parameters for text generation.
|
|
|
|
Encapsulates messages, sampling parameters (via GenerationParams),
|
|
and streaming preference for a single generation request.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
messages: List[Dict[str, str]],
|
|
top_k: int = 50,
|
|
top_p: float = 1.0,
|
|
temperature: float = 1.0,
|
|
max_len: int = 1024,
|
|
stream: bool = False,
|
|
):
|
|
"""Initializes a generation request.
|
|
|
|
Args:
|
|
messages: Conversation history as list of {"role": ..., "content": ...}.
|
|
top_k: Top-k sampling count (0 disables).
|
|
top_p: Nucleus sampling probability threshold.
|
|
temperature: Sampling temperature.
|
|
max_len: Maximum tokens to generate.
|
|
stream: Whether to return output as a token stream.
|
|
"""
|
|
self.messages = messages
|
|
self.params = GenerationParams(
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
temperature=temperature,
|
|
max_tokens=max_len,
|
|
)
|
|
self.stream = stream
|
|
self._validate()
|
|
|
|
@property
|
|
def top_k(self) -> int:
|
|
return self.params.top_k
|
|
|
|
@property
|
|
def top_p(self) -> float:
|
|
return self.params.top_p
|
|
|
|
@property
|
|
def temperature(self) -> float:
|
|
return self.params.temperature
|
|
|
|
@property
|
|
def max_len(self) -> int:
|
|
return self.params.max_tokens
|
|
|
|
def _validate(self):
|
|
"""Validates sampling parameter ranges."""
|
|
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
|
raise ValueError("top_k must be a non-negative integer")
|
|
if not (0.0 <= self.top_p <= 1.0):
|
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
|
raise ValueError("temperature must be a non-negative number")
|
|
|
|
|
|
class _Result:
|
|
"""Thread-safe token accumulator for streaming and non-streaming modes.
|
|
|
|
Supports multiple concurrent generation tasks with per-index result tracking.
|
|
Uses a threading.Condition for efficient completion notification
|
|
and a threading.Event for streaming wakeup.
|
|
"""
|
|
|
|
def __init__(self, count: int = 1):
|
|
"""Initializes the accumulator.
|
|
|
|
Args:
|
|
count: Number of concurrent generation tasks to track.
|
|
"""
|
|
self._cond = threading.Condition()
|
|
self._event = threading.Event()
|
|
self.tokens: List[str] = []
|
|
self.results: List[str] = [""] * count
|
|
self._done: List[bool] = [False] * count
|
|
self._completed = 0
|
|
self._total = count
|
|
|
|
def append(self, token: str, idx: int = 0):
|
|
"""Appends a token to the result buffer.
|
|
|
|
In non-streaming mode, tokens are concatenated into results[idx].
|
|
The sentinel STOP marks a task as complete.
|
|
|
|
Args:
|
|
token: The decoded token string, or STOP sentinel.
|
|
idx: Index of the generation task this token belongs to.
|
|
"""
|
|
with self._cond:
|
|
self.tokens.append((idx, token))
|
|
if token is not STOP:
|
|
self.results[idx] += token
|
|
else:
|
|
if not self._done[idx]:
|
|
self._done[idx] = True
|
|
self._completed += 1
|
|
self._cond.notify_all()
|
|
self._event.set()
|
|
|
|
def pop_all(self) -> List[Tuple[int, str]]:
|
|
"""Returns and clears all accumulated (idx, token) pairs.
|
|
|
|
Returns:
|
|
List of (index, token_string) tuples since the last call.
|
|
"""
|
|
with self._cond:
|
|
out = self.tokens.copy()
|
|
self.tokens.clear()
|
|
if not out:
|
|
self._event.clear()
|
|
return out
|
|
|
|
def wait(self, timeout: Optional[float] = None) -> bool:
|
|
"""Blocks until new tokens arrive or the timeout expires.
|
|
|
|
Args:
|
|
timeout: Maximum wait time in seconds (None = infinite).
|
|
|
|
Returns:
|
|
True if the event was set (new data available), False on timeout.
|
|
"""
|
|
return self._event.wait(timeout=timeout)
|
|
|
|
def wait_completion(self) -> None:
|
|
"""Blocks until all tasks complete (non-streaming).
|
|
|
|
Uses a Condition to sleep efficiently instead of busy-waiting.
|
|
The calling thread is parked until a STOP signal arrives.
|
|
"""
|
|
with self._cond:
|
|
self._cond.wait_for(lambda: self._completed >= self._total)
|
|
|
|
def get_results(self) -> List[str]:
|
|
"""Returns all accumulated results for non-streaming mode.
|
|
|
|
Returns:
|
|
List of complete generated strings, one per task index.
|
|
"""
|
|
with self._cond:
|
|
return self.results.copy()
|
|
|
|
|
|
class InferenceEngine:
|
|
"""Unified inference engine backed by continuous-batching scheduler.
|
|
|
|
Usage:
|
|
with InferenceEngine(model, tokenizer) as engine:
|
|
for token in engine.generate("hello", stream=True):
|
|
print(token, end="")
|
|
|
|
text = engine.generate("hello")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: nn.Module,
|
|
tokenizer: AutoTokenizer,
|
|
max_batch_size: int = 1,
|
|
max_queue_size: int = 64,
|
|
request_timeout: float = 60.0,
|
|
max_seq_len: Optional[int] = None,
|
|
max_prompt_len: int = 2048,
|
|
page_size: int = 128,
|
|
):
|
|
"""Initializes the inference engine.
|
|
|
|
Args:
|
|
model: The model instance.
|
|
tokenizer: The tokenizer instance.
|
|
max_batch_size: Maximum number of concurrent tasks.
|
|
max_seq_len: Maximum sequence length.
|
|
max_prompt_len: Maximum prompt tokens.
|
|
page_size: Number of tokens per KV cache page.
|
|
"""
|
|
self.model = model
|
|
self.tokenizer = tokenizer
|
|
self.scheduler = InferenceScheduler(
|
|
model=self.model,
|
|
tokenizer=self.tokenizer,
|
|
max_batch_size=max_batch_size,
|
|
max_queue_size=max_queue_size,
|
|
request_timeout=request_timeout,
|
|
max_seq_len=max_seq_len,
|
|
max_prompt_len=max_prompt_len,
|
|
page_size=page_size,
|
|
)
|
|
|
|
self.scheduler.start()
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.shutdown()
|
|
return False
|
|
|
|
def generate(
|
|
self,
|
|
prompt: Union[str, List[str]],
|
|
stream: bool = False,
|
|
max_tokens: int = 1024,
|
|
temperature: float = 1.0,
|
|
top_p: float = 1.0,
|
|
top_k: int = 50,
|
|
timeout: Optional[float] = None,
|
|
) -> Union[Generator, str, List[str]]:
|
|
"""Generates text from a prompt.
|
|
|
|
Args:
|
|
prompt: Single string or list of strings for batch generation.
|
|
stream: If True, returns a generator yielding tokens.
|
|
max_tokens: Maximum number of tokens to generate.
|
|
temperature: Sampling temperature.
|
|
top_p: Nucleus sampling probability threshold.
|
|
top_k: Top-k sampling count (0 disables).
|
|
timeout: Per-request timeout in seconds (None = use scheduler default).
|
|
|
|
Returns:
|
|
stream=False, single prompt: str
|
|
stream=False, batch: List[str]
|
|
stream=True, single prompt: Generator[str, None, None]
|
|
stream=True, batch: Generator[Tuple[int, str], None, None]
|
|
"""
|
|
is_batch = isinstance(prompt, list)
|
|
prompts = prompt if is_batch else [prompt]
|
|
|
|
if stream:
|
|
return self._generate_streaming(
|
|
prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout
|
|
)
|
|
else:
|
|
return self._generate_non_streaming(
|
|
prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout
|
|
)
|
|
|
|
def generate_async(
|
|
self,
|
|
prompt: str,
|
|
max_tokens: int = 1024,
|
|
temperature: float = 1.0,
|
|
top_p: float = 1.0,
|
|
top_k: int = 50,
|
|
timeout: Optional[float] = None,
|
|
) -> AsyncGenerator[str, None]:
|
|
"""Async streaming generator that does not block the event loop.
|
|
|
|
Runs the synchronous generator in a background thread pool executor,
|
|
yielding tokens to the async consumer as they arrive.
|
|
|
|
Args:
|
|
prompt: Input text to generate from.
|
|
max_tokens: Maximum tokens to generate.
|
|
temperature: Sampling temperature.
|
|
top_p: Nucleus sampling threshold.
|
|
top_k: Top-k sampling count.
|
|
timeout: Per-request timeout in seconds.
|
|
|
|
Yields:
|
|
Decoded token strings as they are generated.
|
|
"""
|
|
sync_gen = self._generate_streaming(
|
|
[prompt], False, max_tokens, temperature, top_p, top_k, timeout
|
|
)
|
|
|
|
async def _agen():
|
|
loop = asyncio.get_event_loop()
|
|
while True:
|
|
token = await loop.run_in_executor(None, self._next_token, sync_gen)
|
|
if token is None:
|
|
break
|
|
yield token
|
|
|
|
return _agen()
|
|
|
|
@staticmethod
|
|
def _next_token(gen: Generator) -> Optional[str]:
|
|
"""Retrieves the next token from a synchronous generator.
|
|
|
|
Args:
|
|
gen: A synchronous generator yielding token strings.
|
|
|
|
Returns:
|
|
The next token, or None if the generator is exhausted.
|
|
"""
|
|
try:
|
|
return next(gen)
|
|
except StopIteration:
|
|
return None
|
|
|
|
def generate_with_request(
|
|
self, request: GenerationRequest
|
|
) -> Union[Generator[str, None, None], str, List[str]]:
|
|
"""Generates text from a structured GenerationRequest.
|
|
|
|
Applies the chat template to the request's messages before generation.
|
|
|
|
Args:
|
|
request: A GenerationRequest with messages and parameters.
|
|
|
|
Returns:
|
|
Generator, string, or list of strings (see generate()).
|
|
"""
|
|
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
|
|
return self.generate(
|
|
prompt=prompt,
|
|
stream=request.stream,
|
|
max_tokens=request.params.max_tokens,
|
|
temperature=request.params.temperature,
|
|
top_p=request.params.top_p,
|
|
top_k=request.params.top_k,
|
|
)
|
|
|
|
def _generate_streaming(
|
|
self,
|
|
prompts: List[str],
|
|
is_batch: bool,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
top_p: float,
|
|
top_k: int,
|
|
timeout: Optional[float] = None,
|
|
) -> Generator:
|
|
"""Internal streaming generator.
|
|
|
|
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
|
|
Single prompt yields raw token strings; batch yields (idx, token) tuples.
|
|
|
|
Args:
|
|
prompts: List of prompts.
|
|
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
|
|
max_tokens: Maximum tokens to generate.
|
|
temperature: Sampling temperature.
|
|
top_p: Nucleus sampling threshold.
|
|
top_k: Top-k sampling count.
|
|
timeout: Per-request timeout in seconds.
|
|
|
|
Yields:
|
|
Single prompt: decoded token strings.
|
|
Batch: (sequence_index, token_string) tuples.
|
|
"""
|
|
n = len(prompts)
|
|
result = _Result(count=n)
|
|
task_ids = []
|
|
|
|
try:
|
|
for i, p in enumerate(prompts):
|
|
task_id = self.scheduler.add_task(
|
|
prompt=p,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
stream_callback=lambda tok, idx=i: result.append(tok, idx),
|
|
timeout=timeout,
|
|
)
|
|
task_ids.append(task_id)
|
|
except RuntimeError:
|
|
for tid in task_ids:
|
|
self.scheduler.remove_task(tid)
|
|
raise
|
|
|
|
remaining = n
|
|
finished = [False] * n
|
|
|
|
def gen():
|
|
nonlocal remaining
|
|
try:
|
|
while remaining > 0:
|
|
items = result.pop_all()
|
|
for idx, token in items:
|
|
if token is STOP:
|
|
if not finished[idx]:
|
|
finished[idx] = True
|
|
remaining -= 1
|
|
else:
|
|
yield (idx, token) if is_batch else token
|
|
if remaining > 0:
|
|
if not result.wait(timeout=0.05):
|
|
pass
|
|
finally:
|
|
for tid in task_ids:
|
|
self.scheduler.remove_task(tid)
|
|
|
|
return gen()
|
|
|
|
def _generate_non_streaming(
|
|
self,
|
|
prompts: List[str],
|
|
is_batch: bool,
|
|
max_tokens: int,
|
|
temperature: float,
|
|
top_p: float,
|
|
top_k: int,
|
|
timeout: Optional[float] = None,
|
|
) -> Union[str, List[str]]:
|
|
"""Internal non-streaming generator.
|
|
|
|
Submits all prompts to the scheduler and waits for all to complete.
|
|
|
|
Args:
|
|
prompts: List of prompt strings.
|
|
is_batch: Whether multiple prompts were provided.
|
|
max_tokens: Maximum tokens to generate.
|
|
temperature: Sampling temperature.
|
|
top_p: Nucleus sampling threshold.
|
|
top_k: Top-k sampling count.
|
|
timeout: Per-request timeout in seconds.
|
|
|
|
Returns:
|
|
Single string for one prompt, list of strings for batch.
|
|
"""
|
|
result = _Result(count=len(prompts))
|
|
task_ids = []
|
|
|
|
try:
|
|
for i, p in enumerate(prompts):
|
|
|
|
def make_cb(idx):
|
|
return lambda tok: result.append(tok, idx)
|
|
|
|
task_id = self.scheduler.add_task(
|
|
prompt=p,
|
|
max_tokens=max_tokens,
|
|
temperature=temperature,
|
|
top_p=top_p,
|
|
top_k=top_k,
|
|
stream_callback=make_cb(i),
|
|
timeout=timeout,
|
|
)
|
|
task_ids.append(task_id)
|
|
except RuntimeError:
|
|
for tid in task_ids:
|
|
self.scheduler.remove_task(tid)
|
|
raise
|
|
|
|
result.wait_completion()
|
|
|
|
for task_id in task_ids:
|
|
self.scheduler.remove_task(task_id)
|
|
|
|
res = result.get_results()
|
|
return res if is_batch else res[0]
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Returns current engine statistics.
|
|
|
|
Returns:
|
|
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
|
|
"""
|
|
return self.scheduler.get_stats()
|
|
|
|
def shutdown(self) -> None:
|
|
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
|
|
self.scheduler.stop()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|