chore: 解耦 Executor/Scheduler/TaskManager,修复 stop 页泄漏,移除 ServerState 全局单例

This commit is contained in:
ViperEkura 2026-05-12 13:44:55 +08:00
parent 7440e9c809
commit df0845e916
10 changed files with 336 additions and 509 deletions

View File

@ -104,8 +104,7 @@ class PrefixCache:
class TaskTable:
"""Maps task_ids to page tables and cached token counts."""
def __init__(self, pool: PagePool, page_size: int):
self._pool = pool
def __init__(self, page_size: int):
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
@ -125,15 +124,8 @@ class TaskTable:
cached = self._cached.pop(task_id, 0)
return pages, cached
def extend(self, task_id: str, pos: int) -> bool:
page_table = self._pages[task_id]
needed = (pos + 1 + self._page_size - 1) // self._page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def get_ref(self, task_id: str) -> List[int]:
return self._pages.setdefault(task_id, [])
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
states = [self._pages.get(tid, []) for tid in task_ids]
@ -158,7 +150,7 @@ class PagedCache:
self.page_size = page_size
self._prefix = PrefixCache(page_size)
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
self._table = TaskTable(self._pool, page_size)
self._table = TaskTable(page_size)
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
@ -219,7 +211,14 @@ class PagedCache:
self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
return self._table.extend(task_id, pos)
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)

View File

