refactor: 重构 cache 和 inference 参数体系,分离存储与分配

- 合并 GenerationRequest/GenerationParams,统一 max_tokens 参数名
- PagePool/PrefixCache 分离为 Allocator + PrefixCache + PagePool
- 拆分 KV 存储为独立 Storage 类,PagedCache → KVCache,CacheView → KvcacheView
- Allocator.inc_ref 移除 LRU 防止竞争,Storage.write 增加负页防御
- Allocator/PrefixCache/TaskTable 加 threading.Lock 保证线程安全
- server.py uvicorn.run 改为传 app 对象修复导入错误
- benchmark.py 适配 KVCache 新 API
This commit is contained in:
ViperEkura 2026-05-14 19:47:11 +08:00
parent 18fe6e9339
commit 205b40bd28
13 changed files with 394 additions and 351 deletions

View File

@ -3,7 +3,7 @@
Layers: Layers:
- core/: Core inference loop (cache, executor, scheduler, task) - core/: Core inference loop (cache, executor, scheduler, task)
- api/: HTTP protocol handlers (OpenAI, Anthropic) - api/: HTTP protocol handlers (OpenAI, Anthropic)
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest) - engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) - sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
""" """
@ -22,12 +22,14 @@ from astrai.inference.api import (
) )
from astrai.inference.core import ( from astrai.inference.core import (
STOP, STOP,
CacheView, Allocator,
Executor, Executor,
InferenceScheduler, InferenceScheduler,
PagedCache, KVCache,
KvcacheView,
PagePool, PagePool,
PrefixCache, PrefixCache,
Storage,
Task, Task,
TaskManager, TaskManager,
TaskStatus, TaskStatus,
@ -35,7 +37,6 @@ from astrai.inference.core import (
page_hash, page_hash,
) )
from astrai.inference.engine import ( from astrai.inference.engine import (
GenerationParams,
GenerationRequest, GenerationRequest,
InferenceEngine, InferenceEngine,
) )
@ -52,7 +53,6 @@ __all__ = [
# Engine / Requests # Engine / Requests
"InferenceEngine", "InferenceEngine",
"GenerationRequest", "GenerationRequest",
"GenerationParams",
# Core scheduler # Core scheduler
"InferenceScheduler", "InferenceScheduler",
"Executor", "Executor",
@ -61,10 +61,12 @@ __all__ = [
"TaskManager", "TaskManager",
"TaskStatus", "TaskStatus",
# Core cache # Core cache
"CacheView", "Allocator",
"PagedCache", "KVCache",
"KvcacheView",
"PagePool", "PagePool",
"PrefixCache", "PrefixCache",
"Storage",
"TaskTable", "TaskTable",
"page_hash", "page_hash",
# Sampling (Strategy pattern) # Sampling (Strategy pattern)

View File

@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Union
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
from astrai.inference.engine import GenerationParams, InferenceEngine from astrai.inference.engine import InferenceEngine
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str: def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
@ -143,13 +143,13 @@ class ProtocolHandler(ABC):
prompt_tokens=self._count_prompt_tokens(), prompt_tokens=self._count_prompt_tokens(),
) )
params = GenerationParams( agen = self.engine.generate_async(
prompt=self.build_prompt(),
max_tokens=self.request.max_tokens, max_tokens=self.request.max_tokens,
temperature=self.request.temperature, temperature=self.request.temperature,
top_p=self.request.top_p, top_p=self.request.top_p,
top_k=self.request.top_k, top_k=self.request.top_k,
) )
agen = self.engine.generate_async(prompt=self.build_prompt(), params=params)
if self.request.stream: if self.request.stream:
return self._handle_stream(agen, ctx) return self._handle_stream(agen, ctx)

View File

@ -160,8 +160,7 @@ def run_server(
"max_batch_size": max_batch_size, "max_batch_size": max_batch_size,
} }
uvicorn.run( uvicorn.run(
"astrai.inference.server:app", app,
host=host, host=host,
port=port, port=port,
reload=reload,
) )

View File

