Compare commits

..

3 Commits

Author SHA1 Message Date
ViperEkura a3bde30fb1 feat: 服务化基础设施 - 有界队列/超时/优雅关闭/metrics
- astrai/inference/scheduler.py: 有界队列 (max_queue_size) 拒绝满时入队抛 RuntimeError
    -> 请求超时检测 (deadline + _abort_expired_tasks),超时任务 abort 释放页并通知回调
    -> stop() 改为 drain 模式:等待活跃任务自然结束再强制清理
    -> get_stats() 扩展 latency P50/P95/P99 + cache hit rate
- astrai/inference/engine.py: generate/generate_async 新增 timeout 参数
    -> _generate_streaming/_generate_non_streaming 捕获 add_task 异常并清理
- astrai/inference/server.py: 新增 /metrics 端点 (Prometheus 格式)
    -> chat completions 端点捕获 RuntimeError 返回 503
    -> configure_server 传递 max_queue_size/request_timeout
- astrai/inference/cache.py: 新增 lookup_hits/lookup_misses 计数器
- tests/: fix stats key total_tasks -> total_requests
2026-05-10 18:16:51 +08:00
ViperEkura 3da428e0e4 perf: PagedCache 持久前缀缓存 + LRU 逐出
- astrai/inference/cache.py: refcount 归零时保留 hash 映射,页加入 LRU evictable 池
- alloc() 无空闲页时从 LRU 逐出,优先释放 _free_mask
- lookup_prefix/inc_ref 触发 _touch 更新 LRU 序
- record_page 设置 pin 标记并从 LRU 移除
2026-05-10 18:05:11 +08:00
ViperEkura 133a9de98f feat: _generate_streaming 支持 batch 模式
- _Result.append 存储 (idx, token) 元组,pop_all 返回对应列表
- 单 prompt: Generator[str](向后兼容)
- 多 prompt: Generator[Tuple[int, str]],token 交错到达,调用方自行分流
- 不使用 dispatch 线程 / Queue,避免同步开销和内存积压
2026-05-10 17:42:20 +08:00
5 changed files with 299 additions and 90 deletions

View File

