feat: 服务化基础设施 - 有界队列/超时/优雅关闭/metrics
- astrai/inference/scheduler.py: 有界队列 (max_queue_size) 拒绝满时入队抛 RuntimeError
-> 请求超时检测 (deadline + _abort_expired_tasks),超时任务 abort 释放页并通知回调
-> stop() 改为 drain 模式:等待活跃任务自然结束再强制清理
-> get_stats() 扩展 latency P50/P95/P99 + cache hit rate
- astrai/inference/engine.py: generate/generate_async 新增 timeout 参数
-> _generate_streaming/_generate_non_streaming 捕获 add_task 异常并清理
- astrai/inference/server.py: 新增 /metrics 端点 (Prometheus 格式)
-> chat completions 端点捕获 RuntimeError 返回 503
-> configure_server 传递 max_queue_size/request_timeout
- astrai/inference/cache.py: 新增 lookup_hits/lookup_misses 计数器
- tests/: fix stats key total_tasks -> total_requests
This commit is contained in:
parent
3da428e0e4
commit
a3bde30fb1
|
|
@ -61,6 +61,8 @@ class PagedCache:
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
self._lru: List[int] = []
|
self._lru: List[int] = []
|
||||||
self._pin: List[bool] = [False] * n_pages
|
self._pin: List[bool] = [False] * n_pages
|
||||||
|
self.lookup_hits: int = 0
|
||||||
|
self.lookup_misses: int = 0
|
||||||
|
|
||||||
def _touch(self, idx: int) -> None:
|
def _touch(self, idx: int) -> None:
|
||||||
if self._refs[idx] == 0 and idx in self._lru:
|
if self._refs[idx] == 0 and idx in self._lru:
|
||||||
|
|
@ -98,7 +100,9 @@ class PagedCache:
|
||||||
h = page_hash(token_ids, i, self.page_size)
|
h = page_hash(token_ids, i, self.page_size)
|
||||||
p = self._hash_to_page.get(h)
|
p = self._hash_to_page.get(h)
|
||||||
if p is None:
|
if p is None:
|
||||||
|
self.lookup_misses += 1
|
||||||
break
|
break
|
||||||
|
self.lookup_hits += 1
|
||||||
self._touch(p)
|
self._touch(p)
|
||||||
hits.append(p)
|
hits.append(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
|
||||||
|
|
@ -195,6 +195,8 @@ class InferenceEngine:
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 1,
|
max_batch_size: int = 1,
|
||||||
|
max_queue_size: int = 64,
|
||||||
|
request_timeout: float = 60.0,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prompt_len: int = 2048,
|
max_prompt_len: int = 2048,
|
||||||
page_size: int = 128,
|
page_size: int = 128,
|
||||||
|
|
@ -207,7 +209,6 @@ class InferenceEngine:
|
||||||
max_batch_size: Maximum number of concurrent tasks.
|
max_batch_size: Maximum number of concurrent tasks.
|
||||||
max_seq_len: Maximum sequence length.
|
max_seq_len: Maximum sequence length.
|
||||||
max_prompt_len: Maximum prompt tokens.
|
max_prompt_len: Maximum prompt tokens.
|
||||||
compile: Whether to compile the model with torch.compile.
|
|
||||||
page_size: Number of tokens per KV cache page.
|
page_size: Number of tokens per KV cache page.
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
@ -216,6 +217,8 @@ class InferenceEngine:
|
||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
|
max_queue_size=max_queue_size,
|
||||||
|
request_timeout=request_timeout,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
max_prompt_len=max_prompt_len,
|
max_prompt_len=max_prompt_len,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
|
|
@ -238,6 +241,7 @@ class InferenceEngine:
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
) -> Union[Generator, str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
"""Generates text from a prompt.
|
"""Generates text from a prompt.
|
||||||
|
|
||||||
|
|
@ -248,6 +252,7 @@ class InferenceEngine:
|
||||||
temperature: Sampling temperature.
|
temperature: Sampling temperature.
|
||||||
top_p: Nucleus sampling probability threshold.
|
top_p: Nucleus sampling probability threshold.
|
||||||
top_k: Top-k sampling count (0 disables).
|
top_k: Top-k sampling count (0 disables).
|
||||||
|
timeout: Per-request timeout in seconds (None = use scheduler default).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
stream=False, single prompt: str
|
stream=False, single prompt: str
|
||||||
|
|
@ -260,11 +265,11 @@ class InferenceEngine:
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._generate_streaming(
|
return self._generate_streaming(
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._generate_non_streaming(
|
return self._generate_non_streaming(
|
||||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_async(
|
def generate_async(
|
||||||
|
|
@ -274,6 +279,7 @@ class InferenceEngine:
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""Async streaming generator that does not block the event loop.
|
"""Async streaming generator that does not block the event loop.
|
||||||
|
|
||||||
|
|
@ -286,12 +292,13 @@ class InferenceEngine:
|
||||||
temperature: Sampling temperature.
|
temperature: Sampling temperature.
|
||||||
top_p: Nucleus sampling threshold.
|
top_p: Nucleus sampling threshold.
|
||||||
top_k: Top-k sampling count.
|
top_k: Top-k sampling count.
|
||||||
|
timeout: Per-request timeout in seconds.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Decoded token strings as they are generated.
|
Decoded token strings as they are generated.
|
||||||
"""
|
"""
|
||||||
sync_gen = self._generate_streaming(
|
sync_gen = self._generate_streaming(
|
||||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
[prompt], False, max_tokens, temperature, top_p, top_k, timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _agen():
|
async def _agen():
|
||||||
|
|
@ -350,6 +357,7 @@ class InferenceEngine:
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
"""Internal streaming generator.
|
"""Internal streaming generator.
|
||||||
|
|
||||||
|
|
@ -363,6 +371,7 @@ class InferenceEngine:
|
||||||
temperature: Sampling temperature.
|
temperature: Sampling temperature.
|
||||||
top_p: Nucleus sampling threshold.
|
top_p: Nucleus sampling threshold.
|
||||||
top_k: Top-k sampling count.
|
top_k: Top-k sampling count.
|
||||||
|
timeout: Per-request timeout in seconds.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Single prompt: decoded token strings.
|
Single prompt: decoded token strings.
|
||||||
|
|
@ -372,16 +381,22 @@ class InferenceEngine:
|
||||||
result = _Result(count=n)
|
result = _Result(count=n)
|
||||||
task_ids = []
|
task_ids = []
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
try:
|
||||||
task_id = self.scheduler.add_task(
|
for i, p in enumerate(prompts):
|
||||||
prompt=p,
|
task_id = self.scheduler.add_task(
|
||||||
max_tokens=max_tokens,
|
prompt=p,
|
||||||
temperature=temperature,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_p=top_p,
|
||||||
stream_callback=lambda tok, idx=i: result.append(tok, idx),
|
top_k=top_k,
|
||||||
)
|
stream_callback=lambda tok, idx=i: result.append(tok, idx),
|
||||||
task_ids.append(task_id)
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
task_ids.append(task_id)
|
||||||
|
except RuntimeError:
|
||||||
|
for tid in task_ids:
|
||||||
|
self.scheduler.remove_task(tid)
|
||||||
|
raise
|
||||||
|
|
||||||
remaining = n
|
remaining = n
|
||||||
finished = [False] * n
|
finished = [False] * n
|
||||||
|
|
@ -415,6 +430,7 @@ class InferenceEngine:
|
||||||
temperature: float,
|
temperature: float,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
"""Internal non-streaming generator.
|
"""Internal non-streaming generator.
|
||||||
|
|
||||||
|
|
@ -427,6 +443,7 @@ class InferenceEngine:
|
||||||
temperature: Sampling temperature.
|
temperature: Sampling temperature.
|
||||||
top_p: Nucleus sampling threshold.
|
top_p: Nucleus sampling threshold.
|
||||||
top_k: Top-k sampling count.
|
top_k: Top-k sampling count.
|
||||||
|
timeout: Per-request timeout in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Single string for one prompt, list of strings for batch.
|
Single string for one prompt, list of strings for batch.
|
||||||
|
|
@ -434,20 +451,26 @@ class InferenceEngine:
|
||||||
result = _Result(count=len(prompts))
|
result = _Result(count=len(prompts))
|
||||||
task_ids = []
|
task_ids = []
|
||||||
|
|
||||||
for i, p in enumerate(prompts):
|
try:
|
||||||
|
for i, p in enumerate(prompts):
|
||||||
|
|
||||||
def make_cb(idx):
|
def make_cb(idx):
|
||||||
return lambda tok: result.append(tok, idx)
|
return lambda tok: result.append(tok, idx)
|
||||||
|
|
||||||
task_id = self.scheduler.add_task(
|
task_id = self.scheduler.add_task(
|
||||||
prompt=p,
|
prompt=p,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=make_cb(i),
|
stream_callback=make_cb(i),
|
||||||
)
|
timeout=timeout,
|
||||||
task_ids.append(task_id)
|
)
|
||||||
|
task_ids.append(task_id)
|
||||||
|
except RuntimeError:
|
||||||
|
for tid in task_ids:
|
||||||
|
self.scheduler.remove_task(tid)
|
||||||
|
raise
|
||||||
|
|
||||||
result.wait_completion()
|
result.wait_completion()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,7 @@ class Task:
|
||||||
self.n_pages: int = 0
|
self.n_pages: int = 0
|
||||||
self._prefix_cached_tokens: int = 0
|
self._prefix_cached_tokens: int = 0
|
||||||
self.arrival_time = time.time()
|
self.arrival_time = time.time()
|
||||||
|
self.deadline: float = 0.0
|
||||||
self.finish_time: Optional[float] = None
|
self.finish_time: Optional[float] = None
|
||||||
self.stream_callback = stream_callback
|
self.stream_callback = stream_callback
|
||||||
self._pages_freed: bool = False
|
self._pages_freed: bool = False
|
||||||
|
|
@ -86,6 +87,8 @@ class InferenceScheduler:
|
||||||
model: AutoModel,
|
model: AutoModel,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
|
max_queue_size: int = 64,
|
||||||
|
request_timeout: float = 60.0,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prompt_len: int = 512,
|
max_prompt_len: int = 512,
|
||||||
page_size: int = 64,
|
page_size: int = 64,
|
||||||
|
|
@ -97,6 +100,8 @@ class InferenceScheduler:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.max_batch_size = max_batch_size
|
self.max_batch_size = max_batch_size
|
||||||
|
self.max_queue_size = max_queue_size
|
||||||
|
self.request_timeout = request_timeout
|
||||||
self.max_seq_len = max_seq_len or config.max_len
|
self.max_seq_len = max_seq_len or config.max_len
|
||||||
self.max_prompt_len = max_prompt_len
|
self.max_prompt_len = max_prompt_len
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
|
|
@ -124,11 +129,16 @@ class InferenceScheduler:
|
||||||
self.active_tasks: List[Task] = []
|
self.active_tasks: List[Task] = []
|
||||||
|
|
||||||
self._running = False
|
self._running = False
|
||||||
|
self._draining = False
|
||||||
self._task_event = threading.Event()
|
self._task_event = threading.Event()
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
self._total_tasks = 0
|
self._total_tasks = 0
|
||||||
self._total_tokens = 0
|
self._total_tokens = 0
|
||||||
|
self._total_requests = 0
|
||||||
|
self._total_rejected = 0
|
||||||
|
self._total_timeouts = 0
|
||||||
|
self._request_latencies: List[float] = []
|
||||||
|
|
||||||
def _n_pages_for(self, n_tokens: int) -> int:
|
def _n_pages_for(self, n_tokens: int) -> int:
|
||||||
return (n_tokens + self.page_size - 1) // self.page_size
|
return (n_tokens + self.page_size - 1) // self.page_size
|
||||||
|
|
@ -141,6 +151,7 @@ class InferenceScheduler:
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
stream_callback: Optional[Callable[[str], None]] = None,
|
stream_callback: Optional[Callable[[str], None]] = None,
|
||||||
|
timeout: Optional[float] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
prompt_ids = self.tokenizer.encode(prompt)
|
prompt_ids = self.tokenizer.encode(prompt)
|
||||||
|
|
@ -156,9 +167,16 @@ class InferenceScheduler:
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream_callback=stream_callback,
|
stream_callback=stream_callback,
|
||||||
)
|
)
|
||||||
|
task.deadline = time.time() + (
|
||||||
|
timeout if timeout is not None else self.request_timeout
|
||||||
|
)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
|
if len(self.waiting_queue) >= self.max_queue_size:
|
||||||
|
self._total_rejected += 1
|
||||||
|
raise RuntimeError("Request queue is full")
|
||||||
self.waiting_queue.append(task)
|
self.waiting_queue.append(task)
|
||||||
|
self._total_requests += 1
|
||||||
self._total_tasks += 1
|
self._total_tasks += 1
|
||||||
|
|
||||||
self._task_event.set()
|
self._task_event.set()
|
||||||
|
|
@ -181,6 +199,40 @@ class InferenceScheduler:
|
||||||
for idx in indices:
|
for idx in indices:
|
||||||
self.page_cache.free(idx)
|
self.page_cache.free(idx)
|
||||||
|
|
||||||
|
def _abort_task(self, task: Task) -> None:
|
||||||
|
task.status = TaskStatus.ABORTED
|
||||||
|
task.finish_time = time.time()
|
||||||
|
if not task._pages_freed:
|
||||||
|
self._free_pages(task.page_table)
|
||||||
|
task.page_table.clear()
|
||||||
|
task.n_pages = 0
|
||||||
|
task._pages_freed = True
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
|
|
||||||
|
def _abort_expired_tasks(self) -> None:
|
||||||
|
now = time.time()
|
||||||
|
alive = []
|
||||||
|
for t in self.active_tasks:
|
||||||
|
if now > t.deadline:
|
||||||
|
self._abort_task(t)
|
||||||
|
self._total_timeouts += 1
|
||||||
|
else:
|
||||||
|
alive.append(t)
|
||||||
|
self.active_tasks = alive
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
keep = []
|
||||||
|
for t in self.waiting_queue:
|
||||||
|
if now > t.deadline:
|
||||||
|
t.status = TaskStatus.ABORTED
|
||||||
|
if t.stream_callback:
|
||||||
|
t.stream_callback(STOP)
|
||||||
|
self._total_timeouts += 1
|
||||||
|
else:
|
||||||
|
keep.append(t)
|
||||||
|
self.waiting_queue = keep
|
||||||
|
|
||||||
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
|
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
|
||||||
full_pages = len(task.prompt_ids) // self.page_size
|
full_pages = len(task.prompt_ids) // self.page_size
|
||||||
for i in range(start_logical_page, full_pages):
|
for i in range(start_logical_page, full_pages):
|
||||||
|
|
@ -194,6 +246,9 @@ class InferenceScheduler:
|
||||||
task.finish_time = time.time()
|
task.finish_time = time.time()
|
||||||
finished.append(task)
|
finished.append(task)
|
||||||
self._total_tokens += task.output_tokens
|
self._total_tokens += task.output_tokens
|
||||||
|
self._request_latencies.append(task.finish_time - task.arrival_time)
|
||||||
|
if len(self._request_latencies) > 1000:
|
||||||
|
self._request_latencies.pop(0)
|
||||||
|
|
||||||
for task in finished:
|
for task in finished:
|
||||||
if not task._pages_freed:
|
if not task._pages_freed:
|
||||||
|
|
@ -345,14 +400,19 @@ class InferenceScheduler:
|
||||||
|
|
||||||
def _run_generation_loop(self) -> None:
|
def _run_generation_loop(self) -> None:
|
||||||
try:
|
try:
|
||||||
while self._running:
|
while self._running or (self._draining and self.active_tasks):
|
||||||
|
self._abort_expired_tasks()
|
||||||
self._remove_finished_tasks()
|
self._remove_finished_tasks()
|
||||||
self._refill_active_batch()
|
if not self._draining:
|
||||||
|
self._refill_active_batch()
|
||||||
|
|
||||||
if not self.active_tasks and not self.waiting_queue:
|
if not self.active_tasks:
|
||||||
self._task_event.clear()
|
if self._draining:
|
||||||
self._task_event.wait(timeout=1.0)
|
break
|
||||||
continue
|
if not self.waiting_queue:
|
||||||
|
self._task_event.clear()
|
||||||
|
self._task_event.wait(timeout=1.0)
|
||||||
|
continue
|
||||||
|
|
||||||
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
|
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
|
||||||
if to_prefill:
|
if to_prefill:
|
||||||
|
|
@ -392,20 +452,54 @@ class InferenceScheduler:
|
||||||
t.start()
|
t.start()
|
||||||
self._loop_thread = t
|
self._loop_thread = t
|
||||||
|
|
||||||
def stop(self) -> None:
|
def stop(self, timeout: float = 30.0) -> None:
|
||||||
|
self._draining = True
|
||||||
self._running = False
|
self._running = False
|
||||||
self._task_event.set()
|
self._task_event.set()
|
||||||
if hasattr(self, "_loop_thread"):
|
if hasattr(self, "_loop_thread"):
|
||||||
self._loop_thread.join(timeout=2.0)
|
self._loop_thread.join(timeout=timeout)
|
||||||
self.waiting_queue.clear()
|
|
||||||
self.active_tasks.clear()
|
for task in self.active_tasks:
|
||||||
|
if not task._pages_freed:
|
||||||
|
self._free_pages(task.page_table)
|
||||||
|
task._pages_freed = True
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
for task in self.waiting_queue:
|
||||||
|
task.status = TaskStatus.ABORTED
|
||||||
|
if task.stream_callback:
|
||||||
|
task.stream_callback(STOP)
|
||||||
|
self.waiting_queue.clear()
|
||||||
|
self.active_tasks.clear()
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
latencies = self._request_latencies
|
||||||
|
sorted_lat = sorted(latencies) if latencies else []
|
||||||
|
n = len(sorted_lat)
|
||||||
|
p50 = sorted_lat[n // 2] if n > 0 else 0.0
|
||||||
|
p95 = sorted_lat[int(n * 0.95)] if n > 0 else 0.0
|
||||||
|
p99 = sorted_lat[int(n * 0.99)] if n > 0 else 0.0
|
||||||
|
|
||||||
|
cache = self.page_cache
|
||||||
|
total_lookups = cache.lookup_hits + cache.lookup_misses
|
||||||
|
hit_rate = cache.lookup_hits / total_lookups if total_lookups > 0 else 0.0
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_tasks": self._total_tasks,
|
"total_requests": self._total_requests,
|
||||||
|
"total_rejected": self._total_rejected,
|
||||||
|
"total_timeouts": self._total_timeouts,
|
||||||
"total_tokens": self._total_tokens,
|
"total_tokens": self._total_tokens,
|
||||||
"active_tasks": len(self.active_tasks),
|
"active_tasks": len(self.active_tasks),
|
||||||
"waiting_queue": len(self.waiting_queue),
|
"waiting_queue": len(self.waiting_queue),
|
||||||
|
"latency_p50": p50,
|
||||||
|
"latency_p95": p95,
|
||||||
|
"latency_p99": p99,
|
||||||
|
"cache_hit_rate": hit_rate,
|
||||||
|
"cache_hits": cache.lookup_hits,
|
||||||
|
"cache_misses": cache.lookup_misses,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, HTTPException
|
from fastapi import FastAPI, HTTPException
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import PlainTextResponse, StreamingResponse
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
@ -92,6 +92,8 @@ def configure_server(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
param_path=param_path,
|
param_path=param_path,
|
||||||
max_batch_size=max_batch_size,
|
max_batch_size=max_batch_size,
|
||||||
|
max_queue_size=64,
|
||||||
|
request_timeout=60.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -185,6 +187,40 @@ async def get_stats():
|
||||||
return _get_engine().get_stats()
|
return _get_engine().get_stats()
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/metrics")
|
||||||
|
async def metrics():
|
||||||
|
s = _get_engine().get_stats()
|
||||||
|
lines = [
|
||||||
|
"# HELP astrai_requests_total Total requests received",
|
||||||
|
"# TYPE astrai_requests_total counter",
|
||||||
|
f'astrai_requests_total{{status="accepted"}} {s["total_requests"]}',
|
||||||
|
f'astrai_requests_total{{status="rejected"}} {s["total_rejected"]}',
|
||||||
|
f'astrai_requests_total{{status="timeout"}} {s["total_timeouts"]}',
|
||||||
|
"# HELP astrai_tokens_generated Total generated tokens",
|
||||||
|
"# TYPE astrai_tokens_generated counter",
|
||||||
|
f"astrai_tokens_generated {s['total_tokens']}",
|
||||||
|
"# HELP astrai_active_tasks Currently active tasks",
|
||||||
|
"# TYPE astrai_active_tasks gauge",
|
||||||
|
f"astrai_active_tasks {s['active_tasks']}",
|
||||||
|
"# HELP astrai_queue_depth Waiting queue depth",
|
||||||
|
"# TYPE astrai_queue_depth gauge",
|
||||||
|
f"astrai_queue_depth {s['waiting_queue']}",
|
||||||
|
"# HELP astrai_request_latency_seconds Request latency quantiles",
|
||||||
|
"# TYPE astrai_request_latency_seconds gauge",
|
||||||
|
f'astrai_request_latency_seconds{{quantile="0.5"}} {s["latency_p50"]:.3f}',
|
||||||
|
f'astrai_request_latency_seconds{{quantile="0.95"}} {s["latency_p95"]:.3f}',
|
||||||
|
f'astrai_request_latency_seconds{{quantile="0.99"}} {s["latency_p99"]:.3f}',
|
||||||
|
"# HELP astrai_cache_hit_rate Prefix cache hit ratio",
|
||||||
|
"# TYPE astrai_cache_hit_rate gauge",
|
||||||
|
f"astrai_cache_hit_rate {s['cache_hit_rate']:.3f}",
|
||||||
|
"# HELP astrai_cache_lookups_total Prefix cache page lookups",
|
||||||
|
"# TYPE astrai_cache_lookups_total counter",
|
||||||
|
f'astrai_cache_lookups_total{{result="hit"}} {s["cache_hits"]}',
|
||||||
|
f'astrai_cache_lookups_total{{result="miss"}} {s["cache_misses"]}',
|
||||||
|
]
|
||||||
|
return PlainTextResponse("\n".join(lines) + "\n")
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
@app.post("/v1/chat/completions")
|
||||||
async def chat_completion(request: ChatCompletionRequest):
|
async def chat_completion(request: ChatCompletionRequest):
|
||||||
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
|
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
|
||||||
|
|
@ -200,13 +236,16 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
prompt_tokens = len(engine.tokenizer.encode(prompt))
|
||||||
|
|
||||||
if request.stream:
|
if request.stream:
|
||||||
agen = engine.generate_async(
|
try:
|
||||||
prompt=prompt,
|
agen = engine.generate_async(
|
||||||
max_tokens=request.max_tokens,
|
prompt=prompt,
|
||||||
temperature=request.temperature,
|
max_tokens=request.max_tokens,
|
||||||
top_p=request.top_p,
|
temperature=request.temperature,
|
||||||
top_k=request.top_k,
|
top_p=request.top_p,
|
||||||
)
|
top_k=request.top_k,
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise HTTPException(status_code=503, detail=str(e))
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
yield _make_chunk(
|
yield _make_chunk(
|
||||||
|
|
@ -252,13 +291,16 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
|
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
chunks: List[str] = []
|
chunks: List[str] = []
|
||||||
agen = engine.generate_async(
|
try:
|
||||||
prompt=prompt,
|
agen = engine.generate_async(
|
||||||
max_tokens=request.max_tokens,
|
prompt=prompt,
|
||||||
temperature=request.temperature,
|
max_tokens=request.max_tokens,
|
||||||
top_p=request.top_p,
|
temperature=request.temperature,
|
||||||
top_k=request.top_k,
|
top_p=request.top_p,
|
||||||
)
|
top_k=request.top_k,
|
||||||
|
)
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise HTTPException(status_code=503, detail=str(e))
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
completion_tokens += 1
|
completion_tokens += 1
|
||||||
|
|
|
||||||
|
|
@ -173,5 +173,5 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||||
|
|
||||||
# Verify stats are consistent
|
# Verify stats are consistent
|
||||||
for stats in results["stats"]:
|
for stats in results["stats"]:
|
||||||
assert "total_tasks" in stats
|
assert "total_requests" in stats
|
||||||
assert stats["total_tasks"] >= 0
|
assert stats["total_requests"] >= 0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue