AstrAI/astrai/inference/engine.py

312 lines
8.8 KiB
Python

"""Unified inference engine for continuous batching."""
import asyncio
import gc
import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
def __init__(self, count: int = 1):
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[Tuple[int, str]] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self) -> None:
with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]:
with self._cond:
return self.results.copy()
def _validate_params(top_k: int, top_p: float, temperature: float) -> None:
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
class GenerationRequest:
"""Request parameters for text generation."""
def __init__(
self,
messages: List[Dict[str, str]],
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
stream: bool = False,
):
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.stream = stream
_validate_params(top_k, top_p, temperature)
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler."""
def __init__(
self,
model: nn.Module,
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
page_size: int = 128,
):
self.model = model
self.tokenizer = tokenizer
self.scheduler = InferenceScheduler(
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len,
page_size=page_size,
)
self.scheduler.start()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.shutdown()
return False
def generate(
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
_validate_params(top_k, top_p, temperature)
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async(
self,
prompt: str,
max_tokens: int = 1024,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
loop = asyncio.get_event_loop()
while True:
token = await loop.run_in_executor(None, self._next_token, sync_gen)
if token is None:
break
yield token
return _agen()
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
try:
return next(gen)
except StopIteration:
return None
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.params.max_tokens,
temperature=request.params.temperature,
top_p=request.params.top_p,
top_k=request.params.top_k,
)
def _submit_tasks(
self,
prompts: List[str],
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Tuple[GenerateResult, List[str]]:
n = len(prompts)
result = GenerateResult(count=n)
task_ids = []
for i, p in enumerate(prompts):
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=cb,
)
task_ids.append(task_id)
return result, task_ids
@staticmethod
def _make_callback(result: GenerateResult, idx: int):
def cb(token):
result.append(token, idx)
return cb
def _generate_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Generator:
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
n = len(prompts)
remaining = n
finished = [False] * n
def gen():
nonlocal remaining
try:
while remaining > 0:
items = result.pop_all()
for idx, token in items:
if token is STOP:
if not finished[idx]:
finished[idx] = True
remaining -= 1
else:
yield (idx, token) if is_batch else token
if remaining > 0:
result.wait(timeout=0.05)
finally:
for tid in task_ids:
self.scheduler.remove_task(tid)
return gen()
def _generate_non_streaming(
self,
prompts: List[str],
is_batch: bool,
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
result.wait_completion()
for tid in task_ids:
self.scheduler.remove_task(tid)
res = result.get_results()
return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]:
return self.scheduler.get_stats()
def shutdown(self) -> None:
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()