@ -22,14 +22,16 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
class PagedCache: class PagedCache:
"""Paged KV cache with page-table-indirected read/write. """Paged KV cache with page-table-indirected read/write and persistent prefix caching.
Combines: Combines:
- Page pool (ref-counted alloc/free via bitmask) - Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache) - KV tensor storage (k_cache, v_cache)
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx) - Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
- LRU eviction for persistent cross-batch prefix caching
Call :meth:`bind` to obtain a batch view for the attention layers. Pages with recorded hashes persist after refcount reaches 0 (pinned).
They are evicted via LRU only when alloc() finds no free pages.
""" """
def __init__( def __init__(
@ -57,6 +59,26 @@ class PagedCache:
) )
self._page_to_hash: Dict[int, int] = {} self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {}
self._lru: List[int] = []
self._pin: List[bool] = [False] * n_pages
self.lookup_hits: int = 0
self.lookup_misses: int = 0
def _touch(self, idx: int) -> None:
if self._refs[idx] == 0 and idx in self._lru:
self._lru.remove(idx)
self._lru.append(idx)
def _evict_one(self) -> int:
while self._lru:
idx = self._lru.pop(0)
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
self._pin[idx] = False
self._refs[idx] = 1
return idx
return -1
def record_page( def record_page(
self, page_idx: int, token_ids: List[int], logical_page_idx: int self, page_idx: int, token_ids: List[int], logical_page_idx: int
@ -67,6 +89,9 @@ class PagedCache:
self._hash_to_page.pop(old_h, None) self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx self._hash_to_page[h] = page_idx
self._pin[page_idx] = True
if page_idx in self._lru:
self._lru.remove(page_idx)
def lookup_prefix(self, token_ids: List[int]) -> List[int]: def lookup_prefix(self, token_ids: List[int]) -> List[int]:
full_pages = len(token_ids) // self.page_size full_pages = len(token_ids) // self.page_size
@ -75,21 +100,28 @@ class PagedCache:
h = page_hash(token_ids, i, self.page_size) h = page_hash(token_ids, i, self.page_size)
p = self._hash_to_page.get(h) p = self._hash_to_page.get(h)
if p is None: if p is None:
self.lookup_misses += 1
break break
self.lookup_hits += 1
self._touch(p)
hits.append(p) hits.append(p)
return hits return hits
def inc_ref(self, idx: int) -> None: def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1 self._refs[idx] += 1
if self._refs[idx] == 1 and idx in self._lru:
self._lru.remove(idx)
def alloc(self) -> int: def alloc(self) -> int:
lsb = self._free_mask & -self._free_mask if self._free_mask:
if lsb == 0: lsb = self._free_mask & -self._free_mask
return -1 idx = lsb.bit_length() - 1
idx = lsb.bit_length() - 1 self._free_mask ^= lsb
self._free_mask ^= lsb self._refs[idx] = 1
self._refs[idx] = 1 if idx in self._lru:
return idx self._lru.remove(idx)
return idx
return self._evict_one()
def alloc_n(self, n: int) -> List[int]: def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)] pages = [self.alloc() for _ in range(n)]
@ -103,10 +135,15 @@ class PagedCache:
def free(self, idx: int) -> None: def free(self, idx: int) -> None:
self._refs[idx] -= 1 self._refs[idx] -= 1
if self._refs[idx] == 0: if self._refs[idx] == 0:
self._free_mask |= 1 << idx h = self._page_to_hash.get(idx)
h = self._page_to_hash.pop(idx, None) if h is not None and self._pin[idx]:
if h is not None: self._lru.append(idx)
self._hash_to_page.pop(h, None) else:
self._free_mask |= 1 << idx
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
self._pin[idx] = False
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len) return CacheView(self, page_table, total_len)

View File