@ -1,10 +1,12 @@
"""Inference core: cache, executor, scheduler, task management.""" """Inference core: cache, executor, scheduler, task management."""
from astrai.inference.core.cache import ( from astrai.inference.core.cache import (
CacheView, Allocator,
PagedCache, KVCache,
KvcacheView,
PagePool, PagePool,
PrefixCache, PrefixCache,
Storage,
TaskTable, TaskTable,
page_hash, page_hash,
) )
@ -13,10 +15,12 @@ from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
__all__ = [ __all__ = [
"CacheView", "Allocator",
"PagedCache", "KVCache",
"KvcacheView",
"PagePool", "PagePool",
"PrefixCache", "PrefixCache",
"Storage",
"TaskTable", "TaskTable",
"page_hash", "page_hash",
"Executor", "Executor",

View File

@ -1,3 +1,4 @@
import threading
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Optional, Tuple
@ -14,16 +15,18 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
return h return h
class PagePool: class Allocator:
"""Bitmask page allocator with ref-counting and LRU eviction.""" """Bitmask-based page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None): def __init__(self, n_pages: int):
self._free_mask = (1 << n_pages) - 1 self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict() self._lru: OrderedDict[int, None] = OrderedDict()
self._on_evict = on_evict self.on_evict: Optional[Callable[[int], None]] = None
self._lock = threading.Lock()
def alloc(self) -> int: def alloc(self) -> int:
with self._lock:
if self._free_mask: if self._free_mask:
lsb = self._free_mask & -self._free_mask lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1 idx = lsb.bit_length() - 1
@ -32,14 +35,15 @@ class PagePool:
return idx return idx
if self._lru: if self._lru:
idx, _ = self._lru.popitem(last=False) idx, _ = self._lru.popitem(last=False)
if self._on_evict: if self.on_evict:
self._on_evict(idx) self.on_evict(idx)
self._refs[idx] = 1 self._refs[idx] = 1
self._free_mask &= ~(1 << idx) self._free_mask &= ~(1 << idx)
return idx return idx
return -1 return -1
def free(self, idx: int, keep_cached: bool = False) -> None: def free(self, idx: int, keep_cached: bool = False) -> None:
with self._lock:
self._refs[idx] -= 1 self._refs[idx] -= 1
if self._refs[idx] == 0: if self._refs[idx] == 0:
if keep_cached: if keep_cached:
@ -48,14 +52,18 @@ class PagePool:
self._free_mask |= 1 << idx self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None: def inc_ref(self, idx: int) -> None:
with self._lock:
self._refs[idx] += 1 self._refs[idx] += 1
self._lru.pop(idx, None)
def ref_count(self, idx: int) -> int:
with self._lock:
return self._refs[idx]
def touch(self, idx: int) -> None: def touch(self, idx: int) -> None:
with self._lock:
self._lru.move_to_end(idx) self._lru.move_to_end(idx)
def remove_from_lru(self, idx: int) -> None:
self._lru.pop(idx, None)
class PrefixCache: class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices.""" """Hash-based prefix matching: maps page hashes to physical page indices."""
@ -64,16 +72,20 @@ class PrefixCache:
self._page_size = page_size self._page_size = page_size
self._page_to_hash: Dict[int, int] = {} self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def on_evict(self, idx: int) -> None: def evict(self, idx: int) -> None:
with self._lock:
h = self._page_to_hash.pop(idx, None) h = self._page_to_hash.pop(idx, None)
if h is not None: if h is not None:
self._hash_to_page.pop(h, None) self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool: def has_page(self, idx: int) -> bool:
with self._lock:
return idx in self._page_to_hash return idx in self._page_to_hash
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]: def lookup(self, token_ids: List[int]) -> List[int]:
with self._lock:
full_pages = len(token_ids) // self._page_size full_pages = len(token_ids) // self._page_size
hits: List[int] = [] hits: List[int] = []
for i in range(full_pages): for i in range(full_pages):
@ -81,24 +93,59 @@ class PrefixCache:
p = self._hash_to_page.get(h) p = self._hash_to_page.get(h)
if p is None: if p is None:
break break
pool.touch(p)
hits.append(p) hits.append(p)
return hits return hits
def record( def record(
self, self, page_idx: int, token_ids: List[int], logical_page_idx: int
page_idx: int,
token_ids: List[int],
logical_page_idx: int,
pool: PagePool,
) -> None: ) -> None:
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size) h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None) old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None: if old_h is not None:
self._hash_to_page.pop(old_h, None) self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx self._hash_to_page[h] = page_idx
pool.remove_from_lru(page_idx)
class PagePool:
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
def __init__(self, allocator: Allocator, prefix: PrefixCache):
self._alloc = allocator
self._prefix = prefix
self._alloc.on_evict = prefix.evict
@property
def allocator(self) -> Allocator:
return self._alloc
@property
def prefix(self) -> PrefixCache:
return self._prefix
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int) -> None:
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None:
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
hits = self._prefix.lookup(token_ids)
for p in hits:
self._alloc.touch(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx)
class TaskTable: class TaskTable:
@ -108,34 +155,41 @@ class TaskTable:
self._page_size = page_size self._page_size = page_size
self._pages: Dict[str, List[int]] = {} self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {} self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None: def set(self, task_id: str, page_table: List[int], cached: int) -> None:
with self._lock:
self._pages[task_id] = page_table self._pages[task_id] = page_table
self._cached[task_id] = cached self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]: def get(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.get(task_id, []) return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int: def get_cached(self, task_id: str) -> int:
with self._lock:
return self._cached.get(task_id, 0) return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]: def pop(self, task_id: str) -> Tuple[List[int], int]:
with self._lock:
pages = self._pages.pop(task_id, []) pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0) cached = self._cached.pop(task_id, 0)
return pages, cached return pages, cached
def get_ref(self, task_id: str) -> List[int]: def get_ref(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.setdefault(task_id, []) return self._pages.setdefault(task_id, [])
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
with self._lock:
states = [self._pages.get(tid, []) for tid in task_ids] states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0) max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states] rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device) return torch.tensor(rows, dtype=torch.long, device=device)
class PagedCache: class Storage:
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable.""" """KV-cache tensor storage with paged write/gather."""
def __init__( def __init__(
self, self,
@ -148,10 +202,6 @@ class PagedCache:
dtype: torch.dtype, dtype: torch.dtype,
): ):
self.page_size = page_size self.page_size = page_size
self._prefix = PrefixCache(page_size)
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
self._table = TaskTable(page_size)
self.k_cache = torch.empty( self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim), (n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device, device=device,
@ -163,80 +213,6 @@ class PagedCache:
dtype=dtype, dtype=dtype,
) )
def alloc_n(self, n: int) -> List[int]:
pages: List[int] = []
for _ in range(n):
p = self._pool.alloc()
if p < 0:
for page in pages:
self.free(page)
return []
pages.append(p)
return pages
def free(self, idx: int) -> None:
cached = self._prefix.has_page(idx)
self._pool.free(idx, keep_cached=cached)
if not cached:
self._prefix.on_evict(idx)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._prefix.lookup(prompt_ids, self._pool)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self.free(hp)
for np in new_pages:
self.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
def write( def write(
self, self,
layer_id: int, layer_id: int,
@ -259,6 +235,9 @@ class PagedCache:
write_end = min(page_start + page_size, start_pos + seq_len) write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start offset = write_start - page_start
chunk = write_end - write_start chunk = write_end - write_start
if (phys_pages < 0).any():
written += chunk
continue
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[ self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk :, written : written + chunk
] ]
@ -280,17 +259,95 @@ class PagedCache:
return k, v return k, v
class CacheView: class KvcacheView:
"""Bundles PagedCache + page_table + total_len for attention layers.""" """Bundles Storage + page_table + total_len for attention layers."""
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0): def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
self._cache = cache self._storage = storage
self._page_table = page_table self._page_table = page_table
self._total_len = total_len self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None: def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
start_pos = self._total_len - k.size(1) start_pos = self._total_len - k.size(1)
self._cache.write(layer_id, self._page_table, start_pos, k, v) self._storage.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._cache.gather(layer_id, self._page_table, self._total_len) return self._storage.gather(layer_id, self._page_table, self._total_len)
class KVCache:
"""Facade: page management + KV-cache I/O for continuous batching."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
self._table = TaskTable(page_size)
self._storage = Storage(
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._pool.lookup(prompt_ids)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self._pool.free(hp)
for np in new_pages:
self._pool.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._pool.record(page_table[i], prompt_ids, i)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
return KvcacheView(self._storage, page_table, total_len)

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import torch import torch
from astrai.inference.core.cache import PagedCache from astrai.inference.core.cache import KVCache
from astrai.inference.core.task import Task from astrai.inference.core.task import Task
from astrai.inference.sample import sample from astrai.inference.sample import sample
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
@ -19,7 +19,7 @@ class Executor:
self, self,
model: AutoModel, model: AutoModel,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
page_cache: PagedCache, page_cache: KVCache,
device: Optional[str] = None, device: Optional[str] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from astrai.inference.core.cache import PagedCache from astrai.inference.core.cache import KVCache
from astrai.inference.core.executor import Executor from astrai.inference.core.executor import Executor
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
@ -37,7 +37,7 @@ class InferenceScheduler:
max_batch_size * (self.max_seq_len + page_size) + page_size - 1 max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size ) // page_size
self._page_cache = PagedCache( self._page_cache = KVCache(
config.n_layers, config.n_layers,
n_pages, n_pages,
page_size, page_size,

View File

@ -28,7 +28,7 @@ class Task:
self, self,
task_id: str, task_id: str,
prompt_ids: List[int], prompt_ids: List[int],
max_tokens: int = 1024, max_tokens: Optional[int] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
@ -54,7 +54,7 @@ class Task:
return self.input_tokens + len(self.output_ids) return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool: def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens: if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
return True return True
if self.output_ids and self.output_ids[-1] in stop_ids: if self.output_ids and self.output_ids[-1] in stop_ids:
return True return True
@ -88,7 +88,7 @@ class TaskManager:
def add_task( def add_task(
self, self,
prompt: str, prompt: str,
max_tokens: int = 1024, max_tokens: Optional[int] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
@ -104,6 +104,9 @@ class TaskManager:
stream_callback(STOP) stream_callback(STOP)
return task_id return task_id
if max_tokens is None:
max_tokens = self.max_seq_len - len(prompt_ids)
else:
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids)) max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task( task = Task(

View File

@ -3,7 +3,6 @@
import asyncio import asyncio
import gc import gc
import threading import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch import torch
@ -14,6 +13,17 @@ from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
def _validate_sampling_params(
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerateResult: class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes.""" """Thread-safe token accumulator for streaming and non-streaming modes."""
@ -58,24 +68,6 @@ class GenerateResult:
return self.results.copy() return self.results.copy()
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
def __post_init__(self):
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerationRequest: class GenerationRequest:
"""Request parameters for text generation.""" """Request parameters for text generation."""
@ -85,34 +77,18 @@ class GenerationRequest:
top_k: int = 50, top_k: int = 50,
top_p: float = 1.0, top_p: float = 1.0,
temperature: float = 1.0, temperature: float = 1.0,
max_len: int = 1024, max_tokens: Optional[int] = None,
stream: bool = False, stream: bool = False,
): ):
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
self.messages = messages self.messages = messages
self.params = GenerationParams( self.top_k = top_k
top_k=top_k, self.top_p = top_p
top_p=top_p, self.temperature = temperature
temperature=temperature, self.max_tokens = max_tokens
max_tokens=max_len,
)
self.stream = stream self.stream = stream
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
class InferenceEngine: class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler.""" """Unified inference engine backed by continuous-batching scheduler."""
@ -150,37 +126,36 @@ class InferenceEngine:
self, self,
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
stream: bool = False, stream: bool = False,
max_tokens: int = 1024, max_tokens: Optional[int] = None,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> Union[Generator, str, List[str]]: ) -> Union[Generator, str, List[str]]:
params = GenerationParams( _validate_sampling_params(top_k, top_p, temperature, max_tokens)
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
)
is_batch = isinstance(prompt, list) is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt] prompts = prompt if is_batch else [prompt]
if stream: if stream:
return self._generate_streaming(prompts, is_batch, params) return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else: else:
return self._generate_non_streaming(prompts, is_batch, params) return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async( def generate_async(
self, self,
prompt: str, prompt: str,
params: Optional[GenerationParams] = None, max_tokens: Optional[int] = None,
max_tokens: int = 1024,
temperature: float = 1.0, temperature: float = 1.0,
top_p: float = 1.0, top_p: float = 1.0,
top_k: int = 50, top_k: int = 50,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
if params is None: _validate_sampling_params(top_k, top_p, temperature, max_tokens)
params = GenerationParams( sync_gen = self._generate_streaming(
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens [prompt], False, max_tokens, temperature, top_p, top_k
) )
sync_gen = self._generate_streaming([prompt], False, params)
async def _agen(): async def _agen():
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@ -206,14 +181,19 @@ class InferenceEngine:
return self.generate( return self.generate(
prompt=prompt, prompt=prompt,
stream=request.stream, stream=request.stream,
max_tokens=request.params.max_tokens, max_tokens=request.max_tokens,
temperature=request.params.temperature, temperature=request.temperature,
top_p=request.params.top_p, top_p=request.top_p,
top_k=request.params.top_k, top_k=request.top_k,
) )
def _submit_tasks( def _submit_tasks(
self, prompts: List[str], params: GenerationParams self,
prompts: List[str],
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Tuple[GenerateResult, List[str]]: ) -> Tuple[GenerateResult, List[str]]:
n = len(prompts) n = len(prompts)
result = GenerateResult(count=n) result = GenerateResult(count=n)
@ -222,10 +202,10 @@ class InferenceEngine:
cb = self._make_callback(result, i) cb = self._make_callback(result, i)
task_id = self.scheduler.add_task( task_id = self.scheduler.add_task(
prompt=p, prompt=p,
max_tokens=params.max_tokens, max_tokens=max_tokens,
temperature=params.temperature, temperature=temperature,
top_p=params.top_p, top_p=top_p,
top_k=params.top_k, top_k=top_k,
stream_callback=cb, stream_callback=cb,
) )
task_ids.append(task_id) task_ids.append(task_id)
@ -239,9 +219,17 @@ class InferenceEngine:
return cb return cb
def _generate_streaming( def _generate_streaming(
self, prompts: List[str], is_batch: bool, params: GenerationParams self,
prompts: List[str],
is_batch: bool,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Generator: ) -> Generator:
result, task_ids = self._submit_tasks(prompts, params) result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
n = len(prompts) n = len(prompts)
remaining = n remaining = n
finished = [False] * n finished = [False] * n
@ -267,9 +255,17 @@ class InferenceEngine:
return gen() return gen()
def _generate_non_streaming( def _generate_non_streaming(
self, prompts: List[str], is_batch: bool, params: GenerationParams self,
prompts: List[str],
is_batch: bool,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]: ) -> Union[str, List[str]]:
result, task_ids = self._submit_tasks(prompts, params) result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
result.wait_completion() result.wait_completion()

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.inference.core.cache import CacheView from astrai.inference.core.cache import KvcacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
@ -147,7 +147,7 @@ class GQA(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tensor, rotary_emb: Tensor,
attn_mask: Tensor = None, attn_mask: Tensor = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[KvcacheView] = None,
) -> Tensor: ) -> Tensor:
is_causal = attn_mask is None is_causal = attn_mask is None
@ -227,7 +227,7 @@ class MLA(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tensor, rotary_emb: Tensor,
attn_mask: Tensor = None, attn_mask: Tensor = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[KvcacheView] = None,
) -> Tensor: ) -> Tensor:
bsz, seq_len, _ = x.size() bsz, seq_len, _ = x.size()
is_causal = attn_mask is None is_causal = attn_mask is None
@ -306,7 +306,7 @@ class DecoderBlock(nn.Module):
x: Tensor, x: Tensor,
rotary_emb: Tensor, rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[KvcacheView] = None,
) -> Tensor: ) -> Tensor:
attn_output = self.attention( attn_output = self.attention(
self.input_norm(x), self.input_norm(x),

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from astrai.inference.core.cache import CacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel from astrai.model.automodel import AutoModel
from astrai.model.module import ( from astrai.model.module import (
DecoderBlock, DecoderBlock,
@ -122,7 +122,7 @@ class Transformer(AutoModel):
self, self,
input_ids: Tensor, input_ids: Tensor,
input_mask: Optional[Tensor] = None, input_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None, paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None, position_ids: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
assert input_ids.ndim == 2 assert input_ids.ndim == 2

View File

@ -1,4 +1,4 @@
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values).""" """Benchmark Transformer with KVCache"""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict from typing import Any, Dict
@ -6,7 +6,7 @@ from typing import Any, Dict
import torch import torch
from astrai.config import ModelConfig from astrai.config import ModelConfig
from astrai.inference import PagedCache from astrai.inference import KVCache
from astrai.model.transformer import Transformer from astrai.model.transformer import Transformer
@ -33,7 +33,7 @@ class GenerationBenchmark:
self.model.eval() self.model.eval()
head_dim = config.dim // config.n_heads head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size n_pages = (config.max_len * 4 + page_size - 1) // page_size
self._page_cache = PagedCache( self._page_cache = KVCache(
config.n_layers, config.n_layers,
n_pages, n_pages,
page_size, page_size,
@ -130,7 +130,12 @@ class GenerationBenchmark:
) )
n_pages = (prompt_length + gen_length + page_size - 1) // page_size n_pages = (prompt_length + gen_length + page_size - 1) // page_size
pages = self._page_cache.alloc_n(n_pages * batch_size) total = n_pages * batch_size
pages = []
for _ in range(total):
p = self._page_cache._pool.alloc()
assert p >= 0, "OOM"
pages.append(p)
page_table = torch.tensor( page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)], [pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long, dtype=torch.long,
@ -176,7 +181,7 @@ class GenerationBenchmark:
total_time += trial_time total_time += trial_time
for idx in pages: for idx in pages:
self._page_cache.free(idx) self._page_cache._pool.free(idx)
print( print(
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
@ -225,7 +230,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config) benchmark = GenerationBenchmark(config)
print("=" * 80) print("=" * 80)
print("Running Transformer Generation Benchmark (PagedCache)") print("Running Transformer Generation Benchmark (KVCache)")
print("=" * 80) print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark( prefill_result = benchmark.run_prefill_benchmark(

View File

@ -3,14 +3,20 @@
import torch import torch
from astrai.inference import ( from astrai.inference import (
PagedCache, Allocator,
KVCache,
PagePool, PagePool,
PrefixCache, PrefixCache,
Storage,
TaskTable, TaskTable,
page_hash, page_hash,
) )
def make_pool(n_pages: int, page_size: int) -> PagePool:
return PagePool(Allocator(n_pages), PrefixCache(page_size))
def test_page_hash_full_page(): def test_page_hash_full_page():
token_ids = list(range(256)) token_ids = list(range(256))
h = page_hash(token_ids, 0, 64) h = page_hash(token_ids, 0, 64)
@ -24,7 +30,7 @@ def test_page_hash_different_page_differs():
def test_page_pool_alloc_free_cycle(): def test_page_pool_alloc_free_cycle():
pool = PagePool(n_pages=4) pool = make_pool(4, 64)
a = pool.alloc() a = pool.alloc()
b = pool.alloc() b = pool.alloc()
assert a != b assert a != b
@ -35,130 +41,101 @@ def test_page_pool_alloc_free_cycle():
def test_page_pool_alloc_when_full(): def test_page_pool_alloc_when_full():
pool = PagePool(n_pages=2) pool = make_pool(2, 64)
pool.alloc() pool.alloc()
pool.alloc() pool.alloc()
assert pool.alloc() == -1 assert pool.alloc() == -1
def test_page_pool_lru_eviction(): def test_page_pool_lru_eviction():
evicted = [] pool = make_pool(2, 64)
def on_evict(idx):
evicted.append(idx)
pool = PagePool(n_pages=2, on_evict=on_evict)
p0 = pool.alloc() p0 = pool.alloc()
p1 = pool.alloc() p1 = pool.alloc()
pool.free(p0, keep_cached=True) pool.record(p0, list(range(64)), 0)
pool.free(p1, keep_cached=True) pool.record(p1, list(range(64, 128)), 0)
pool.free(p0)
pool.free(p1)
pool.alloc() pool.alloc()
assert len(evicted) == 1 assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
assert evicted[0] == p0
def test_page_pool_inc_ref_and_free(): def test_page_pool_inc_ref_and_free():
pool = PagePool(n_pages=2) pool = make_pool(2, 64)
p = pool.alloc() p = pool.alloc()
pool.inc_ref(p) pool.inc_ref(p)
assert pool._refs[p] == 2 assert pool._alloc._refs[p] == 2
pool.free(p) pool.free(p)
assert pool._refs[p] == 1 assert pool._alloc._refs[p] == 1
pool.free(p) pool.free(p)
assert pool._refs[p] == 0 assert pool._alloc._refs[p] == 0
def test_page_pool_touch_moves_to_end():
pool = PagePool(n_pages=4)
p0 = pool.alloc()
p1 = pool.alloc()
p2 = pool.alloc()
pool.free(p0, keep_cached=True)
pool.free(p1, keep_cached=True)
pool.free(p2, keep_cached=True)
assert next(iter(pool._lru)) == p0
pool.touch(p0)
assert next(reversed(pool._lru)) == p0
def test_page_pool_remove_from_lru():
pool = PagePool(n_pages=4)
p0 = pool.alloc()
pool.free(p0, keep_cached=True)
assert p0 in pool._lru
pool.remove_from_lru(p0)
assert p0 not in pool._lru
def test_page_pool_keep_cached_realloc(): def test_page_pool_keep_cached_realloc():
"""Free mask has priority over LRU; cached page returned only when no free pages.""" """Free mask has priority over LRU; cached page returned only when no free pages."""
pool = PagePool(n_pages=3) pool = make_pool(3, 64)
p0 = pool.alloc() p0 = pool.alloc()
p1 = pool.alloc() p1 = pool.alloc()
p2 = pool.alloc() p2 = pool.alloc()
pool.free(p0, keep_cached=True) for p in (p0, p1, p2):
pool.free(p1, keep_cached=True) pool.record(p, [p] * 64, 0)
pool.free(p2, keep_cached=True) pool.free(p0)
pool.free(p1)
pool.free(p2)
assert pool.alloc() == p0 assert pool.alloc() == p0
def _record_then_cache(pool, prefix, page, token_ids, logical_idx):
"""Simulate the real lifecycle: record → ref stays >0, then free cached returns to LRU."""
prefix.record(page, token_ids, logical_idx, pool)
pool.free(page, keep_cached=True)
def test_prefix_cache_lookup_returns_hits(): def test_prefix_cache_lookup_returns_hits():
token_ids = list(range(256)) token_ids = list(range(256))
pool = PagePool(n_pages=16) pool = make_pool(16, 64)
prefix = PrefixCache(page_size=64)
pages = [pool.alloc() for _ in range(4)] pages = [pool.alloc() for _ in range(4)]
for i, p in enumerate(pages): for i, p in enumerate(pages):
_record_then_cache(pool, prefix, p, token_ids, i) pool.record(p, token_ids, i)
hits = prefix.lookup(token_ids, pool) pool.free(p)
hits = pool.lookup(token_ids)
assert hits == pages assert hits == pages
def test_prefix_cache_lookup_stops_at_first_miss(): def test_prefix_cache_lookup_stops_at_first_miss():
token_ids = list(range(256)) token_ids = list(range(256))
pool = PagePool(n_pages=16) pool = make_pool(16, 64)
prefix = PrefixCache(page_size=64)
p0 = pool.alloc() p0 = pool.alloc()
_record_then_cache(pool, prefix, p0, token_ids, 0) pool.record(p0, token_ids, 0)
pool.free(p0)
p1 = pool.alloc() p1 = pool.alloc()
_record_then_cache(pool, prefix, p1, [99] * 64, 1) pool.record(p1, [99] * 64, 1)
hits = prefix.lookup(token_ids, pool) pool.free(p1)
hits = pool.lookup(token_ids)
assert len(hits) == 1 assert len(hits) == 1
assert hits[0] == p0 assert hits[0] == p0
def test_prefix_cache_ignores_partial_last_page(): def test_prefix_cache_ignores_partial_last_page():
token_ids = list(range(100)) token_ids = list(range(100))
pool = PagePool(n_pages=16) pool = make_pool(16, 64)
prefix = PrefixCache(page_size=64)
p = pool.alloc() p = pool.alloc()
_record_then_cache(pool, prefix, p, token_ids, 0) pool.record(p, token_ids, 0)
hits = prefix.lookup(token_ids, pool) pool.free(p)
hits = pool.lookup(token_ids)
assert len(hits) == 1 assert len(hits) == 1
def test_prefix_cache_on_evict_clears_mappings(): def test_prefix_cache_on_evict_clears_mappings():
pool = PagePool(n_pages=4) pool = make_pool(4, 64)
prefix = PrefixCache(page_size=64)
p = pool.alloc() p = pool.alloc()
_record_then_cache(pool, prefix, p, list(range(64)), 0) pool.record(p, list(range(64)), 0)
assert prefix.has_page(p) pool.free(p)
prefix.on_evict(p) assert p in pool._prefix._page_to_hash
assert not prefix.has_page(p) pool._prefix.evict(p)
assert p not in pool._prefix._page_to_hash
def test_prefix_cache_has_page(): def test_prefix_cache_has_page():
pool = PagePool(n_pages=4) pool = make_pool(4, 64)
prefix = PrefixCache(page_size=64)
p = pool.alloc() p = pool.alloc()
assert not prefix.has_page(p) assert p not in pool._prefix._page_to_hash
_record_then_cache(pool, prefix, p, list(range(64)), 0) pool.record(p, list(range(64)), 0)
assert prefix.has_page(p) pool.free(p)
assert p in pool._prefix._page_to_hash
def test_task_table_set_get(): def test_task_table_set_get():
@ -183,8 +160,8 @@ def test_task_table_pop():
assert table.get("task1") == [] assert table.get("task1") == []
def test_paged_cache_task_extend_allocates(): def test_kv_cache_task_extend_allocates():
cache = PagedCache( cache = KVCache(
n_layers=1, n_layers=1,
n_pages=8, n_pages=8,
page_size=64, page_size=64,
@ -199,8 +176,8 @@ def test_paged_cache_task_extend_allocates():
assert len(cache._table.get("task1")) == 4 assert len(cache._table.get("task1")) == 4
def test_paged_cache_task_extend_fails_when_pool_full(): def test_kv_cache_task_extend_fails_when_pool_full():
cache = PagedCache( cache = KVCache(
n_layers=1, n_layers=1,
n_pages=2, n_pages=2,
page_size=64, page_size=64,
@ -230,8 +207,8 @@ def test_task_table_table_tensor_empty_input():
assert t.numel() == 0 assert t.numel() == 0
def test_paged_cache_write_gather_single_page(): def test_storage_write_gather_single_page():
cache = PagedCache( storage = Storage(
n_layers=2, n_layers=2,
n_pages=8, n_pages=8,
page_size=4, page_size=4,
@ -244,13 +221,13 @@ def test_paged_cache_write_gather_single_page():
k = torch.randn(1, 2, 2, 8) k = torch.randn(1, 2, 2, 8)
v = torch.randn(1, 2, 2, 8) v = torch.randn(1, 2, 2, 8)
cache.write(0, page_table, 0, k, v) storage.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 2) gk, gv = storage.gather(0, page_table, 2)
assert torch.allclose(gk, k) assert torch.allclose(gk, k)
def test_paged_cache_write_cross_page(): def test_storage_write_cross_page():
cache = PagedCache( storage = Storage(
n_layers=1, n_layers=1,
n_pages=8, n_pages=8,
page_size=4, page_size=4,
@ -263,13 +240,13 @@ def test_paged_cache_write_cross_page():
k = torch.randn(1, 8, 2, 8) k = torch.randn(1, 8, 2, 8)
v = torch.randn(1, 8, 2, 8) v = torch.randn(1, 8, 2, 8)
cache.write(0, page_table, 0, k, v) storage.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 8) gk, gv = storage.gather(0, page_table, 8)
assert torch.allclose(gk, k) assert torch.allclose(gk, k)
def test_paged_cache_gather_truncates_to_total_len(): def test_storage_gather_truncates_to_total_len():
cache = PagedCache( storage = Storage(
n_layers=1, n_layers=1,
n_pages=8, n_pages=8,
page_size=4, page_size=4,
@ -281,14 +258,14 @@ def test_paged_cache_gather_truncates_to_total_len():
page_table = torch.tensor([[0, 1]], dtype=torch.long) page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 6, 2, 8) k = torch.randn(1, 6, 2, 8)
v = torch.randn(1, 6, 2, 8) v = torch.randn(1, 6, 2, 8)
cache.write(0, page_table, 0, k, v) storage.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 5) gk, gv = storage.gather(0, page_table, 5)
assert gk.shape == (1, 5, 2, 8) assert gk.shape == (1, 5, 2, 8)
def test_paged_cache_gather_clamps_negative_padding(): def test_storage_gather_clamps_negative_padding():
cache = PagedCache( storage = Storage(
n_layers=1, n_layers=1,
n_pages=8, n_pages=8,
page_size=4, page_size=4,
@ -298,5 +275,5 @@ def test_paged_cache_gather_clamps_negative_padding():
dtype=torch.float32, dtype=torch.float32,
) )
page_table = torch.tensor([[0, -1]], dtype=torch.long) page_table = torch.tensor([[0, -1]], dtype=torch.long)
gk, gv = cache.gather(0, page_table, 4) gk, gv = storage.gather(0, page_table, 4)
assert gk.shape == (1, 4, 2, 8) assert gk.shape == (1, 4, 2, 8)