chore: 解耦 Executor/Scheduler/TaskManager,修复 stop 页泄漏,移除 ServerState 全局单例

This commit is contained in:
ViperEkura 2026-05-12 13:44:55 +08:00
parent 7440e9c809
commit df0845e916
10 changed files with 336 additions and 509 deletions

View File

@ -104,8 +104,7 @@ class PrefixCache:
class TaskTable: class TaskTable:
"""Maps task_ids to page tables and cached token counts.""" """Maps task_ids to page tables and cached token counts."""
def __init__(self, pool: PagePool, page_size: int): def __init__(self, page_size: int):
self._pool = pool
self._page_size = page_size self._page_size = page_size
self._pages: Dict[str, List[int]] = {} self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {} self._cached: Dict[str, int] = {}
@ -125,15 +124,8 @@ class TaskTable:
cached = self._cached.pop(task_id, 0) cached = self._cached.pop(task_id, 0)
return pages, cached return pages, cached
def extend(self, task_id: str, pos: int) -> bool: def get_ref(self, task_id: str) -> List[int]:
page_table = self._pages[task_id] return self._pages.setdefault(task_id, [])
needed = (pos + 1 + self._page_size - 1) // self._page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
states = [self._pages.get(tid, []) for tid in task_ids] states = [self._pages.get(tid, []) for tid in task_ids]
@ -158,7 +150,7 @@ class PagedCache:
self.page_size = page_size self.page_size = page_size
self._prefix = PrefixCache(page_size) self._prefix = PrefixCache(page_size)
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict) self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
self._table = TaskTable(self._pool, page_size) self._table = TaskTable(page_size)
self.k_cache = torch.empty( self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim), (n_layers, n_pages, page_size, n_kv_heads, head_dim),
@ -219,7 +211,14 @@ class PagedCache:
self.free(idx) self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool: def task_extend(self, task_id: str, pos: int) -> bool:
return self._table.extend(task_id, pos) page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int: def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id) return self._table.get_cached(task_id)

View File