@ -11,7 +11,7 @@ import asyncio
import gc import gc
import threading import threading
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -126,7 +126,7 @@ class _Result:
idx: Index of the generation task this token belongs to. idx: Index of the generation task this token belongs to.
""" """
with self._cond: with self._cond:
self.tokens.append(token) self.tokens.append((idx, token))
if token is not STOP: if token is not STOP:
self.results[idx] += token self.results[idx] += token
else: else:
@ -136,11 +136,11 @@ class _Result:
self._cond.notify_all() self._cond.notify_all()
self._event.set() self._event.set()
def pop_all(self) -> List[str]: def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated tokens. """Returns and clears all accumulated (idx, token) pairs.
Returns: Returns:
List of token strings since the last call. List of (index, token_string) tuples since the last call.
""" """
with self._cond: with self._cond:
out = self.tokens.copy() out = self.tokens.copy()
@ -195,6 +195,8 @@ class InferenceEngine:
model: nn.Module, model: nn.Module,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 1, max_batch_size: int = 1,
max_queue_size: int = 64,
request_timeout: float = 60.0,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048, max_prompt_len: int = 2048,
page_size: int = 128, page_size: int = 128,
@ -207,7 +209,6 @@ class InferenceEngine:
max_batch_size: Maximum number of concurrent tasks. max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length. max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens. max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page. page_size: Number of tokens per KV cache page.
""" """
self.model = model self.model = model
@ -216,6 +217,8 @@ class InferenceEngine:
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_queue_size=max_queue_size,
request_timeout=request_timeout,
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len, max_prompt_len=max_prompt_len,
page_size=page_size, page_size=page_size,
@ -238,31 +241,35 @@ class InferenceEngine:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> Union[Generator[str, None, None], str, List[str]]: timeout: Optional[float] = None,
) -> Union[Generator, str, List[str]]:
"""Generates text from a prompt. """Generates text from a prompt.
Args: Args:
prompt: Single string or list of strings for batch generation. prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens one by one. stream: If True, returns a generator yielding tokens.
max_tokens: Maximum number of tokens to generate. max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature. temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold. top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables). top_k: Top-k sampling count (0 disables).
timeout: Per-request timeout in seconds (None = use scheduler default).
Returns: Returns:
Generator (stream=True), single string (non-stream, single prompt), stream=False, single prompt: str
or list of strings (non-stream, batch prompts). stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
""" """
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
if stream: if stream:
return self._generate_streaming( return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout
) )
else: else:
return self._generate_non_streaming( return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout
) )
def generate_async( def generate_async(
@ -272,6 +279,7 @@ class InferenceEngine:
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
timeout: Optional[float] = None,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop. """Async streaming generator that does not block the event loop.
@ -284,12 +292,13 @@ class InferenceEngine:
temperature: Sampling temperature. temperature: Sampling temperature.
top_p: Nucleus sampling threshold. top_p: Nucleus sampling threshold.
top_k: Top-k sampling count. top_k: Top-k sampling count.
timeout: Per-request timeout in seconds.
Yields: Yields:
Decoded token strings as they are generated. Decoded token strings as they are generated.
""" """
sync_gen = self._generate_streaming( sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k [prompt], False, max_tokens, temperature, top_p, top_k, timeout
) )
async def _agen(): async def _agen():
@ -348,49 +357,68 @@ class InferenceEngine:
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
) -> Generator[str, None, None]: timeout: Optional[float] = None,
) -> Generator:
"""Internal streaming generator. """Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive. Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Cleans up the scheduler task on GeneratorExit. Single prompt yields raw token strings; batch yields (idx, token) tuples.
Args: Args:
prompts: List of prompts (only first is used; batch not yet supported). prompts: List of prompts.
is_batch: If True, raises NotImplementedError. is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
max_tokens: Maximum tokens to generate. max_tokens: Maximum tokens to generate.
temperature: Sampling temperature. temperature: Sampling temperature.
top_p: Nucleus sampling threshold. top_p: Nucleus sampling threshold.
top_k: Top-k sampling count. top_k: Top-k sampling count.
timeout: Per-request timeout in seconds.
Yields: Yields:
Decoded token strings. Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
""" """
if is_batch: n = len(prompts)
raise NotImplementedError("Batch streaming not yet supported") result = _Result(count=n)
task_ids = []
result = _Result() try:
for i, p in enumerate(prompts):
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok, idx=i: result.append(tok, idx),
timeout=timeout,
)
task_ids.append(task_id)
except RuntimeError:
for tid in task_ids:
self.scheduler.remove_task(tid)
raise
task_id = self.scheduler.add_task( remaining = n
prompt=prompts[0], finished = [False] * n
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok: result.append(tok, 0),
)
def gen(): def gen():
nonlocal remaining
try: try:
while True: while remaining > 0:
tokens = result.pop_all() items = result.pop_all()
for token in tokens: for idx, token in items:
if token is STOP: if token is STOP:
return if not finished[idx]:
yield token finished[idx] = True
if not result.wait(timeout=0.05): remaining -= 1
pass else:
yield (idx, token) if is_batch else token
if remaining > 0:
if not result.wait(timeout=0.05):
pass
finally: finally:
self.scheduler.remove_task(task_id) for tid in task_ids:
self.scheduler.remove_task(tid)
return gen() return gen()
@ -402,6 +430,7 @@ class InferenceEngine:
temperature: float, temperature: float,
top_p: float, top_p: float,
top_k: int, top_k: int,
timeout: Optional[float] = None,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
"""Internal non-streaming generator. """Internal non-streaming generator.
@ -414,6 +443,7 @@ class InferenceEngine:
temperature: Sampling temperature. temperature: Sampling temperature.
top_p: Nucleus sampling threshold. top_p: Nucleus sampling threshold.
top_k: Top-k sampling count. top_k: Top-k sampling count.
timeout: Per-request timeout in seconds.
Returns: Returns:
Single string for one prompt, list of strings for batch. Single string for one prompt, list of strings for batch.
@ -421,20 +451,26 @@ class InferenceEngine:
result = _Result(count=len(prompts)) result = _Result(count=len(prompts))
task_ids = [] task_ids = []
for i, p in enumerate(prompts): try:
for i, p in enumerate(prompts):
def make_cb(idx): def make_cb(idx):
return lambda tok: result.append(tok, idx) return lambda tok: result.append(tok, idx)
task_id = self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=p, prompt=p,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
stream_callback=make_cb(i), stream_callback=make_cb(i),
) timeout=timeout,
task_ids.append(task_id) )
task_ids.append(task_id)
except RuntimeError:
for tid in task_ids:
self.scheduler.remove_task(tid)
raise
result.wait_completion() result.wait_completion()

