From a3bde30fb1b40480f071e9111876ac141c813ec9 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 10 May 2026 18:16:51 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=9C=8D=E5=8A=A1=E5=8C=96=E5=9F=BA?= =?UTF-8?q?=E7=A1=80=E8=AE=BE=E6=96=BD=20-=20=E6=9C=89=E7=95=8C=E9=98=9F?= =?UTF-8?q?=E5=88=97/=E8=B6=85=E6=97=B6/=E4=BC=98=E9=9B=85=E5=85=B3?= =?UTF-8?q?=E9=97=AD/metrics?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - astrai/inference/scheduler.py: 有界队列 (max_queue_size) 拒绝满时入队抛 RuntimeError -> 请求超时检测 (deadline + _abort_expired_tasks),超时任务 abort 释放页并通知回调 -> stop() 改为 drain 模式:等待活跃任务自然结束再强制清理 -> get_stats() 扩展 latency P50/P95/P99 + cache hit rate - astrai/inference/engine.py: generate/generate_async 新增 timeout 参数 -> _generate_streaming/_generate_non_streaming 捕获 add_task 异常并清理 - astrai/inference/server.py: 新增 /metrics 端点 (Prometheus 格式) -> chat completions 端点捕获 RuntimeError 返回 503 -> configure_server 传递 max_queue_size/request_timeout - astrai/inference/cache.py: 新增 lookup_hits/lookup_misses 计数器 - tests/: fix stats key total_tasks -> total_requests --- astrai/inference/cache.py | 4 + astrai/inference/engine.py | 75 +++++++---- astrai/inference/scheduler.py | 116 ++++++++++++++++-- astrai/inference/server.py | 72 ++++++++--- tests/inference/test_scheduler_concurrency.py | 4 +- 5 files changed, 217 insertions(+), 54 deletions(-) diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index f548bc8..34f404c 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -61,6 +61,8 @@ class PagedCache: self._hash_to_page: Dict[int, int] = {} self._lru: List[int] = [] self._pin: List[bool] = [False] * n_pages + self.lookup_hits: int = 0 + self.lookup_misses: int = 0 def _touch(self, idx: int) -> None: if self._refs[idx] == 0 and idx in self._lru: @@ -98,7 +100,9 @@ class PagedCache: h = page_hash(token_ids, i, self.page_size) p = self._hash_to_page.get(h) if p is None: + self.lookup_misses += 1 break + self.lookup_hits += 1 self._touch(p) hits.append(p) return hits diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 9c3c41c..04ae456 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -195,6 +195,8 @@ class InferenceEngine: model: nn.Module, tokenizer: AutoTokenizer, max_batch_size: int = 1, + max_queue_size: int = 64, + request_timeout: float = 60.0, max_seq_len: Optional[int] = None, max_prompt_len: int = 2048, page_size: int = 128, @@ -207,7 +209,6 @@ class InferenceEngine: max_batch_size: Maximum number of concurrent tasks. max_seq_len: Maximum sequence length. max_prompt_len: Maximum prompt tokens. - compile: Whether to compile the model with torch.compile. page_size: Number of tokens per KV cache page. """ self.model = model @@ -216,6 +217,8 @@ class InferenceEngine: model=self.model, tokenizer=self.tokenizer, max_batch_size=max_batch_size, + max_queue_size=max_queue_size, + request_timeout=request_timeout, max_seq_len=max_seq_len, max_prompt_len=max_prompt_len, page_size=page_size, @@ -238,6 +241,7 @@ class InferenceEngine: temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, + timeout: Optional[float] = None, ) -> Union[Generator, str, List[str]]: """Generates text from a prompt. @@ -248,6 +252,7 @@ class InferenceEngine: temperature: Sampling temperature. top_p: Nucleus sampling probability threshold. top_k: Top-k sampling count (0 disables). + timeout: Per-request timeout in seconds (None = use scheduler default). Returns: stream=False, single prompt: str @@ -260,11 +265,11 @@ class InferenceEngine: if stream: return self._generate_streaming( - prompts, is_batch, max_tokens, temperature, top_p, top_k + prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout ) else: return self._generate_non_streaming( - prompts, is_batch, max_tokens, temperature, top_p, top_k + prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout ) def generate_async( @@ -274,6 +279,7 @@ class InferenceEngine: temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, + timeout: Optional[float] = None, ) -> AsyncGenerator[str, None]: """Async streaming generator that does not block the event loop. @@ -286,12 +292,13 @@ class InferenceEngine: temperature: Sampling temperature. top_p: Nucleus sampling threshold. top_k: Top-k sampling count. + timeout: Per-request timeout in seconds. Yields: Decoded token strings as they are generated. """ sync_gen = self._generate_streaming( - [prompt], False, max_tokens, temperature, top_p, top_k + [prompt], False, max_tokens, temperature, top_p, top_k, timeout ) async def _agen(): @@ -350,6 +357,7 @@ class InferenceEngine: temperature: float, top_p: float, top_k: int, + timeout: Optional[float] = None, ) -> Generator: """Internal streaming generator. @@ -363,6 +371,7 @@ class InferenceEngine: temperature: Sampling temperature. top_p: Nucleus sampling threshold. top_k: Top-k sampling count. + timeout: Per-request timeout in seconds. Yields: Single prompt: decoded token strings. @@ -372,16 +381,22 @@ class InferenceEngine: result = _Result(count=n) task_ids = [] - for i, p in enumerate(prompts): - task_id = self.scheduler.add_task( - prompt=p, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream_callback=lambda tok, idx=i: result.append(tok, idx), - ) - task_ids.append(task_id) + try: + for i, p in enumerate(prompts): + task_id = self.scheduler.add_task( + prompt=p, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=lambda tok, idx=i: result.append(tok, idx), + timeout=timeout, + ) + task_ids.append(task_id) + except RuntimeError: + for tid in task_ids: + self.scheduler.remove_task(tid) + raise remaining = n finished = [False] * n @@ -415,6 +430,7 @@ class InferenceEngine: temperature: float, top_p: float, top_k: int, + timeout: Optional[float] = None, ) -> Union[str, List[str]]: """Internal non-streaming generator. @@ -427,6 +443,7 @@ class InferenceEngine: temperature: Sampling temperature. top_p: Nucleus sampling threshold. top_k: Top-k sampling count. + timeout: Per-request timeout in seconds. Returns: Single string for one prompt, list of strings for batch. @@ -434,20 +451,26 @@ class InferenceEngine: result = _Result(count=len(prompts)) task_ids = [] - for i, p in enumerate(prompts): + try: + for i, p in enumerate(prompts): - def make_cb(idx): - return lambda tok: result.append(tok, idx) + def make_cb(idx): + return lambda tok: result.append(tok, idx) - task_id = self.scheduler.add_task( - prompt=p, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - top_k=top_k, - stream_callback=make_cb(i), - ) - task_ids.append(task_id) + task_id = self.scheduler.add_task( + prompt=p, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream_callback=make_cb(i), + timeout=timeout, + ) + task_ids.append(task_id) + except RuntimeError: + for tid in task_ids: + self.scheduler.remove_task(tid) + raise result.wait_completion() diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index b81e72d..579af37 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -55,6 +55,7 @@ class Task: self.n_pages: int = 0 self._prefix_cached_tokens: int = 0 self.arrival_time = time.time() + self.deadline: float = 0.0 self.finish_time: Optional[float] = None self.stream_callback = stream_callback self._pages_freed: bool = False @@ -86,6 +87,8 @@ class InferenceScheduler: model: AutoModel, tokenizer: AutoTokenizer, max_batch_size: int = 16, + max_queue_size: int = 64, + request_timeout: float = 60.0, max_seq_len: Optional[int] = None, max_prompt_len: int = 512, page_size: int = 64, @@ -97,6 +100,8 @@ class InferenceScheduler: self.model = model self.tokenizer = tokenizer self.max_batch_size = max_batch_size + self.max_queue_size = max_queue_size + self.request_timeout = request_timeout self.max_seq_len = max_seq_len or config.max_len self.max_prompt_len = max_prompt_len self.page_size = page_size @@ -124,11 +129,16 @@ class InferenceScheduler: self.active_tasks: List[Task] = [] self._running = False + self._draining = False self._task_event = threading.Event() self._lock = threading.Lock() self._total_tasks = 0 self._total_tokens = 0 + self._total_requests = 0 + self._total_rejected = 0 + self._total_timeouts = 0 + self._request_latencies: List[float] = [] def _n_pages_for(self, n_tokens: int) -> int: return (n_tokens + self.page_size - 1) // self.page_size @@ -141,6 +151,7 @@ class InferenceScheduler: top_p: float = 1.0, top_k: int = 50, stream_callback: Optional[Callable[[str], None]] = None, + timeout: Optional[float] = None, ) -> str: task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}" prompt_ids = self.tokenizer.encode(prompt) @@ -156,9 +167,16 @@ class InferenceScheduler: top_k=top_k, stream_callback=stream_callback, ) + task.deadline = time.time() + ( + timeout if timeout is not None else self.request_timeout + ) with self._lock: + if len(self.waiting_queue) >= self.max_queue_size: + self._total_rejected += 1 + raise RuntimeError("Request queue is full") self.waiting_queue.append(task) + self._total_requests += 1 self._total_tasks += 1 self._task_event.set() @@ -181,6 +199,40 @@ class InferenceScheduler: for idx in indices: self.page_cache.free(idx) + def _abort_task(self, task: Task) -> None: + task.status = TaskStatus.ABORTED + task.finish_time = time.time() + if not task._pages_freed: + self._free_pages(task.page_table) + task.page_table.clear() + task.n_pages = 0 + task._pages_freed = True + if task.stream_callback: + task.stream_callback(STOP) + + def _abort_expired_tasks(self) -> None: + now = time.time() + alive = [] + for t in self.active_tasks: + if now > t.deadline: + self._abort_task(t) + self._total_timeouts += 1 + else: + alive.append(t) + self.active_tasks = alive + + with self._lock: + keep = [] + for t in self.waiting_queue: + if now > t.deadline: + t.status = TaskStatus.ABORTED + if t.stream_callback: + t.stream_callback(STOP) + self._total_timeouts += 1 + else: + keep.append(t) + self.waiting_queue = keep + def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None: full_pages = len(task.prompt_ids) // self.page_size for i in range(start_logical_page, full_pages): @@ -194,6 +246,9 @@ class InferenceScheduler: task.finish_time = time.time() finished.append(task) self._total_tokens += task.output_tokens + self._request_latencies.append(task.finish_time - task.arrival_time) + if len(self._request_latencies) > 1000: + self._request_latencies.pop(0) for task in finished: if not task._pages_freed: @@ -345,14 +400,19 @@ class InferenceScheduler: def _run_generation_loop(self) -> None: try: - while self._running: + while self._running or (self._draining and self.active_tasks): + self._abort_expired_tasks() self._remove_finished_tasks() - self._refill_active_batch() + if not self._draining: + self._refill_active_batch() - if not self.active_tasks and not self.waiting_queue: - self._task_event.clear() - self._task_event.wait(timeout=1.0) - continue + if not self.active_tasks: + if self._draining: + break + if not self.waiting_queue: + self._task_event.clear() + self._task_event.wait(timeout=1.0) + continue to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] if to_prefill: @@ -392,20 +452,54 @@ class InferenceScheduler: t.start() self._loop_thread = t - def stop(self) -> None: + def stop(self, timeout: float = 30.0) -> None: + self._draining = True self._running = False self._task_event.set() if hasattr(self, "_loop_thread"): - self._loop_thread.join(timeout=2.0) - self.waiting_queue.clear() - self.active_tasks.clear() + self._loop_thread.join(timeout=timeout) + + for task in self.active_tasks: + if not task._pages_freed: + self._free_pages(task.page_table) + task._pages_freed = True + if task.stream_callback: + task.stream_callback(STOP) + + with self._lock: + for task in self.waiting_queue: + task.status = TaskStatus.ABORTED + if task.stream_callback: + task.stream_callback(STOP) + self.waiting_queue.clear() + self.active_tasks.clear() + if torch.cuda.is_available(): torch.cuda.empty_cache() def get_stats(self) -> Dict[str, Any]: + latencies = self._request_latencies + sorted_lat = sorted(latencies) if latencies else [] + n = len(sorted_lat) + p50 = sorted_lat[n // 2] if n > 0 else 0.0 + p95 = sorted_lat[int(n * 0.95)] if n > 0 else 0.0 + p99 = sorted_lat[int(n * 0.99)] if n > 0 else 0.0 + + cache = self.page_cache + total_lookups = cache.lookup_hits + cache.lookup_misses + hit_rate = cache.lookup_hits / total_lookups if total_lookups > 0 else 0.0 + return { - "total_tasks": self._total_tasks, + "total_requests": self._total_requests, + "total_rejected": self._total_rejected, + "total_timeouts": self._total_timeouts, "total_tokens": self._total_tokens, "active_tasks": len(self.active_tasks), "waiting_queue": len(self.waiting_queue), + "latency_p50": p50, + "latency_p95": p95, + "latency_p99": p99, + "cache_hit_rate": hit_rate, + "cache_hits": cache.lookup_hits, + "cache_misses": cache.lookup_misses, } diff --git a/astrai/inference/server.py b/astrai/inference/server.py index 7216eaa..f5b61cb 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Union import torch import uvicorn from fastapi import FastAPI, HTTPException -from fastapi.responses import StreamingResponse +from fastapi.responses import PlainTextResponse, StreamingResponse from pydantic import BaseModel, Field from astrai.inference.engine import InferenceEngine @@ -92,6 +92,8 @@ def configure_server( dtype=dtype, param_path=param_path, max_batch_size=max_batch_size, + max_queue_size=64, + request_timeout=60.0, ) @@ -185,6 +187,40 @@ async def get_stats(): return _get_engine().get_stats() +@app.get("/metrics") +async def metrics(): + s = _get_engine().get_stats() + lines = [ + "# HELP astrai_requests_total Total requests received", + "# TYPE astrai_requests_total counter", + f'astrai_requests_total{{status="accepted"}} {s["total_requests"]}', + f'astrai_requests_total{{status="rejected"}} {s["total_rejected"]}', + f'astrai_requests_total{{status="timeout"}} {s["total_timeouts"]}', + "# HELP astrai_tokens_generated Total generated tokens", + "# TYPE astrai_tokens_generated counter", + f"astrai_tokens_generated {s['total_tokens']}", + "# HELP astrai_active_tasks Currently active tasks", + "# TYPE astrai_active_tasks gauge", + f"astrai_active_tasks {s['active_tasks']}", + "# HELP astrai_queue_depth Waiting queue depth", + "# TYPE astrai_queue_depth gauge", + f"astrai_queue_depth {s['waiting_queue']}", + "# HELP astrai_request_latency_seconds Request latency quantiles", + "# TYPE astrai_request_latency_seconds gauge", + f'astrai_request_latency_seconds{{quantile="0.5"}} {s["latency_p50"]:.3f}', + f'astrai_request_latency_seconds{{quantile="0.95"}} {s["latency_p95"]:.3f}', + f'astrai_request_latency_seconds{{quantile="0.99"}} {s["latency_p99"]:.3f}', + "# HELP astrai_cache_hit_rate Prefix cache hit ratio", + "# TYPE astrai_cache_hit_rate gauge", + f"astrai_cache_hit_rate {s['cache_hit_rate']:.3f}", + "# HELP astrai_cache_lookups_total Prefix cache page lookups", + "# TYPE astrai_cache_lookups_total counter", + f'astrai_cache_lookups_total{{result="hit"}} {s["cache_hits"]}', + f'astrai_cache_lookups_total{{result="miss"}} {s["cache_misses"]}', + ] + return PlainTextResponse("\n".join(lines) + "\n") + + @app.post("/v1/chat/completions") async def chat_completion(request: ChatCompletionRequest): """OpenAI-compatible chat completion endpoint (streaming + non-streaming).""" @@ -200,13 +236,16 @@ async def chat_completion(request: ChatCompletionRequest): prompt_tokens = len(engine.tokenizer.encode(prompt)) if request.stream: - agen = engine.generate_async( - prompt=prompt, - max_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - ) + try: + agen = engine.generate_async( + prompt=prompt, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + ) + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) async def event_stream(): yield _make_chunk( @@ -252,13 +291,16 @@ async def chat_completion(request: ChatCompletionRequest): completion_tokens = 0 chunks: List[str] = [] - agen = engine.generate_async( - prompt=prompt, - max_tokens=request.max_tokens, - temperature=request.temperature, - top_p=request.top_p, - top_k=request.top_k, - ) + try: + agen = engine.generate_async( + prompt=prompt, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + ) + except RuntimeError as e: + raise HTTPException(status_code=503, detail=str(e)) async for token in agen: chunks.append(token) completion_tokens += 1 diff --git a/tests/inference/test_scheduler_concurrency.py b/tests/inference/test_scheduler_concurrency.py index cc7a9d2..d49140c 100644 --- a/tests/inference/test_scheduler_concurrency.py +++ b/tests/inference/test_scheduler_concurrency.py @@ -173,5 +173,5 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer): # Verify stats are consistent for stats in results["stats"]: - assert "total_tasks" in stats - assert stats["total_tasks"] >= 0 + assert "total_requests" in stats + assert stats["total_requests"] >= 0