@ -1,11 +1,4 @@
"""Unified inference engine for continuous batching. """Unified inference engine for continuous batching."""
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
import asyncio import asyncio
import gc import gc
@ -21,6 +14,59 @@ from astrai.inference.task import STOP
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
def __init__(self, count: int = 1):
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[Tuple[int, str]] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self) -> None:
with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]:
with self._cond:
return self.results.copy()
def _validate_params(top_k: int, top_p: float, temperature: float) -> None:
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
@dataclass(frozen=True) @dataclass(frozen=True)
class GenerationParams: class GenerationParams:
"""Immutable value object for sampling hyperparameters.""" """Immutable value object for sampling hyperparameters."""
@ -32,11 +78,7 @@ class GenerationParams:
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation. """Request parameters for text generation."""
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
def __init__( def __init__(
self, self,
@ -47,16 +89,6 @@ class GenerationRequest:
max_len: int = 1024, max_len: int = 1024,
stream: bool = False, stream: bool = False,
): ):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages self.messages = messages
self.params = GenerationParams( self.params = GenerationParams(
top_k=top_k, top_k=top_k,
@ -65,7 +97,7 @@ class GenerationRequest:
max_tokens=max_len, max_tokens=max_len,
) )
self.stream = stream self.stream = stream
self._validate() _validate_params(top_k, top_p, temperature)
@property @property
def top_k(self) -> int: def top_k(self) -> int:
@ -83,112 +115,9 @@ class GenerationRequest:
def max_len(self) -> int: def max_len(self) -> int:
return self.params.max_tokens return self.params.max_tokens
def _validate(self):
"""Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class _Result:
"""Thread-safe token accumulator for streaming and non-streaming modes.
Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Condition for efficient completion notification
and a threading.Event for streaming wakeup.
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated (idx, token) pairs.
Returns:
List of (index, token_string) tuples since the last call.
"""
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout)
def wait_completion(self) -> None:
"""Blocks until all tasks complete (non-streaming).
Uses a Condition to sleep efficiently instead of busy-waiting.
The calling thread is parked until a STOP signal arrives.
"""
with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._cond:
return self.results.copy()
class InferenceEngine: class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler. """Unified inference engine backed by continuous-batching scheduler."""
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
def __init__( def __init__(
self, self,
@ -199,17 +128,6 @@ class InferenceEngine:
max_prompt_len: int = 2048, max_prompt_len: int = 2048,
page_size: int = 128, page_size: int = 128,
): ):
"""Initializes the inference engine.
Args:
model: The model instance.
tokenizer: The tokenizer instance.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
"""
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.scheduler = InferenceScheduler( self.scheduler = InferenceScheduler(
@ -239,22 +157,8 @@ class InferenceEngine:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> Union[Generator, str, List[str]]: ) -> Union[Generator, str, List[str]]:
"""Generates text from a prompt. _validate_params(top_k, top_p, temperature)
Args:
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
stream=False, single prompt: str
stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
"""
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
@ -275,21 +179,6 @@ class InferenceEngine:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming( sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k [prompt], False, max_tokens, temperature, top_p, top_k
) )
@ -306,14 +195,6 @@ class InferenceEngine:
@staticmethod @staticmethod
def _next_token(gen: Generator) -> Optional[str]: def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try: try:
return next(gen) return next(gen)
except StopIteration: except StopIteration:
@ -322,16 +203,6 @@ class InferenceEngine:
def generate_with_request( def generate_with_request(
self, request: GenerationRequest self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]: ) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a structured GenerationRequest.
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False) prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate( return self.generate(
prompt=prompt, prompt=prompt,
@ -342,6 +213,37 @@ class InferenceEngine:
top_k=request.params.top_k, top_k=request.params.top_k,
) )
def _submit_tasks(
self,
prompts: List[str],
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Tuple[GenerateResult, List[str]]:
n = len(prompts)
result = GenerateResult(count=n)
task_ids = []
for i, p in enumerate(prompts):
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=cb,
)
task_ids.append(task_id)
return result, task_ids
@staticmethod
def _make_callback(result: GenerateResult, idx: int):
def cb(token):
result.append(token, idx)
return cb
def _generate_streaming( def _generate_streaming(
self, self,
prompts: List[str], prompts: List[str],
@ -351,38 +253,10 @@ class InferenceEngine:
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Generator: ) -> Generator:
"""Internal streaming generator. result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Single prompt yields raw token strings; batch yields (idx, token) tuples.
Args:
prompts: List of prompts.
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
"""
n = len(prompts)
result = _Result(count=n)
task_ids = []
for i, p in enumerate(prompts):
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok, idx=i: result.append(tok, idx),
) )
task_ids.append(task_id) n = len(prompts)
remaining = n remaining = n
finished = [False] * n finished = [False] * n
@ -399,8 +273,7 @@ class InferenceEngine:
else: else:
yield (idx, token) if is_batch else token yield (idx, token) if is_batch else token
if remaining > 0: if remaining > 0:
if not result.wait(timeout=0.05): result.wait(timeout=0.05)
pass
finally: finally:
for tid in task_ids: for tid in task_ids:
self.scheduler.remove_task(tid) self.scheduler.remove_task(tid)
@ -416,57 +289,22 @@ class InferenceEngine:
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Internal non-streaming generator. result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts))
task_ids = []
for i, p in enumerate(prompts):
def make_cb(idx):
return lambda tok: result.append(tok, idx)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_cb(i),
) )
task_ids.append(task_id)
result.wait_completion() result.wait_completion()
for task_id in task_ids: for tid in task_ids:
self.scheduler.remove_task(task_id) self.scheduler.remove_task(tid)
res = result.get_results() res = result.get_results()
return res if is_batch else res[0] return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats() return self.scheduler.get_stats()
def shutdown(self) -> None: def shutdown(self) -> None:
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop() self.scheduler.stop()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -5,7 +5,7 @@ import torch
from astrai.inference.cache import PagedCache from astrai.inference.cache import PagedCache
from astrai.inference.sample import sample from astrai.inference.sample import sample
from astrai.inference.task import STOP, Task, TaskStatus from astrai.inference.task import Task
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
@ -60,31 +60,10 @@ class Executor:
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
) )
start_logical_page = start_pos // self.page_cache.page_size def execute_decode(self, tasks: List[Task], start_pos: int) -> List[int]:
for t in tasks:
self.page_cache.task_record_hashes(
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks: if not tasks:
return return []
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self.page_cache.task_extend(t.task_id, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks) batch_sz = len(tasks)
input_ids = torch.tensor( input_ids = torch.tensor(
@ -112,22 +91,9 @@ class Executor:
) )
logits = outputs["logits"][:, -1, :] logits = outputs["logits"][:, -1, :]
next_tokens = sample( return sample(
logits, logits,
temperature=temperatures, temperature=temperatures,
top_k=top_ks, top_k=top_ks,
top_p=top_ps, top_p=top_ps,
).tolist() ).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self.page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)