View File

@ -55,6 +55,7 @@ class Task:
self.n_pages: int = 0 self.n_pages: int = 0
self._prefix_cached_tokens: int = 0 self._prefix_cached_tokens: int = 0
self.arrival_time = time.time() self.arrival_time = time.time()
self.deadline: float = 0.0
self.finish_time: Optional[float] = None self.finish_time: Optional[float] = None
self.stream_callback = stream_callback self.stream_callback = stream_callback
self._pages_freed: bool = False self._pages_freed: bool = False
@ -86,6 +87,8 @@ class InferenceScheduler:
model: AutoModel, model: AutoModel,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
max_batch_size: int = 16, max_batch_size: int = 16,
max_queue_size: int = 64,
request_timeout: float = 60.0,
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 512, max_prompt_len: int = 512,
page_size: int = 64, page_size: int = 64,
@ -97,6 +100,8 @@ class InferenceScheduler:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.max_batch_size = max_batch_size self.max_batch_size = max_batch_size
self.max_queue_size = max_queue_size
self.request_timeout = request_timeout
self.max_seq_len = max_seq_len or config.max_len self.max_seq_len = max_seq_len or config.max_len
self.max_prompt_len = max_prompt_len self.max_prompt_len = max_prompt_len
self.page_size = page_size self.page_size = page_size
@ -124,11 +129,16 @@ class InferenceScheduler:
self.active_tasks: List[Task] = [] self.active_tasks: List[Task] = []
self._running = False self._running = False
self._draining = False
self._task_event = threading.Event() self._task_event = threading.Event()
self._lock = threading.Lock() self._lock = threading.Lock()
self._total_tasks = 0 self._total_tasks = 0
self._total_tokens = 0 self._total_tokens = 0
self._total_requests = 0
self._total_rejected = 0
self._total_timeouts = 0
self._request_latencies: List[float] = []
def _n_pages_for(self, n_tokens: int) -> int: def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size return (n_tokens + self.page_size - 1) // self.page_size
@ -141,6 +151,7 @@ class InferenceScheduler:
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None, stream_callback: Optional[Callable[[str], None]] = None,
timeout: Optional[float] = None,
) -> str: ) -> str:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt) prompt_ids = self.tokenizer.encode(prompt)
@ -156,9 +167,16 @@ class InferenceScheduler:
top_k=top_k, top_k=top_k,
stream_callback=stream_callback, stream_callback=stream_callback,
) )
task.deadline = time.time() + (
timeout if timeout is not None else self.request_timeout
)
with self._lock: with self._lock:
if len(self.waiting_queue) >= self.max_queue_size:
self._total_rejected += 1
raise RuntimeError("Request queue is full")
self.waiting_queue.append(task) self.waiting_queue.append(task)
self._total_requests += 1
self._total_tasks += 1 self._total_tasks += 1
self._task_event.set() self._task_event.set()
@ -181,6 +199,40 @@ class InferenceScheduler:
for idx in indices: for idx in indices:
self.page_cache.free(idx) self.page_cache.free(idx)
def _abort_task(self, task: Task) -> None:
task.status = TaskStatus.ABORTED
task.finish_time = time.time()
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
if task.stream_callback:
task.stream_callback(STOP)
def _abort_expired_tasks(self) -> None:
now = time.time()
alive = []
for t in self.active_tasks:
if now > t.deadline:
self._abort_task(t)
self._total_timeouts += 1
else:
alive.append(t)
self.active_tasks = alive
with self._lock:
keep = []
for t in self.waiting_queue:
if now > t.deadline:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
self._total_timeouts += 1
else:
keep.append(t)
self.waiting_queue = keep
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None: def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages): for i in range(start_logical_page, full_pages):
@ -194,6 +246,9 @@ class InferenceScheduler:
task.finish_time = time.time() task.finish_time = time.time()
finished.append(task) finished.append(task)
self._total_tokens += task.output_tokens self._total_tokens += task.output_tokens
self._request_latencies.append(task.finish_time - task.arrival_time)
if len(self._request_latencies) > 1000:
self._request_latencies.pop(0)
for task in finished: for task in finished:
if not task._pages_freed: if not task._pages_freed:
@ -345,14 +400,19 @@ class InferenceScheduler:
def _run_generation_loop(self) -> None: def _run_generation_loop(self) -> None:
try: try:
while self._running: while self._running or (self._draining and self.active_tasks):
self._abort_expired_tasks()
self._remove_finished_tasks() self._remove_finished_tasks()
self._refill_active_batch() if not self._draining:
self._refill_active_batch()
if not self.active_tasks and not self.waiting_queue: if not self.active_tasks:
self._task_event.clear() if self._draining:
self._task_event.wait(timeout=1.0) break
continue if not self.waiting_queue:
self._task_event.clear()
self._task_event.wait(timeout=1.0)
continue
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if to_prefill: if to_prefill:
@ -392,20 +452,54 @@ class InferenceScheduler:
t.start() t.start()
self._loop_thread = t self._loop_thread = t
def stop(self) -> None: def stop(self, timeout: float = 30.0) -> None:
self._draining = True
self._running = False self._running = False
self._task_event.set() self._task_event.set()
if hasattr(self, "_loop_thread"): if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=2.0) self._loop_thread.join(timeout=timeout)
self.waiting_queue.clear()
self.active_tasks.clear() for task in self.active_tasks:
if not task._pages_freed:
self._free_pages(task.page_table)
task._pages_freed = True
if task.stream_callback:
task.stream_callback(STOP)
with self._lock:
for task in self.waiting_queue:
task.status = TaskStatus.ABORTED
if task.stream_callback:
task.stream_callback(STOP)
self.waiting_queue.clear()
self.active_tasks.clear()
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
latencies = self._request_latencies
sorted_lat = sorted(latencies) if latencies else []
n = len(sorted_lat)
p50 = sorted_lat[n // 2] if n > 0 else 0.0
p95 = sorted_lat[int(n * 0.95)] if n > 0 else 0.0
p99 = sorted_lat[int(n * 0.99)] if n > 0 else 0.0
cache = self.page_cache
total_lookups = cache.lookup_hits + cache.lookup_misses
hit_rate = cache.lookup_hits / total_lookups if total_lookups > 0 else 0.0
return { return {
"total_tasks": self._total_tasks, "total_requests": self._total_requests,
"total_rejected": self._total_rejected,
"total_timeouts": self._total_timeouts,
"total_tokens": self._total_tokens, "total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks), "active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue), "waiting_queue": len(self.waiting_queue),
"latency_p50": p50,
"latency_p95": p95,
"latency_p99": p99,
"cache_hit_rate": hit_rate,
"cache_hits": cache.lookup_hits,
"cache_misses": cache.lookup_misses,
} }

