Compare commits

..

No commits in common. "a3bde30fb1b40480f071e9111876ac141c813ec9" and "523eacf5fe80fc0a909599161f67e00b1371fb03" have entirely different histories.

5 changed files with 90 additions and 299 deletions

View File

@ -22,16 +22,14 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
class PagedCache:
"""Paged KV cache with page-table-indirected read/write and persistent prefix caching.
"""Paged KV cache with page-table-indirected read/write.
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
- LRU eviction for persistent cross-batch prefix caching
Pages with recorded hashes persist after refcount reaches 0 (pinned).
They are evicted via LRU only when alloc() finds no free pages.
Call :meth:`bind` to obtain a batch view for the attention layers.
"""
def __init__(
@ -59,26 +57,6 @@ class PagedCache:
)
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
self._lru: List[int] = []
self._pin: List[bool] = [False] * n_pages
self.lookup_hits: int = 0
self.lookup_misses: int = 0
def _touch(self, idx: int) -> None:
if self._refs[idx] == 0 and idx in self._lru:
self._lru.remove(idx)
self._lru.append(idx)
def _evict_one(self) -> int:
while self._lru:
idx = self._lru.pop(0)
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
self._pin[idx] = False
self._refs[idx] = 1
return idx
return -1
def record_page(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
@ -89,9 +67,6 @@ class PagedCache:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
self._pin[page_idx] = True
if page_idx in self._lru:
self._lru.remove(page_idx)
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
full_pages = len(token_ids) // self.page_size
@ -100,28 +75,21 @@ class PagedCache:
h = page_hash(token_ids, i, self.page_size)
p = self._hash_to_page.get(h)
if p is None:
self.lookup_misses += 1
break
self.lookup_hits += 1
self._touch(p)
hits.append(p)
return hits
def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1
if self._refs[idx] == 1 and idx in self._lru:
self._lru.remove(idx)
def alloc(self) -> int:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
if lsb == 0:
return -1
idx = lsb.bit_length() - 1
self._free_mask ^= lsb
self._refs[idx] = 1
if idx in self._lru:
self._lru.remove(idx)
return idx
return self._evict_one()
def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)]
@ -135,15 +103,10 @@ class PagedCache:
def free(self, idx: int) -> None:
self._refs[idx] -= 1
if self._refs[idx] == 0:
h = self._page_to_hash.get(idx)
if h is not None and self._pin[idx]:
self._lru.append(idx)
else:
self._free_mask |= 1 << idx
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
self._pin[idx] = False
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)

View File

@ -11,7 +11,7 @@ import asyncio
import gc
import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
@ -126,7 +126,7 @@ class _Result:
idx: Index of the generation task this token belongs to.
"""
with self._cond:
self.tokens.append((idx, token))
self.tokens.append(token)
if token is not STOP:
self.results[idx] += token
else:
@ -136,11 +136,11 @@ class _Result:
self._cond.notify_all()
self._event.set()
def pop_all(self) -> List[Tuple[int, str]]:
"""Returns and clears all accumulated (idx, token) pairs.
def pop_all(self) -> List[str]:
"""Returns and clears all accumulated tokens.
Returns:
List of (index, token_string) tuples since the last call.
List of token strings since the last call.
"""
with self._cond:
out = self.tokens.copy()
@ -195,8 +195,6 @@ class InferenceEngine:
model: nn.Module,
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_queue_size: int = 64,
request_timeout: float = 60.0,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 2048,
page_size: int = 128,
@ -209,6 +207,7 @@ class InferenceEngine:
max_batch_size: Maximum number of concurrent tasks.
max_seq_len: Maximum sequence length.
max_prompt_len: Maximum prompt tokens.
compile: Whether to compile the model with torch.compile.
page_size: Number of tokens per KV cache page.
"""
self.model = model
@ -217,8 +216,6 @@ class InferenceEngine:
model=self.model,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
max_queue_size=max_queue_size,
request_timeout=request_timeout,
max_seq_len=max_seq_len,
max_prompt_len=max_prompt_len,
page_size=page_size,
@ -241,35 +238,31 @@ class InferenceEngine:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
timeout: Optional[float] = None,
) -> Union[Generator, str, List[str]]:
) -> Union[Generator[str, None, None], str, List[str]]:
"""Generates text from a prompt.
Args:
prompt: Single string or list of strings for batch generation.
stream: If True, returns a generator yielding tokens.
stream: If True, returns a generator yielding tokens one by one.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability threshold.
top_k: Top-k sampling count (0 disables).
timeout: Per-request timeout in seconds (None = use scheduler default).
Returns:
stream=False, single prompt: str
stream=False, batch: List[str]
stream=True, single prompt: Generator[str, None, None]
stream=True, batch: Generator[Tuple[int, str], None, None]
Generator (stream=True), single string (non-stream, single prompt),
or list of strings (non-stream, batch prompts).
"""
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k, timeout
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async(
@ -279,7 +272,6 @@ class InferenceEngine:
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
timeout: Optional[float] = None,
) -> AsyncGenerator[str, None]:
"""Async streaming generator that does not block the event loop.
@ -292,13 +284,12 @@ class InferenceEngine:
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
timeout: Per-request timeout in seconds.
Yields:
Decoded token strings as they are generated.
"""
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k, timeout
[prompt], False, max_tokens, temperature, top_p, top_k
)
async def _agen():
@ -357,68 +348,49 @@ class InferenceEngine:
temperature: float,
top_p: float,
top_k: int,
timeout: Optional[float] = None,
) -> Generator:
) -> Generator[str, None, None]:
"""Internal streaming generator.
Polls the _Result accumulator in a loop, yielding tokens as they arrive.
Single prompt yields raw token strings; batch yields (idx, token) tuples.
Cleans up the scheduler task on GeneratorExit.
Args:
prompts: List of prompts.
is_batch: If True, yields (idx, token) tuples; else yields raw tokens.
prompts: List of prompts (only first is used; batch not yet supported).
is_batch: If True, raises NotImplementedError.
max_tokens: Maximum tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
timeout: Per-request timeout in seconds.
Yields:
Single prompt: decoded token strings.
Batch: (sequence_index, token_string) tuples.
Decoded token strings.
"""
n = len(prompts)
result = _Result(count=n)
task_ids = []
if is_batch:
raise NotImplementedError("Batch streaming not yet supported")
result = _Result()
try:
for i, p in enumerate(prompts):
task_id = self.scheduler.add_task(
prompt=p,
prompt=prompts[0],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=lambda tok, idx=i: result.append(tok, idx),
timeout=timeout,
stream_callback=lambda tok: result.append(tok, 0),
)
task_ids.append(task_id)
except RuntimeError:
for tid in task_ids:
self.scheduler.remove_task(tid)
raise
remaining = n
finished = [False] * n
def gen():
nonlocal remaining
try:
while remaining > 0:
items = result.pop_all()
for idx, token in items:
while True:
tokens = result.pop_all()
for token in tokens:
if token is STOP:
if not finished[idx]:
finished[idx] = True
remaining -= 1
else:
yield (idx, token) if is_batch else token
if remaining > 0:
return
yield token
if not result.wait(timeout=0.05):
pass
finally:
for tid in task_ids:
self.scheduler.remove_task(tid)
self.scheduler.remove_task(task_id)
return gen()
@ -430,7 +402,6 @@ class InferenceEngine:
temperature: float,
top_p: float,
top_k: int,
timeout: Optional[float] = None,
) -> Union[str, List[str]]:
"""Internal non-streaming generator.
@ -443,7 +414,6 @@ class InferenceEngine:
temperature: Sampling temperature.
top_p: Nucleus sampling threshold.
top_k: Top-k sampling count.
timeout: Per-request timeout in seconds.
Returns:
Single string for one prompt, list of strings for batch.
@ -451,7 +421,6 @@ class InferenceEngine:
result = _Result(count=len(prompts))
task_ids = []
try:
for i, p in enumerate(prompts):
def make_cb(idx):
@ -464,13 +433,8 @@ class InferenceEngine:
top_p=top_p,
top_k=top_k,
stream_callback=make_cb(i),
timeout=timeout,
)
task_ids.append(task_id)
except RuntimeError:
for tid in task_ids:
self.scheduler.remove_task(tid)
raise
result.wait_completion()

View File

@ -55,7 +55,6 @@ class Task:
self.n_pages: int = 0
self._prefix_cached_tokens: int = 0
self.arrival_time = time.time()
self.deadline: float = 0.0
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
self._pages_freed: bool = False
@ -87,8 +86,6 @@ class InferenceScheduler:
model: AutoModel,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_queue_size: int = 64,
request_timeout: float = 60.0,
max_seq_len: Optional[int] = None,
max_prompt_len: int = 512,
page_size: int = 64,
@ -100,8 +97,6 @@ class InferenceScheduler:
self.model = model
self.tokenizer = tokenizer
self.max_batch_size = max_batch_size
self.max_queue_size = max_queue_size
self.request_timeout = request_timeout
self.max_seq_len = max_seq_len or config.max_len
self.max_prompt_len = max_prompt_len
self.page_size = page_size
@ -129,16 +124,11 @@ class InferenceScheduler:
self.active_tasks: List[Task] = []
self._running = False
self._draining = False
self._task_event = threading.Event()
self._lock = threading.Lock()
self._total_tasks = 0
self._total_tokens = 0
self._total_requests = 0
self._total_rejected = 0
self._total_timeouts = 0
self._request_latencies: List[float] = []
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
@ -151,7 +141,6 @@ class InferenceScheduler:
top_p: float = 1.0,
top_k: int = 50,
stream_callback: Optional[Callable[[str], None]] = None,
timeout: Optional[float] = None,
) -> str:
task_id = f"task_{int(time.time())}_{uuid.uuid4().hex[:8]}"
prompt_ids = self.tokenizer.encode(prompt)
@ -167,16 +156,9 @@ class InferenceScheduler:
top_k=top_k,
stream_callback=stream_callback,
)
task.deadline = time.time() + (
timeout if timeout is not None else self.request_timeout
)
with self._lock:
if len(self.waiting_queue) >= self.max_queue_size:
self._total_rejected += 1
raise RuntimeError("Request queue is full")
self.waiting_queue.append(task)
self._total_requests += 1
self._total_tasks += 1
self._task_event.set()
@ -199,40 +181,6 @@ class InferenceScheduler:
for idx in indices:
self.page_cache.free(idx)
def _abort_task(self, task: Task) -> None:
task.status = TaskStatus.ABORTED
task.finish_time = time.time()
if not task._pages_freed:
self._free_pages(task.page_table)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
if task.stream_callback:
task.stream_callback(STOP)
def _abort_expired_tasks(self) -> None:
now = time.time()
alive = []
for t in self.active_tasks:
if now > t.deadline:
self._abort_task(t)
self._total_timeouts += 1
else:
alive.append(t)
self.active_tasks = alive
with self._lock:
keep = []
for t in self.waiting_queue:
if now > t.deadline:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
self._total_timeouts += 1
else:
keep.append(t)
self.waiting_queue = keep
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
@ -246,9 +194,6 @@ class InferenceScheduler:
task.finish_time = time.time()
finished.append(task)
self._total_tokens += task.output_tokens
self._request_latencies.append(task.finish_time - task.arrival_time)
if len(self._request_latencies) > 1000:
self._request_latencies.pop(0)
for task in finished:
if not task._pages_freed:
@ -400,16 +345,11 @@ class InferenceScheduler:
def _run_generation_loop(self) -> None:
try:
while self._running or (self._draining and self.active_tasks):
self._abort_expired_tasks()
while self._running:
self._remove_finished_tasks()
if not self._draining:
self._refill_active_batch()
if not self.active_tasks:
if self._draining:
break
if not self.waiting_queue:
if not self.active_tasks and not self.waiting_queue:
self._task_event.clear()
self._task_event.wait(timeout=1.0)
continue
@ -452,54 +392,20 @@ class InferenceScheduler:
t.start()
self._loop_thread = t
def stop(self, timeout: float = 30.0) -> None:
self._draining = True
def stop(self) -> None:
self._running = False
self._task_event.set()
if hasattr(self, "_loop_thread"):
self._loop_thread.join(timeout=timeout)
for task in self.active_tasks:
if not task._pages_freed:
self._free_pages(task.page_table)
task._pages_freed = True
if task.stream_callback:
task.stream_callback(STOP)
with self._lock:
for task in self.waiting_queue:
task.status = TaskStatus.ABORTED
if task.stream_callback:
task.stream_callback(STOP)
self._loop_thread.join(timeout=2.0)
self.waiting_queue.clear()
self.active_tasks.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def get_stats(self) -> Dict[str, Any]:
latencies = self._request_latencies
sorted_lat = sorted(latencies) if latencies else []
n = len(sorted_lat)
p50 = sorted_lat[n // 2] if n > 0 else 0.0
p95 = sorted_lat[int(n * 0.95)] if n > 0 else 0.0
p99 = sorted_lat[int(n * 0.99)] if n > 0 else 0.0
cache = self.page_cache
total_lookups = cache.lookup_hits + cache.lookup_misses
hit_rate = cache.lookup_hits / total_lookups if total_lookups > 0 else 0.0
return {
"total_requests": self._total_requests,
"total_rejected": self._total_rejected,
"total_timeouts": self._total_timeouts,
"total_tasks": self._total_tasks,
"total_tokens": self._total_tokens,
"active_tasks": len(self.active_tasks),
"waiting_queue": len(self.waiting_queue),
"latency_p50": p50,
"latency_p95": p95,
"latency_p99": p99,
"cache_hit_rate": hit_rate,
"cache_hits": cache.lookup_hits,
"cache_misses": cache.lookup_misses,
}

View File

@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Union
import torch
import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi.responses import PlainTextResponse, StreamingResponse
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from astrai.inference.engine import InferenceEngine
@ -92,8 +92,6 @@ def configure_server(
dtype=dtype,
param_path=param_path,
max_batch_size=max_batch_size,
max_queue_size=64,
request_timeout=60.0,
)
@ -187,40 +185,6 @@ async def get_stats():
return _get_engine().get_stats()
@app.get("/metrics")
async def metrics():
s = _get_engine().get_stats()
lines = [
"# HELP astrai_requests_total Total requests received",
"# TYPE astrai_requests_total counter",
f'astrai_requests_total{{status="accepted"}} {s["total_requests"]}',
f'astrai_requests_total{{status="rejected"}} {s["total_rejected"]}',
f'astrai_requests_total{{status="timeout"}} {s["total_timeouts"]}',
"# HELP astrai_tokens_generated Total generated tokens",
"# TYPE astrai_tokens_generated counter",
f"astrai_tokens_generated {s['total_tokens']}",
"# HELP astrai_active_tasks Currently active tasks",
"# TYPE astrai_active_tasks gauge",
f"astrai_active_tasks {s['active_tasks']}",
"# HELP astrai_queue_depth Waiting queue depth",
"# TYPE astrai_queue_depth gauge",
f"astrai_queue_depth {s['waiting_queue']}",
"# HELP astrai_request_latency_seconds Request latency quantiles",
"# TYPE astrai_request_latency_seconds gauge",
f'astrai_request_latency_seconds{{quantile="0.5"}} {s["latency_p50"]:.3f}',
f'astrai_request_latency_seconds{{quantile="0.95"}} {s["latency_p95"]:.3f}',
f'astrai_request_latency_seconds{{quantile="0.99"}} {s["latency_p99"]:.3f}',
"# HELP astrai_cache_hit_rate Prefix cache hit ratio",
"# TYPE astrai_cache_hit_rate gauge",
f"astrai_cache_hit_rate {s['cache_hit_rate']:.3f}",
"# HELP astrai_cache_lookups_total Prefix cache page lookups",
"# TYPE astrai_cache_lookups_total counter",
f'astrai_cache_lookups_total{{result="hit"}} {s["cache_hits"]}',
f'astrai_cache_lookups_total{{result="miss"}} {s["cache_misses"]}',
]
return PlainTextResponse("\n".join(lines) + "\n")
@app.post("/v1/chat/completions")
async def chat_completion(request: ChatCompletionRequest):
"""OpenAI-compatible chat completion endpoint (streaming + non-streaming)."""
@ -236,7 +200,6 @@ async def chat_completion(request: ChatCompletionRequest):
prompt_tokens = len(engine.tokenizer.encode(prompt))
if request.stream:
try:
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
@ -244,8 +207,6 @@ async def chat_completion(request: ChatCompletionRequest):
top_p=request.top_p,
top_k=request.top_k,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
async def event_stream():
yield _make_chunk(
@ -291,7 +252,6 @@ async def chat_completion(request: ChatCompletionRequest):
completion_tokens = 0
chunks: List[str] = []
try:
agen = engine.generate_async(
prompt=prompt,
max_tokens=request.max_tokens,
@ -299,8 +259,6 @@ async def chat_completion(request: ChatCompletionRequest):
top_p=request.top_p,
top_k=request.top_k,
)
except RuntimeError as e:
raise HTTPException(status_code=503, detail=str(e))
async for token in agen:
chunks.append(token)
completion_tokens += 1

View File

@ -173,5 +173,5 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
# Verify stats are consistent
for stats in results["stats"]:
assert "total_requests" in stats
assert stats["total_requests"] >= 0
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0