@ -1,11 +1,4 @@
"""Unified inference engine for continuous batching.
Layers:
- GenerationParams: Immutable value object for sampling parameters.
- GenerationRequest: User-facing request DTO with validation.
- _Result: Thread-safe token accumulator (Observer pattern).
- InferenceEngine: Facade over InferenceScheduler + async wrapper.
"""
"""Unified inference engine for continuous batching."""
import asyncio
import gc
@ -21,6 +14,59 @@ from astrai.inference.task import STOP
from astrai.tokenize import AutoTokenizer
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
def __init__(self, count: int = 1):
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[Tuple[int, str]] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
return self._event.wait(timeout=timeout)
def wait_completion(self) -> None:
with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]:
with self._cond:
return self.results.copy()
def _validate_params(top_k: int, top_p: float, temperature: float) -> None:
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
@ -32,11 +78,7 @@ class GenerationParams:
class GenerationRequest:
"""Request parameters for text generation.
Encapsulates messages, sampling parameters (via GenerationParams),
and streaming preference for a single generation request.
"""
"""Request parameters for text generation."""
def __init__(
self,
@ -47,16 +89,6 @@ class GenerationRequest:
max_len: int = 1024,
stream: bool = False,
):
"""Initializes a generation request.
Args:
messages: Conversation history as list of {"role": ..., "content": ...}.
top_k: Top-k sampling count (0 disables).
top_p: Nucleus sampling probability threshold.
temperature: Sampling temperature.
max_len: Maximum tokens to generate.
stream: Whether to return output as a token stream.
"""
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
@ -65,7 +97,7 @@ class GenerationRequest:
max_tokens=max_len,
)
self.stream = stream
self._validate()
_validate_params(top_k, top_p, temperature)
@property
def top_k(self) -> int:
@ -83,112 +115,9 @@ class GenerationRequest:
def max_len(self) -> int:
return self.params.max_tokens
def _validate(self):
"""Validates sampling parameter ranges."""
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class _Result:
"""Thread-safe token accumulator for streaming and non-streaming modes.
Supports multiple concurrent generation tasks with per-index result tracking.
Uses a threading.Condition for efficient completion notification
and a threading.Event for streaming wakeup.
"""
def __init__(self, count: int = 1):
"""Initializes the accumulator.
Args:
count: Number of concurrent generation tasks to track.
"""
self._cond = threading.Condition()
self._event = threading.Event()
self.tokens: List[str] = []
self.results: List[str] = [""] * count
self._done: List[bool] = [False] * count
self._completed = 0
self._total = count
def append(self, token: str, idx: int = 0):
"""Appends a token to the result buffer.
In non-streaming mode, tokens are concatenated into results[idx].
The sentinel STOP marks a task as complete.
Args:
token: The decoded token string, or STOP sentinel.
idx: Index of the generation task this token belongs to.
"""
with self._cond:
self.tokens.append((idx, token))
if token is not STOP:
self.results[idx] += token
else:
if not self._done[idx]:
self._done[idx] = True
self._completed += 1
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated (idx, token) pairs.
Returns:
List of (index, token_string) tuples since the last call.
"""
with self._cond:
out = self.tokens.copy()
self.tokens.clear()
if not out:
self._event.clear()
return out
def wait(self, timeout: Optional[float] = None) -> bool:
"""Blocks until new tokens arrive or the timeout expires.
Args:
timeout: Maximum wait time in seconds (None = infinite).
Returns:
True if the event was set (new data available), False on timeout.
"""
return self._event.wait(timeout=timeout)
def wait_completion(self) -> None:
"""Blocks until all tasks complete (non-streaming).
Uses a Condition to sleep efficiently instead of busy-waiting.
The calling thread is parked until a STOP signal arrives.
"""
with self._cond:
self._cond.wait_for(lambda: self._completed >= self._total)
def get_results(self) -> List[str]:
"""Returns all accumulated results for non-streaming mode.
Returns:
List of complete generated strings, one per task index.
"""
with self._cond:
return self.results.copy()
class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler.
Usage:
with InferenceEngine(model, tokenizer) as engine:
for token in engine.generate("hello", stream=True):
print(token, end="")
text = engine.generate("hello")
"""
"""Unified inference engine backed by continuous-batching scheduler."""
def __init__(
self,
@ -199,17 +128,6 @@ class InferenceEngine:
max_prompt_len: int = 2048,
page_size: int = 128,
):
"""Initializes the inference engine.
Args:
model: The model instance.
tokenizer: The tokenizer instance.
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
"""
self.model = model
self.tokenizer = tokenizer
self.scheduler = InferenceScheduler(
@ -239,22 +157,8 @@ class InferenceEngine:
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
"""Generates text from a prompt.
_validate_params(top_k, top_p, temperature)
Args:
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
Returns:
stream=False, single prompt: str
stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
"""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
@ -275,21 +179,6 @@ class InferenceEngine:
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
Runs the synchronous generator in a background thread pool executor,
yielding tokens to the async consumer as they arrive.
Args:
prompt: Input text to generate from.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
@ -306,14 +195,6 @@ class InferenceEngine:
@staticmethod
def _next_token(gen: Generator) -> Optional[str]:
"""Retrieves the next token from a synchronous generator.
Args:
gen: A synchronous generator yielding token strings.
Returns:
The next token, or None if the generator is exhausted.
"""
try:
return next(gen)
except StopIteration:
@ -322,16 +203,6 @@ class InferenceEngine:
def generate_with_request(
self, request: GenerationRequest
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a structured GenerationRequest.
Applies the chat template to the request's messages before generation.
Args:
request: A GenerationRequest with messages and parameters.
Returns:
Generator, string, or list of strings (see generate()).
"""
prompt = self.tokenizer.apply_chat_template(request.messages, tokenize=False)
return self.generate(
prompt=prompt,
@ -342,6 +213,37 @@ class InferenceEngine:
top_k=request.params.top_k,
)
def _submit_tasks(
self,
prompts: List[str],
max_tokens: int,
temperature: float,
top_p: float,
top_k: int,
) -> Tuple[GenerateResult, List[str]]:
n = len(prompts)
result = GenerateResult(count=n)
task_ids = []
for i, p in enumerate(prompts):
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=cb,
)
task_ids.append(task_id)
return result, task_ids
@staticmethod
def _make_callback(result: GenerateResult, idx: int):
def cb(token):
result.append(token, idx)
return cb
def _generate_streaming(
self,
prompts: List[str],
@ -351,38 +253,10 @@ class InferenceEngine:
top_p: float,
top_k: int,
) -> Generator:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Single prompt yields raw token strings; batch yields (idx, token) tuples.
Args:
prompts: List of prompts.
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Yields:
Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
"""
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
n = len(prompts)
result = _Result(count=n)
task_ids = []
for i, p in enumerate(prompts):
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok, idx=i: result.append(tok, idx),
)
task_ids.append(task_id)
remaining = n
finished = [False] * n
@ -399,8 +273,7 @@ class InferenceEngine:
else:
yield (idx, token) if is_batch else token
if remaining > 0:
if not result.wait(timeout=0.05):
pass
result.wait(timeout=0.05)
finally:
for tid in task_ids:
self.scheduler.remove_task(tid)
@ -416,57 +289,22 @@ class InferenceEngine:
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
"""Internal non-streaming generator.
Submits all prompts to the scheduler and waits for all to complete.
Args:
prompts: List of prompt strings.
is_batch: Whether multiple prompts were provided.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
Returns:
Single string for one prompt, list of strings for batch.
"""
result = _Result(count=len(prompts))
task_ids = []
for i, p in enumerate(prompts):
def make_cb(idx):
return lambda tok: result.append(tok, idx)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=make_cb(i),
)
task_ids.append(task_id)
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
result.wait_completion()
for task_id in task_ids:
self.scheduler.remove_task(task_id)
for tid in task_ids:
self.scheduler.remove_task(tid)
res = result.get_results()
return res if is_batch else res[0]
def get_stats(self) -> Dict[str, Any]:
"""Returns current engine statistics.
Returns:
Dict with total_tasks, total_tokens, active_tasks, waiting_queue.
"""
return self.scheduler.get_stats()
def shutdown(self) -> None:
"""Shuts down the engine, stops the scheduler, and frees GPU memory."""
self.scheduler.stop()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -5,7 +5,7 @@ import torch
from astrai.inference.cache import PagedCache
from astrai.inference.sample import sample
from astrai.inference.task import STOP, Task, TaskStatus
from astrai.inference.task import Task
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
@ -60,31 +60,10 @@ class Executor:
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_cache.page_size
for t in tasks:
self.page_cache.task_record_hashes(
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
def execute_decode(self, tasks: List[Task], start_pos: int) -> List[int]:
if not tasks:
return
return []
tasks = sorted(tasks, key=lambda t: t.task_id)
valid: List[Task] = []
for t in tasks:
if self.page_cache.task_extend(t.task_id, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if not valid:
return
tasks = valid
batch_sz = len(tasks)
input_ids = torch.tensor(
@ -112,22 +91,9 @@ class Executor:
)
logits = outputs["logits"][:, -1, :]
next_tokens = sample(
return sample(
logits,
temperature=temperatures,
top_k=top_ks,
top_p=top_ps,
).tolist()
for t, ntok in zip(tasks, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self.page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
for t in tasks:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)

View File

@ -6,7 +6,7 @@ import torch
from astrai.inference.cache import PagedCache
from astrai.inference.executor import Executor
from astrai.inference.task import STOP, Task, TaskManager
from astrai.inference.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import AutoTokenizer
@ -75,17 +75,15 @@ class InferenceScheduler:
return self._task_mgr.get_stats()
def _run_generation_loop(self) -> None:
stop_ids = self._task_mgr.tokenizer.stop_ids
try:
while self._running:
finished = self._task_mgr.remove_finished_tasks(
self._task_mgr.tokenizer.stop_ids
)
finished = self._task_mgr.remove_finished_tasks(stop_ids)
for task in finished:
self._page_cache.task_free(task.task_id)
available = self._task_mgr.max_batch_size - len(
self._task_mgr.active_tasks
)
active = self._task_mgr.get_active_tasks()
available = self._task_mgr.max_batch_size - len(active)
if available > 0:
candidates = self._task_mgr.pull_candidates(available)
failed = []
@ -102,7 +100,7 @@ class InferenceScheduler:
continue
to_prefill = [
t for t in self._task_mgr.active_tasks if t.output_tokens == 0
t for t in self._task_mgr.get_active_tasks() if t.output_tokens == 0
]
if to_prefill:
for t in to_prefill:
@ -118,23 +116,56 @@ class InferenceScheduler:
for (prompt_len, start_pos), group in groups.items():
self._executor.execute_prefill(group, prompt_len, start_pos)
start_logical_page = start_pos // self._page_cache.page_size
for t in group:
self._page_cache.task_record_hashes(
t.task_id,
t.prompt_ids,
start_logical_page=start_logical_page,
)
pos_groups: Dict[int, List[Task]] = {}
for t in self._task_mgr.active_tasks:
for t in self._task_mgr.get_active_tasks():
pos_groups.setdefault(t.next_pos, []).append(t)
if pos_groups:
best_pos = max(pos_groups, key=lambda p: len(pos_groups[p]))
self._executor.execute_decode(pos_groups[best_pos], best_pos)
group = sorted(pos_groups[best_pos], key=lambda t: t.task_id)
valid: List[Task] = []
for t in group:
if self._page_cache.task_extend(t.task_id, best_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
if valid:
next_tokens = self._executor.execute_decode(valid, best_pos)
for t, ntok in zip(valid, next_tokens):
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(
self._task_mgr.tokenizer.decode([ntok])
)
for t in valid:
if t.is_finished(stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
except Exception as e:
logger.error(f"Scheduler loop crashed: {e}", exc_info=True)
for task in self._task_mgr.active_tasks:
if task.stream_callback:
task.stream_callback(STOP)
for task in self._task_mgr.waiting_queue:
for task in self._task_mgr.get_active_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._page_cache.task_free(task.task_id)
self._task_mgr.clear_queues()
raise
def start(self) -> None:
@ -149,7 +180,8 @@ class InferenceScheduler:
self._task_mgr.wake()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0)
self._task_mgr.waiting_queue.clear()
self._task_mgr.active_tasks.clear()
for task in self._task_mgr.get_active_tasks():
self._page_cache.task_free(task.task_id)
self._task_mgr.clear_queues()
if torch.cuda.is_available():
torch.cuda.empty_cache()

View File

@ -12,7 +12,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
@ -25,20 +25,6 @@ logger = logging.getLogger(__name__)
_project_root = Path(__file__).parent.parent.parent
class ServerState:
def __init__(self):
self.engine: Optional[InferenceEngine] = None
self.config: Dict[str, Any] = {
"device": "cuda",
"dtype": torch.bfloat16,
"param_path": None,
"max_batch_size": 16,
}
_state = ServerState()
class ChatMessage(BaseModel):
role: str
content: str
@ -81,47 +67,12 @@ class MessagesRequest(BaseModel):
stop_sequences: Optional[List[str]] = None
def configure_server(
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
_state.config.update(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
load_model(
param_path=_state.config["param_path"],
device=_state.config["device"],
dtype=_state.config["dtype"],
max_batch_size=_state.config["max_batch_size"],
)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if _state.engine:
_state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def load_model(
def _create_engine(
param_path: Optional[Path] = None,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
max_batch_size: int = 16,
):
) -> InferenceEngine:
if param_path is None:
param_path = _project_root / "params"
if not param_path.exists():
@ -132,18 +83,38 @@ def load_model(
model.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")
_state.engine = InferenceEngine(
engine = InferenceEngine(
model=model,
tokenizer=tokenizer,
max_batch_size=max_batch_size,
)
logger.info(f"Inference engine initialized with max_batch_size={max_batch_size}")
return engine
def _get_engine() -> InferenceEngine:
if _state.engine is None:
@asynccontextmanager
async def lifespan(app: FastAPI):
config = app.state.server_config
if not config.get("_test", False):
try:
app.state.engine = _create_engine(**config)
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise
yield
if app.state.engine:
app.state.engine.shutdown()
logger.info("Inference engine shutdown complete")
app = FastAPI(title="AstrAI Inference Server", version="0.2.0", lifespan=lifespan)
def _get_engine(request: Request) -> InferenceEngine:
engine = request.app.state.engine
if engine is None:
raise HTTPException(status_code=503, detail="Engine not initialized")
return _state.engine
return engine
def _make_chunk(
@ -155,7 +126,6 @@ def _make_chunk(
model: str,
index: int = 0,
) -> str:
"""Build a single SSE ``data:`` chunk matching OpenAI streaming format."""
data = {
"id": resp_id,
"object": "chat.completion.chunk",
@ -172,23 +142,56 @@ def _make_chunk(
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
for seq in stop_sequences:
if seq and seq in text:
return seq
return None
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
def _build_anthropic_messages(
messages: List[AnthropicMessage], system: Optional[str]
) -> List[Dict[str, str]]:
result: List[Dict[str, str]] = []
if system:
result.append({"role": "system", "content": system})
for m in messages:
content = _extract_text_content(m.content)
if content:
result.append({"role": m.role, "content": content})
return result
@app.get("/health")
async def health():
async def health(request: Request):
return {
"status": "ok",
"model_loaded": _state.engine is not None,
"model_loaded": request.app.state.engine is not None,
}
@app.get("/stats")
async def get_stats():
return _get_engine().get_stats()
async def get_stats(request: Request):
return _get_engine(request).get_stats()
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
engine = _get_engine()
async def chat_completion(request: ChatCompletionRequest, req: Request):
engine = _get_engine(req)
resp_id = f"chatcmpl-{uuid.uuid4().hex[:12]}"
created = int(time.time())
model = request.model
@ -284,44 +287,9 @@ async def chat_completion(request: ChatCompletionRequest):
}
def _make_anthropic_sse(event: str, data: Dict[str, Any]) -> str:
return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n"
def _check_stop_sequence(text: str, stop_sequences: List[str]) -> Optional[str]:
for seq in stop_sequences:
if seq and seq in text:
return seq
return None
def _extract_text_content(content: Union[str, List[Dict[str, Any]]]) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
return block.get("text", "")
return ""
def _build_anthropic_messages(
messages: List[AnthropicMessage], system: Optional[str]
) -> List[Dict[str, str]]:
result: List[Dict[str, str]] = []
if system:
result.append({"role": "system", "content": system})
for m in messages:
content = _extract_text_content(m.content)
if content:
result.append({"role": m.role, "content": content})
return result
@app.post("/v1/messages")
async def create_message(request: MessagesRequest):
"""Anthropic-compatible Messages API endpoint (streaming + non-streaming)."""
engine = _get_engine()
async def create_message(request: MessagesRequest, req: Request):
engine = _get_engine(req)
resp_id = f"msg_{uuid.uuid4().hex[:24]}"
model = request.model
@ -472,12 +440,12 @@ def run_server(
param_path: Optional[Path] = None,
max_batch_size: int = 16,
):
configure_server(
device=device,
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
)
app.state.server_config = {
"device": device,
"dtype": dtype,
"param_path": param_path,
"max_batch_size": max_batch_size,
}
uvicorn.run(
"astrai.inference.server:app",
host=host,

View File

@ -139,23 +139,24 @@ class TaskManager:
}
def remove_finished_tasks(self, stop_ids: List[int]) -> List[Task]:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
with self._lock:
finished = []
for task in self.active_tasks:
if task.status == TaskStatus.ABORTED:
task.finish_time = time.time()
finished.append(task)
elif task.is_finished(stop_ids):
task.status = TaskStatus.FINISHED
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
return finished
self.active_tasks = [
t
for t in self.active_tasks
if t.status not in (TaskStatus.FINISHED, TaskStatus.ABORTED)
]
return finished
def pull_candidates(self, n: int) -> List[Task]:
to_add: List[Task] = []
@ -180,5 +181,14 @@ class TaskManager:
self._task_event.clear()
self._task_event.wait(timeout=timeout)
def get_active_tasks(self) -> List[Task]:
with self._lock:
return list(self.active_tasks)
def clear_queues(self) -> None:
with self._lock:
self.waiting_queue.clear()
self.active_tasks.clear()
def wake(self) -> None:
self._task_event.set()

View File

@ -11,6 +11,14 @@ from astrai.inference.server import app
@pytest.fixture
def client():
"""Provide a test client for the FastAPI app."""
app.state.server_config = {
"device": "cpu",
"dtype": "bfloat16",
"param_path": None,
"max_batch_size": 1,
"_test": True,
}
app.state.engine = None
return TestClient(app)
@ -39,7 +47,7 @@ def mock_engine():
@pytest.fixture
def loaded_model(mock_engine, monkeypatch):
def loaded_model(client, mock_engine):
"""Simulate that the engine is loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = mock_engine
return mock_engine

View File

@ -162,23 +162,20 @@ def test_prefix_cache_has_page():
def test_task_table_set_get():
pool = PagePool(n_pages=8)
table = TaskTable(pool, page_size=64)
table = TaskTable(page_size=64)
table.set("task1", [0, 1, 2], 128)
assert table.get("task1") == [0, 1, 2]
assert table.get_cached("task1") == 128
def test_task_table_get_missing():
pool = PagePool(n_pages=8)
table = TaskTable(pool, page_size=64)
table = TaskTable(page_size=64)
assert table.get("nonexistent") == []
assert table.get_cached("nonexistent") == 0
def test_task_table_pop():
pool = PagePool(n_pages=8)
table = TaskTable(pool, page_size=64)
table = TaskTable(page_size=64)
table.set("task1", [0, 1], 64)
pages, cached = table.pop("task1")
assert pages == [0, 1]
@ -186,26 +183,39 @@ def test_task_table_pop():
assert table.get("task1") == []
def test_task_table_extend_allocates_pages():
pool = PagePool(n_pages=8)
table = TaskTable(pool, page_size=64)
table.set("task1", [], 0)
ok = table.extend("task1", 200)
def test_paged_cache_task_extend_allocates():
cache = PagedCache(
n_layers=1,
n_pages=8,
page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [], 0)
ok = cache.task_extend("task1", 200)
assert ok
assert len(table.get("task1")) == 4
assert len(cache._table.get("task1")) == 4
def test_task_table_extend_fails_when_pool_full():
pool = PagePool(n_pages=2)
table = TaskTable(pool, page_size=64)
table.set("task1", [pool.alloc(), pool.alloc()], 0)
ok = table.extend("task1", 300)
def test_paged_cache_task_extend_fails_when_pool_full():
cache = PagedCache(
n_layers=1,
n_pages=2,
page_size=64,
n_kv_heads=2,
head_dim=8,
device=torch.device("cpu"),
dtype=torch.float32,
)
cache._table.set("task1", [0, 1], 0)
ok = cache.task_extend("task1", 300)
assert not ok
def test_task_table_table_tensor():
pool = PagePool(n_pages=16)
table = TaskTable(pool, page_size=64)
table = TaskTable(page_size=64)
table.set("a", [0, 1], 0)
table.set("b", [2, 3, 4], 0)
t = table.table_tensor(["a", "b"], torch.device("cpu"))
@ -215,8 +225,7 @@ def test_task_table_table_tensor():
def test_task_table_table_tensor_empty_input():
pool = PagePool(n_pages=4)
table = TaskTable(pool, page_size=64)
table = TaskTable(page_size=64)
t = table.table_tensor([], torch.device("cpu"))
assert t.numel() == 0

View File

@ -1,20 +1,20 @@
"""Unit tests for _Result accumulator and InferenceEngine.generate()."""
"""Unit tests for GenerateResult accumulator and InferenceEngine.generate()."""
import threading
from unittest.mock import MagicMock, patch
from astrai.inference.engine import _Result
from astrai.inference.engine import GenerateResult
from astrai.inference.task import STOP
def test_result_append_single():
r = _Result(count=1)
r = GenerateResult(count=1)
r.append("hello", 0)
assert r.results[0] == "hello"
def test_result_append_multiple_tasks():
r = _Result(count=3)
r = GenerateResult(count=3)
r.append("a", 0)
r.append("b", 1)
r.append("c", 2)
@ -24,7 +24,7 @@ def test_result_append_multiple_tasks():
def test_result_stop_marks_complete():
r = _Result(count=2)
r = GenerateResult(count=2)
r.append("text", 0)
r.append(STOP, 0)
r.append("more", 1)
@ -34,14 +34,14 @@ def test_result_stop_marks_complete():
def test_result_stop_does_not_double_count():
r = _Result(count=1)
r = GenerateResult(count=1)
r.append(STOP, 0)
r.append(STOP, 0)
assert r._completed == 1
def test_result_pop_all_returns_and_clears():
r = _Result(count=2)
r = GenerateResult(count=2)
r.append("a", 0)
r.append("b", 1)
out = r.pop_all()
@ -52,7 +52,7 @@ def test_result_pop_all_returns_and_clears():
def test_result_wait_blocks_until_data():
r = _Result(count=1)
r = GenerateResult(count=1)
def delayed_append():
import time
@ -69,13 +69,13 @@ def test_result_wait_blocks_until_data():
def test_result_wait_timeout():
r = _Result(count=1)
r = GenerateResult(count=1)
ok = r.wait(timeout=0.01)
assert not ok
def test_result_wait_completion_non_streaming():
r = _Result(count=2)
r = GenerateResult(count=2)
def finish_later():
import time
@ -93,7 +93,7 @@ def test_result_wait_completion_non_streaming():
def test_result_get_results():
r = _Result(count=2)
r = GenerateResult(count=2)
r.append("hello", 0)
r.append("world", 1)
results = r.get_results()
@ -148,9 +148,9 @@ def test_engine_generate_streaming_yields_tokens():
gen = eng.generate("hello", stream=True)
cb = callbacks_saved[0]
cb("t1", 0)
cb("t2", 0)
cb(STOP, 0)
cb("t1")
cb("t2")
cb(STOP)
tokens = list(gen)
assert tokens == ["t1", "t2"]

View File

@ -2,10 +2,12 @@
import pytest
from astrai.inference.server import app
def test_health_no_model(client, monkeypatch):
def test_health_no_model(client):
"""GET /health should return 200 even when engine not loaded."""
monkeypatch.setattr("astrai.inference.server._state.engine", None)
app.state.engine = None
response = client.get("/health")
assert response.status_code == 200
data = response.json()
@ -22,15 +24,14 @@ def test_health_with_model(client, loaded_model):
assert data["model_loaded"] is True
def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
def test_chat_completions_non_stream(client, loaded_model):
"""POST /v1/chat/completions with stream=false returns OpenAI-style JSON."""
async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
@ -48,16 +49,15 @@ def test_chat_completions_non_stream(client, loaded_model, monkeypatch):
assert "prompt_tokens" in data["usage"]
def test_chat_completions_stream(client, loaded_model, monkeypatch):
def test_chat_completions_stream(client, loaded_model):
"""POST /v1/chat/completions with stream=true returns SSE stream."""
async def async_gen():
yield "cumulative1"
yield "cumulative2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/chat/completions",
json={
@ -77,15 +77,14 @@ def test_chat_completions_stream(client, loaded_model, monkeypatch):
assert any("[DONE]" in line for line in lines)
def test_messages_non_stream(client, loaded_model, monkeypatch):
def test_messages_non_stream(client, loaded_model):
"""POST /v1/messages with stream=false returns Anthropic-style JSON."""
async def async_gen():
yield "Assistant reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
json={
@ -105,16 +104,15 @@ def test_messages_non_stream(client, loaded_model, monkeypatch):
assert "input_tokens" in data["usage"]
def test_messages_stream(client, loaded_model, monkeypatch):
def test_messages_stream(client, loaded_model):
"""POST /v1/messages with stream=true returns Anthropic SSE stream."""
async def async_gen():
yield "cumulative1"
yield "cumulative2"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
json={
@ -137,15 +135,14 @@ def test_messages_stream(client, loaded_model, monkeypatch):
assert "message_stop" in content
def test_messages_with_system(client, loaded_model, monkeypatch):
def test_messages_with_system(client, loaded_model):
"""POST /v1/messages with system prompt."""
async def async_gen():
yield "Reply"
mock_engine = loaded_model
mock_engine.generate_async.return_value = async_gen()
monkeypatch.setattr("astrai.inference.server._state.engine", mock_engine)
app.state.engine = loaded_model
loaded_model.generate_async.return_value = async_gen()
response = client.post(
"/v1/messages",
json={