View File

@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse from fastapi.responses import PlainTextResponse, StreamingResponse
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from astrai.inference.engine import InferenceEngine from astrai.inference.engine import InferenceEngine
@ -92,6 +92,8 @@ def configure_server(
dtype=dtype, dtype=dtype,
param_path=param_path, param_path=param_path,
max_batch_size=max_batch_size, max_batch_size=max_batch_size,
max_queue_size=64,
request_timeout=60.0,
) )
@ -185,6 +187,40 @@ async def get_stats():
return _get_engine().get_stats() return _get_engine().get_stats()
@app.get("/metrics")
async def metrics():
s = _get_engine().get_stats()
lines = [
"# HELP astrai_requests_total Total requests received",
"# TYPE astrai_requests_total counter",
f'astrai_requests_total{{status="accepted"}} {s["total_requests"]}',
f'astrai_requests_total{{status="rejected"}} {s["total_rejected"]}',
f'astrai_requests_total{{status="timeout"}} {s["total_timeouts"]}',
"# HELP astrai_tokens_generated Total generated tokens",
"# TYPE astrai_tokens_generated counter",
f"astrai_tokens_generated {s['total_tokens']}",
"# HELP astrai_active_tasks Currently active tasks",
"# TYPE astrai_active_tasks gauge",
f"astrai_active_tasks {s['active_tasks']}",
"# HELP astrai_queue_depth Waiting queue depth",
"# TYPE astrai_queue_depth gauge",
f"astrai_queue_depth {s['waiting_queue']}",
"# HELP astrai_request_latency_seconds Request latency quantiles",
"# TYPE astrai_request_latency_seconds gauge",
f'astrai_request_latency_seconds{{quantile="0.5"}} {s["latency_p50"]:.3f}',
f'astrai_request_latency_seconds{{quantile="0.95"}} {s["latency_p95"]:.3f}',
f'astrai_request_latency_seconds{{quantile="0.99"}} {s["latency_p99"]:.3f}',
"# HELP astrai_cache_hit_rate Prefix cache hit ratio",
"# TYPE astrai_cache_hit_rate gauge",
f"astrai_cache_hit_rate {s['cache_hit_rate']:.3f}",
"# HELP astrai_cache_lookups_total Prefix cache page lookups",
"# TYPE astrai_cache_lookups_total counter",
f'astrai_cache_lookups_total{{result="hit"}} {s["cache_hits"]}',
f'astrai_cache_lookups_total{{result="miss"}} {s["cache_misses"]}',
]
return PlainTextResponse("\n".join(lines) + "\n")
@app.post("/v1/chat/completions") @app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest): async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming).""" """OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
@ -200,13 +236,16 @@ async def chat_completion(request: ChatCompletionRequest):
prompt_tokens = len(engine.tokenizer.encode(prompt)) prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream: if request.stream:
agen = engine.generate_async( try:
prompt=prompt, agen = engine.generate_async(
max_tokens=request.max_tokens, prompt=prompt,
temperature=request.temperature, max_tokens=request.max_tokens,
top_p=request.top_p, temperature=request.temperature,
top_k=request.top_k, top_p=request.top_p,
) top_k=request.top_k,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
async def event_stream(): async def event_stream():
yield _make_chunk( yield _make_chunk(
@ -252,13 +291,16 @@ async def chat_completion(request: ChatCompletionRequest):
completion_tokens = 0 completion_tokens = 0
chunks: List[str] = [] chunks: List[str] = []
agen = engine.generate_async( try:
prompt=prompt, agen = engine.generate_async(
max_tokens=request.max_tokens, prompt=prompt,
temperature=request.temperature, max_tokens=request.max_tokens,
top_p=request.top_p, temperature=request.temperature,
top_k=request.top_k, top_p=request.top_p,
) top_k=request.top_k,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
async for token in agen: async for token in agen:
chunks.append(token) chunks.append(token)
completion_tokens += 1 completion_tokens += 1

View File

@ -173,5 +173,5 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
# Verify stats are consistent # Verify stats are consistent
for stats in results["stats"]: for stats in results["stats"]:
assert "total_tasks" in stats assert "total_requests" in stats
assert stats["total_tasks"] >= 0 assert stats["total_requests"] >= 0