View File

@ -6,7 +6,7 @@ import torch
from astrai.inference.cache import PagedCache from astrai.inference.cache import PagedCache
from astrai.inference.executor import Executor from astrai.inference.executor import Executor
from astrai.inference.task import STOP, Task, TaskManager from astrai.inference.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer from astrai.tokenize.tokenizer import AutoTokenizer
@ -75,17 +75,15 @@ class InferenceScheduler:
return self._task_mgr.get_stats() return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
stop_ids = self._task_mgr.tokenizer.stop_ids
try: try:
while self._running: while self._running:
finished = self._task_mgr.remove_finished_tasks( finished = self._task_mgr.remove_finished_tasks(stop_ids)
self._task_mgr.tokenizer.stop_ids
)
for task in finished: for task in finished:
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
available = self._task_mgr.max_batch_size - len( active = self._task_mgr.get_active_tasks()
self._task_mgr.active_tasks available = self._task_mgr.max_batch_size - len(active)
)
if available > 0: if available > 0:
candidates = self._task_mgr.pull_candidates(available) candidates = self._task_mgr.pull_candidates(available)
failed = [] failed = []
@ -102,7 +100,7 @@ class InferenceScheduler:
continue continue
to_prefill = [ to_prefill = [
t for t in self._task_mgr.active_tasks if t.output_tokens == 0 t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
] ]
if to_prefill: if to_prefill:
for t in to_prefill: for t in to_prefill:
@ -118,23 +116,56 @@ class InferenceScheduler:
for (prompt_len, start_pos), group in groups.items(): for (prompt_len, start_pos), group in groups.items():
self._executor.execute_prefill(group, prompt_len, start_pos) self._executor.execute_prefill(group, prompt_len, start_pos)
start_logical_page = start_pos // self._page_cache.page_size
for t in group:
self._page_cache.task_record_hashes(
t.task_id,
t.prompt_ids,
start_logical_page=start_logical_page,
)
pos_groups: Dict[int, List[Task]] = {} pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.active_tasks: for t in self._task_mgr.get_active_tasks():
pos_groups.setdefault(t.next_pos, []).append(t) pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups: if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p])) best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._executor.execute_decode(pos_groups[best_pos], best_pos) group = sorted(pos_groups[best_pos], key=lambda t: t.task_id)
valid: List[Task] = []
for t in group:
if self._page_cache.task_extend(t.task_id, best_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if valid:
next_tokens = self._executor.execute_decode(valid, best_pos)
for t, ntok in zip(valid, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
for t in valid:
if t.is_finished(stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
except Exception as e: except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True) logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.active_tasks: for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
for task in self._task_mgr.waiting_queue:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
self._task_mgr.clear_queues()
raise raise
def start(self) -> None: def start(self) -> None:
@ -149,7 +180,8 @@ class InferenceScheduler:
self._task_mgr.wake() self._task_mgr.wake()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=2.0)
self._task_mgr.waiting_queue.clear() for task in self._task_mgr.get_active_tasks():
self._task_mgr.active_tasks.clear() self._page_cache.task_free(task.task_id)
self._task_mgr.clear_queues()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -25,20 +25,6 @@ logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent _project_root = Path(__file__).parent.parent.parent
class ServerState:
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
_state = ServerState()
class ChatMessage(BaseModel): class ChatMessage(BaseModel):
role: str role: str
content: str content: str
@ -81,47 +67,12 @@ class MessagesRequest(BaseModel):
stop_sequences: Optional[List[str]] = None stop_sequences: Optional[List[str]] = None
def configure_server( def _create_engine(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
_state.config.update(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
load_model(
param_path=_state.config["param_path"],
device=_state.config["device"],
dtype=_state.config["dtype"],
max_batch_size=_state.config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if _state.engine:
_state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def load_model(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
device: str = "cuda", device: str = "cuda",
dtype: torch.dtype = torch.bfloat16, dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16, max_batch_size: int = 16,
): ) -> InferenceEngine:
if param_path is None: if param_path is None:
param_path = _project_root / "params" param_path = _project_root / "params"
if not param_path.exists(): if not param_path.exists():
@ -132,18 +83,38 @@ def load_model(
model.to(device=device, dtype=dtype) model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}") logger.info(f"Model loaded on {device} with dtype {dtype}")
_state.engine = InferenceEngine( engine = InferenceEngine(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
) )
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}") logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
def _get_engine() -> InferenceEngine: @asynccontextmanager
if _state.engine is None: async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized") raise HTTPException(status_code=503, detail="Engine not initialized")
return _state.engine return engine
def _make_chunk( def _make_chunk(
@ -155,7 +126,6 @@ def _make_chunk(
model: str, model: str,
index: int = 0, index: int = 0,
) -> str: ) -> str:
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
data = { data = {
"id": resp_id, "id": resp_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@ -172,23 +142,56 @@ def _make_chunk(
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n" return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
for seq in stop_sequences:
if seq and seq in text:
return seq
return None
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
def _build_anthropic_messages(
messages: List[AnthropicMessage], system: Optional[str]
) -> List[Dict[str, str]]:
result: List[Dict[str, str]] = []
if system:
result.append({"role": "system", "content": system})
for m in messages:
content = _extract_text_content(m.content)
if content:
result.append({"role": m.role, "content": content})
return result
@app.get("/health") @app.get("/health")
async def health(): async def health(request: Request):
return { return {
"status": "ok", "status": "ok",
"model_loaded": _state.engine is not None, "model_loaded": request.app.state.engine is not None,
} }
@app.get("/stats") @app.get("/stats")
async def get_stats(): async def get_stats(request: Request):
return _get_engine().get_stats() return _get_engine(request).get_stats()
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest): async def chat_completion(request: ChatCompletionRequest, req: Request):
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming).""" engine = _get_engine(req)
engine = _get_engine()
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}" resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time()) created = int(time.time())
model = request.model model = request.model
@ -284,44 +287,9 @@ async def chat_completion(request: ChatCompletionRequest):
} }
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
for seq in stop_sequences:
if seq and seq in text:
return seq
return None
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
def _build_anthropic_messages(
messages: List[AnthropicMessage], system: Optional[str]
) -> List[Dict[str, str]]:
result: List[Dict[str, str]] = []
if system:
result.append({"role": "system", "content": system})
for m in messages:
content = _extract_text_content(m.content)
if content:
result.append({"role": m.role, "content": content})
return result
@app.post("/v1/messages") @app.post("/v1/messages")
async def create_message(request: MessagesRequest): async def create_message(request: MessagesRequest, req: Request):
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming).""" engine = _get_engine(req)
engine = _get_engine()
resp_id = f"msg_{uuid.uuid4().hex[:24]}" resp_id = f"msg_{uuid.uuid4().hex[:24]}"
model = request.model model = request.model
@ -472,12 +440,12 @@ def run_server(
param_path: Optional[Path] = None, param_path: Optional[Path] = None,
max_batch_size: int = 16, max_batch_size: int = 16,
): ):
configure_server( app.state.server_config = {
device=device, "device": device,
dtype=dtype, "dtype": dtype,
param_path=param_path, "param_path": param_path,
max_batch_size=max_batch_size, "max_batch_size": max_batch_size,
) }
uvicorn.run( uvicorn.run(
"astrai.inference.server:app", "astrai.inference.server:app",
host=host, host=host,

View File

@ -139,6 +139,7 @@ class TaskManager:
} }
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]: def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
with self._lock:
finished = [] finished = []
for task in self.active_tasks: for task in self.active_tasks:
if task.status == TaskStatus.ABORTED: if task.status == TaskStatus.ABORTED:
@ -180,5 +181,14 @@ class TaskManager:
self._task_event.clear() self._task_event.clear()
self._task_event.wait(timeout=timeout) self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]:
with self._lock:
return list(self.active_tasks)
def clear_queues(self) -> None:
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self) -> None: def wake(self) -> None:
self._task_event.set() self._task_event.set()

View File

@ -11,6 +11,14 @@ from astrai.inference.server import app
@pytest.fixture @pytest.fixture
def client(): def client():
"""Provide a test client for the FastAPI app.""" """Provide a test client for the FastAPI app."""
app.state.server_config = {
"device": "cpu",
"dtype": "bfloat16",
"param_path": None,
"max_batch_size": 1,
"_test": True,
}
app.state.engine = None
return TestClient(app) return TestClient(app)
@ -39,7 +47,7 @@ def mock_engine():
@pytest.fixture @pytest.fixture
def loaded_model(mock_engine, monkeypatch): def loaded_model(client, mock_engine):
"""Simulate that the engine is loaded.""" """Simulate that the engine is loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine) app.state.engine = mock_engine
return mock_engine return mock_engine

View File

@ -162,23 +162,20 @@ def test_prefix_cache_has_page():
def test_task_table_set_get(): def test_task_table_set_get():
pool = PagePool(n_pages=8) table = TaskTable(page_size=64)
table = TaskTable(pool, page_size=64)
table.set("task1", [0, 1, 2], 128) table.set("task1", [0, 1, 2], 128)
assert table.get("task1") == [0, 1, 2] assert table.get("task1") == [0, 1, 2]
assert table.get_cached("task1") == 128 assert table.get_cached("task1") == 128
def test_task_table_get_missing(): def test_task_table_get_missing():
pool = PagePool(n_pages=8) table = TaskTable(page_size=64)
table = TaskTable(pool, page_size=64)
assert table.get("nonexistent") == [] assert table.get("nonexistent") == []
assert table.get_cached("nonexistent") == 0 assert table.get_cached("nonexistent") == 0
def test_task_table_pop(): def test_task_table_pop():
pool = PagePool(n_pages=8) table = TaskTable(page_size=64)
table = TaskTable(pool, page_size=64)
table.set("task1", [0, 1], 64) table.set("task1", [0, 1], 64)
pages, cached = table.pop("task1") pages, cached = table.pop("task1")
assert pages == [0, 1] assert pages == [0, 1]
@ -186,26 +183,39 @@ def test_task_table_pop():
assert table.get("task1") == [] assert table.get("task1") == []
def test_task_table_extend_allocates_pages(): def test_paged_cache_task_extend_allocates():
pool = PagePool(n_pages=8) cache = PagedCache(
table = TaskTable(pool, page_size=64) n_layers=1,
table.set("task1", [], 0) n_pages=8,
ok = table.extend("task1", 200) page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [], 0)
ok = cache.task_extend("task1", 200)
assert ok assert ok
assert len(table.get("task1")) == 4 assert len(cache._table.get("task1")) == 4
def test_task_table_extend_fails_when_pool_full(): def test_paged_cache_task_extend_fails_when_pool_full():
pool = PagePool(n_pages=2) cache = PagedCache(
table = TaskTable(pool, page_size=64) n_layers=1,
table.set("task1", [pool.alloc(), pool.alloc()], 0) n_pages=2,
ok = table.extend("task1", 300) page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [0, 1], 0)
ok = cache.task_extend("task1", 300)
assert not ok assert not ok
def test_task_table_table_tensor(): def test_task_table_table_tensor():
pool = PagePool(n_pages=16) table = TaskTable(page_size=64)
table = TaskTable(pool, page_size=64)
table.set("a", [0, 1], 0) table.set("a", [0, 1], 0)
table.set("b", [2, 3, 4], 0) table.set("b", [2, 3, 4], 0)
t = table.table_tensor(["a", "b"], torch.device("cpu")) t = table.table_tensor(["a", "b"], torch.device("cpu"))
@ -215,8 +225,7 @@ def test_task_table_table_tensor():
def test_task_table_table_tensor_empty_input(): def test_task_table_table_tensor_empty_input():
pool = PagePool(n_pages=4) table = TaskTable(page_size=64)
table = TaskTable(pool, page_size=64)
t = table.table_tensor([], torch.device("cpu")) t = table.table_tensor([], torch.device("cpu"))
assert t.numel() == 0 assert t.numel() == 0

View File

@ -1,20 +1,20 @@
"""Unit tests for _Result accumulator and InferenceEngine.generate().""" """Unit tests for GenerateResult accumulator and InferenceEngine.generate()."""
import threading import threading
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from astrai.inference.engine import _Result from astrai.inference.engine import GenerateResult
from astrai.inference.task import STOP from astrai.inference.task import STOP
def test_result_append_single(): def test_result_append_single():
r = _Result(count=1) r = GenerateResult(count=1)
r.append("hello", 0) r.append("hello", 0)
assert r.results[0] == "hello" assert r.results[0] == "hello"
def test_result_append_multiple_tasks(): def test_result_append_multiple_tasks():
r = _Result(count=3) r = GenerateResult(count=3)
r.append("a", 0) r.append("a", 0)
r.append("b", 1) r.append("b", 1)
r.append("c", 2) r.append("c", 2)
@ -24,7 +24,7 @@ def test_result_append_multiple_tasks():
def test_result_stop_marks_complete(): def test_result_stop_marks_complete():
r = _Result(count=2) r = GenerateResult(count=2)
r.append("text", 0) r.append("text", 0)
r.append(STOP, 0) r.append(STOP, 0)
r.append("more", 1) r.append("more", 1)
@ -34,14 +34,14 @@ def test_result_stop_marks_complete():
def test_result_stop_does_not_double_count(): def test_result_stop_does_not_double_count():
r = _Result(count=1) r = GenerateResult(count=1)
r.append(STOP, 0) r.append(STOP, 0)
r.append(STOP, 0) r.append(STOP, 0)
assert r._completed == 1 assert r._completed == 1
def test_result_pop_all_returns_and_clears(): def test_result_pop_all_returns_and_clears():
r = _Result(count=2) r = GenerateResult(count=2)
r.append("a", 0) r.append("a", 0)
r.append("b", 1) r.append("b", 1)
out = r.pop_all() out = r.pop_all()
@ -52,7 +52,7 @@ def test_result_pop_all_returns_and_clears():
def test_result_wait_blocks_until_data(): def test_result_wait_blocks_until_data():
r = _Result(count=1) r = GenerateResult(count=1)
def delayed_append(): def delayed_append():
import time import time
@ -69,13 +69,13 @@ def test_result_wait_blocks_until_data():
def test_result_wait_timeout(): def test_result_wait_timeout():
r = _Result(count=1) r = GenerateResult(count=1)
ok = r.wait(timeout=0.01) ok = r.wait(timeout=0.01)
assert not ok assert not ok
def test_result_wait_completion_non_streaming(): def test_result_wait_completion_non_streaming():
r = _Result(count=2) r = GenerateResult(count=2)
def finish_later(): def finish_later():
import time import time
@ -93,7 +93,7 @@ def test_result_wait_completion_non_streaming():
def test_result_get_results(): def test_result_get_results():
r = _Result(count=2) r = GenerateResult(count=2)
r.append("hello", 0) r.append("hello", 0)
r.append("world", 1) r.append("world", 1)
results = r.get_results() results = r.get_results()
@ -148,9 +148,9 @@ def test_engine_generate_streaming_yields_tokens():
gen = eng.generate("hello", stream=True) gen = eng.generate("hello", stream=True)
cb = callbacks_saved[0] cb = callbacks_saved[0]
cb("t1", 0) cb("t1")
cb("t2", 0) cb("t2")
cb(STOP, 0) cb(STOP)
tokens = list(gen) tokens = list(gen)
assert tokens == ["t1", "t2"] assert tokens == ["t1", "t2"]

View File

@ -2,10 +2,12 @@
import pytest import pytest
from astrai.inference.server import app
def test_health_no_model(client, monkeypatch):
def test_health_no_model(client):
"""GET /health should return 200 even when engine not loaded.""" """GET /health should return 200 even when engine not loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", None) app.state.engine = None
response = client.get("/health") response = client.get("/health")
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
@ -22,15 +24,14 @@ def test_health_with_model(client, loaded_model):
assert data["model_loaded"] is True assert data["model_loaded"] is True
def test_chat_completions_non_stream(client, loaded_model, monkeypatch): def test_chat_completions_non_stream(client, loaded_model):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON.""" """POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
async def async_gen(): async def async_gen():
yield "Assistant reply" yield "Assistant reply"
mock_engine = loaded_model app.state.engine = loaded_model
mock_engine.generate_async.return_value = async_gen() loaded_model.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
@ -48,16 +49,15 @@ def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
assert "prompt_tokens" in data["usage"] assert "prompt_tokens" in data["usage"]
def test_chat_completions_stream(client, loaded_model, monkeypatch): def test_chat_completions_stream(client, loaded_model):
"""POST /v1/chat/completions with stream=true returns SSE stream.""" """POST /v1/chat/completions with stream=true returns SSE stream."""
async def async_gen(): async def async_gen():
yield "cumulative1" yield "cumulative1"
yield "cumulative2" yield "cumulative2"
mock_engine = loaded_model app.state.engine = loaded_model
mock_engine.generate_async.return_value = async_gen() loaded_model.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/chat/completions", "/v1/chat/completions",
json={ json={
@ -77,15 +77,14 @@ def test_chat_completions_stream(client, loaded_model, monkeypatch):
assert any("[DONE]" in line for line in lines) assert any("[DONE]" in line for line in lines)
def test_messages_non_stream(client, loaded_model, monkeypatch): def test_messages_non_stream(client, loaded_model):
"""POST /v1/messages with stream=false returns Anthropic-style JSON.""" """POST /v1/messages with stream=false returns Anthropic-style JSON."""
async def async_gen(): async def async_gen():
yield "Assistant reply" yield "Assistant reply"
mock_engine = loaded_model app.state.engine = loaded_model
mock_engine.generate_async.return_value = async_gen() loaded_model.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/messages", "/v1/messages",
json={ json={
@ -105,16 +104,15 @@ def test_messages_non_stream(client, loaded_model, monkeypatch):
assert "input_tokens" in data["usage"] assert "input_tokens" in data["usage"]
def test_messages_stream(client, loaded_model, monkeypatch): def test_messages_stream(client, loaded_model):
"""POST /v1/messages with stream=true returns Anthropic SSE stream.""" """POST /v1/messages with stream=true returns Anthropic SSE stream."""
async def async_gen(): async def async_gen():
yield "cumulative1" yield "cumulative1"
yield "cumulative2" yield "cumulative2"
mock_engine = loaded_model app.state.engine = loaded_model
mock_engine.generate_async.return_value = async_gen() loaded_model.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/messages", "/v1/messages",
json={ json={
@ -137,15 +135,14 @@ def test_messages_stream(client, loaded_model, monkeypatch):
assert "message_stop" in content assert "message_stop" in content
def test_messages_with_system(client, loaded_model, monkeypatch): def test_messages_with_system(client, loaded_model):
"""POST /v1/messages with system prompt.""" """POST /v1/messages with system prompt."""
async def async_gen(): async def async_gen():
yield "Reply" yield "Reply"
mock_engine = loaded_model app.state.engine = loaded_model
mock_engine.generate_async.return_value = async_gen() loaded_model.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
response = client.post( response = client.post(
"/v1/messages", "/v1/messages",
json={ json={