refactor: 重构 cache 和 inference 参数体系,分离存储与分配

- 合并 GenerationRequest/GenerationParams,统一 max_tokens 参数名
- PagePool/PrefixCache 分离为 Allocator + PrefixCache + PagePool
- 拆分 KV 存储为独立 Storage 类,PagedCache → KVCache,CacheView → KvcacheView
- Allocator.inc_ref 移除 LRU 防止竞争,Storage.write 增加负页防御
- Allocator/PrefixCache/TaskTable 加 threading.Lock 保证线程安全
- server.py uvicorn.run 改为传 app 对象修复导入错误
- benchmark.py 适配 KVCache 新 API
This commit is contained in:
ViperEkura 2026-05-14 19:47:11 +08:00
parent 18fe6e9339
commit 205b40bd28
13 changed files with 394 additions and 351 deletions

View File

@ -3,7 +3,7 @@
Layers:
- core/: Core inference loop (cache, executor, scheduler, task)
- api/: HTTP protocol handlers (OpenAI, Anthropic)
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
"""
@ -22,12 +22,14 @@ from astrai.inference.api import (
)
from astrai.inference.core import (
STOP,
CacheView,
Allocator,
Executor,
InferenceScheduler,
PagedCache,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
Task,
TaskManager,
TaskStatus,
@ -35,7 +37,6 @@ from astrai.inference.core import (
page_hash,
)
from astrai.inference.engine import (
GenerationParams,
GenerationRequest,
InferenceEngine,
)
@ -52,7 +53,6 @@ __all__ = [
# Engine / Requests
"InferenceEngine",
"GenerationRequest",
"GenerationParams",
# Core scheduler
"InferenceScheduler",
"Executor",
@ -61,10 +61,12 @@ __all__ = [
"TaskManager",
"TaskStatus",
# Core cache
"CacheView",
"PagedCache",
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
# Sampling (Strategy pattern)

View File

@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Union
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from astrai.inference.engine import GenerationParams, InferenceEngine
from astrai.inference.engine import InferenceEngine
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
@ -143,13 +143,13 @@ class ProtocolHandler(ABC):
prompt_tokens=self._count_prompt_tokens(),
)
params = GenerationParams(
agen = self.engine.generate_async(
prompt=self.build_prompt(),
max_tokens=self.request.max_tokens,
temperature=self.request.temperature,
top_p=self.request.top_p,
top_k=self.request.top_k,
)
agen = self.engine.generate_async(prompt=self.build_prompt(), params=params)
if self.request.stream:
return self._handle_stream(agen, ctx)

View File

@ -160,8 +160,7 @@ def run_server(
"max_batch_size": max_batch_size,
}
uvicorn.run(
"astrai.inference.server:app",
app,
host=host,
port=port,
reload=reload,
)

View File

@ -1,10 +1,12 @@
"""Inference core: cache, executor, scheduler, task management."""
from astrai.inference.core.cache import (
CacheView,
PagedCache,
Allocator,
KVCache,
KvcacheView,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
@ -13,10 +15,12 @@ from astrai.inference.core.scheduler import InferenceScheduler
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
__all__ = [
"CacheView",
"PagedCache",
"Allocator",
"KVCache",
"KvcacheView",
"PagePool",
"PrefixCache",
"Storage",
"TaskTable",
"page_hash",
"Executor",

View File

@ -1,3 +1,4 @@
import threading
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple
@ -14,16 +15,18 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
return h
class PagePool:
"""Bitmask page allocator with ref-counting and LRU eviction."""
class Allocator:
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
def __init__(self, n_pages: int):
self._free_mask = (1 << n_pages) - 1
self._refs: List[int] = [0] * n_pages
self._lru: OrderedDict[int, None] = OrderedDict()
self._on_evict = on_evict
self.on_evict: Optional[Callable[[int], None]] = None
self._lock = threading.Lock()
def alloc(self) -> int:
with self._lock:
if self._free_mask:
lsb = self._free_mask & -self._free_mask
idx = lsb.bit_length() - 1
@ -32,14 +35,15 @@ class PagePool:
return idx
if self._lru:
idx, _ = self._lru.popitem(last=False)
if self._on_evict:
self._on_evict(idx)
if self.on_evict:
self.on_evict(idx)
self._refs[idx] = 1
self._free_mask &= ~(1 << idx)
return idx
return -1
def free(self, idx: int, keep_cached: bool = False) -> None:
with self._lock:
self._refs[idx] -= 1
if self._refs[idx] == 0:
if keep_cached:
@ -48,14 +52,18 @@ class PagePool:
self._free_mask |= 1 << idx
def inc_ref(self, idx: int) -> None:
with self._lock:
self._refs[idx] += 1
self._lru.pop(idx, None)
def ref_count(self, idx: int) -> int:
with self._lock:
return self._refs[idx]
def touch(self, idx: int) -> None:
with self._lock:
self._lru.move_to_end(idx)
def remove_from_lru(self, idx: int) -> None:
self._lru.pop(idx, None)
class PrefixCache:
"""Hash-based prefix matching: maps page hashes to physical page indices."""
@ -64,16 +72,20 @@ class PrefixCache:
self._page_size = page_size
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
self._lock = threading.Lock()
def on_evict(self, idx: int) -> None:
def evict(self, idx: int) -> None:
with self._lock:
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def has_page(self, idx: int) -> bool:
with self._lock:
return idx in self._page_to_hash
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
def lookup(self, token_ids: List[int]) -> List[int]:
with self._lock:
full_pages = len(token_ids) // self._page_size
hits: List[int] = []
for i in range(full_pages):
@ -81,24 +93,59 @@ class PrefixCache:
p = self._hash_to_page.get(h)
if p is None:
break
pool.touch(p)
hits.append(p)
return hits
def record(
self,
page_idx: int,
token_ids: List[int],
logical_page_idx: int,
pool: PagePool,
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
with self._lock:
h = page_hash(token_ids, logical_page_idx, self._page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
pool.remove_from_lru(page_idx)
class PagePool:
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
def __init__(self, allocator: Allocator, prefix: PrefixCache):
self._alloc = allocator
self._prefix = prefix
self._alloc.on_evict = prefix.evict
@property
def allocator(self) -> Allocator:
return self._alloc
@property
def prefix(self) -> PrefixCache:
return self._prefix
def alloc(self) -> int:
return self._alloc.alloc()
def free(self, idx: int) -> None:
keep = self._prefix.has_page(idx)
self._alloc.free(idx, keep_cached=keep)
if not keep:
self._prefix.evict(idx)
def inc_ref(self, idx: int) -> None:
self._alloc.inc_ref(idx)
def lookup(self, token_ids: List[int]) -> List[int]:
hits = self._prefix.lookup(token_ids)
for p in hits:
self._alloc.touch(p)
return hits
def record(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
self._prefix.record(page_idx, token_ids, logical_page_idx)
class TaskTable:
@ -108,34 +155,41 @@ class TaskTable:
self._page_size = page_size
self._pages: Dict[str, List[int]] = {}
self._cached: Dict[str, int] = {}
self._lock = threading.Lock()
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
with self._lock:
self._pages[task_id] = page_table
self._cached[task_id] = cached
def get(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.get(task_id, [])
def get_cached(self, task_id: str) -> int:
with self._lock:
return self._cached.get(task_id, 0)
def pop(self, task_id: str) -> Tuple[List[int], int]:
with self._lock:
pages = self._pages.pop(task_id, [])
cached = self._cached.pop(task_id, 0)
return pages, cached
def get_ref(self, task_id: str) -> List[int]:
with self._lock:
return self._pages.setdefault(task_id, [])
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
with self._lock:
states = [self._pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
class PagedCache:
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
class Storage:
"""KV-cache tensor storage with paged write/gather."""
def __init__(
self,
@ -148,10 +202,6 @@ class PagedCache:
dtype: torch.dtype,
):
self.page_size = page_size
self._prefix = PrefixCache(page_size)
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
self._table = TaskTable(page_size)
self.k_cache = torch.empty(
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
device=device,
@ -163,80 +213,6 @@ class PagedCache:
dtype=dtype,
)
def alloc_n(self, n: int) -> List[int]:
pages: List[int] = []
for _ in range(n):
p = self._pool.alloc()
if p < 0:
for page in pages:
self.free(page)
return []
pages.append(p)
return pages
def free(self, idx: int) -> None:
cached = self._prefix.has_page(idx)
self._pool.free(idx, keep_cached=cached)
if not cached:
self._prefix.on_evict(idx)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._prefix.lookup(prompt_ids, self._pool)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self.free(hp)
for np in new_pages:
self.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)
def write(
self,
layer_id: int,
@ -259,6 +235,9 @@ class PagedCache:
write_end = min(page_start + page_size, start_pos + seq_len)
offset = write_start - page_start
chunk = write_end - write_start
if (phys_pages < 0).any():
written += chunk
continue
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
:, written : written + chunk
]
@ -280,17 +259,95 @@ class PagedCache:
return k, v
class CacheView:
"""Bundles PagedCache + page_table + total_len for attention layers."""
class KvcacheView:
"""Bundles Storage + page_table + total_len for attention layers."""
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
self._cache = cache
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
self._storage = storage
self._page_table = page_table
self._total_len = total_len
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
start_pos = self._total_len - k.size(1)
self._cache.write(layer_id, self._page_table, start_pos, k, v)
self._storage.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
return self._cache.gather(layer_id, self._page_table, self._total_len)
return self._storage.gather(layer_id, self._page_table, self._total_len)
class KVCache:
"""Facade: page management + KV-cache I/O for continuous batching."""
def __init__(
self,
n_layers: int,
n_pages: int,
page_size: int,
n_kv_heads: int,
head_dim: int,
device: torch.device,
dtype: torch.dtype,
):
self.page_size = page_size
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
self._table = TaskTable(page_size)
self._storage = Storage(
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
)
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hits = self._pool.lookup(prompt_ids)
cached = len(hits) * self.page_size
for p in hits:
self._pool.inc_ref(p)
remaining = len(prompt_ids) - cached
n_new = (
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
)
new_pages: List[int] = []
if n_new > 0:
for _ in range(n_new):
p = self._pool.alloc()
if p < 0:
for hp in hits:
self._pool.free(hp)
for np in new_pages:
self._pool.free(np)
return False
new_pages.append(p)
self._table.set(task_id, hits + new_pages, cached)
return True
def task_free(self, task_id: str) -> None:
page_table, _ = self._table.pop(task_id)
for idx in page_table:
self._pool.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
page_table = self._table.get(task_id)
needed = (pos + 1 + self.page_size - 1) // self.page_size
while len(page_table) < needed:
p = self._pool.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._table.get_cached(task_id)
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._table.get(task_id)
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self._pool.record(page_table[i], prompt_ids, i)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
return self._table.table_tensor(task_ids, device)
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
return KvcacheView(self._storage, page_table, total_len)

View File

@ -3,7 +3,7 @@ from typing import List, Optional
import torch
from astrai.inference.core.cache import PagedCache
from astrai.inference.core.cache import KVCache
from astrai.inference.core.task import Task
from astrai.inference.sample import sample
from astrai.model.automodel import AutoModel
@ -19,7 +19,7 @@ class Executor:
self,
model: AutoModel,
tokenizer: AutoTokenizer,
page_cache: PagedCache,
page_cache: KVCache,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):

View File

@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from astrai.inference.core.cache import PagedCache
from astrai.inference.core.cache import KVCache
from astrai.inference.core.executor import Executor
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
from astrai.model.automodel import AutoModel
@ -37,7 +37,7 @@ class InferenceScheduler:
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
) // page_size
self._page_cache = PagedCache(
self._page_cache = KVCache(
config.n_layers,
n_pages,
page_size,

View File

@ -28,7 +28,7 @@ class Task:
self,
task_id: str,
prompt_ids: List[int],
max_tokens: int = 1024,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
@ -54,7 +54,7 @@ class Task:
return self.input_tokens + len(self.output_ids)
def is_finished(self, stop_ids: List[int]) -> bool:
if self.output_tokens >= self.max_tokens:
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
return True
if self.output_ids and self.output_ids[-1] in stop_ids:
return True
@ -88,7 +88,7 @@ class TaskManager:
def add_task(
self,
prompt: str,
max_tokens: int = 1024,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
@ -104,6 +104,9 @@ class TaskManager:
stream_callback(STOP)
return task_id
if max_tokens is None:
max_tokens = self.max_seq_len - len(prompt_ids)
else:
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
task = Task(

View File

@ -3,7 +3,6 @@
import asyncio
import gc
import threading
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
import torch
@ -14,6 +13,17 @@ from astrai.inference.core.task import STOP
from astrai.tokenize import AutoTokenizer
def _validate_sampling_params(
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
):
if not (isinstance(top_k, int) and top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(temperature, (int, float)) and temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerateResult:
"""Thread-safe token accumulator for streaming and non-streaming modes."""
@ -58,24 +68,6 @@ class GenerateResult:
return self.results.copy()
@dataclass(frozen=True)
class GenerationParams:
"""Immutable value object for sampling hyperparameters."""
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
max_tokens: int = 1024
def __post_init__(self):
if not (isinstance(self.top_k, int) and self.top_k >= 0):
raise ValueError("top_k must be a non-negative integer")
if not (0.0 <= self.top_p <= 1.0):
raise ValueError("top_p must be a float between 0.0 and 1.0")
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
raise ValueError("temperature must be a non-negative number")
class GenerationRequest:
"""Request parameters for text generation."""
@ -85,34 +77,18 @@ class GenerationRequest:
top_k: int = 50,
top_p: float = 1.0,
temperature: float = 1.0,
max_len: int = 1024,
max_tokens: Optional[int] = None,
stream: bool = False,
):
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
self.messages = messages
self.params = GenerationParams(
top_k=top_k,
top_p=top_p,
temperature=temperature,
max_tokens=max_len,
)
self.top_k = top_k
self.top_p = top_p
self.temperature = temperature
self.max_tokens = max_tokens
self.stream = stream
@property
def top_k(self) -> int:
return self.params.top_k
@property
def top_p(self) -> float:
return self.params.top_p
@property
def temperature(self) -> float:
return self.params.temperature
@property
def max_len(self) -> int:
return self.params.max_tokens
class InferenceEngine:
"""Unified inference engine backed by continuous-batching scheduler."""
@ -150,37 +126,36 @@ class InferenceEngine:
self,
prompt: Union[str, List[str]],
stream: bool = False,
max_tokens: int = 1024,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> Union[Generator, str, List[str]]:
params = GenerationParams(
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
)
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
is_batch = isinstance(prompt, list)
prompts = prompt if is_batch else [prompt]
if stream:
return self._generate_streaming(prompts, is_batch, params)
return self._generate_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
else:
return self._generate_non_streaming(prompts, is_batch, params)
return self._generate_non_streaming(
prompts, is_batch, max_tokens, temperature, top_p, top_k
)
def generate_async(
self,
prompt: str,
params: Optional[GenerationParams] = None,
max_tokens: int = 1024,
max_tokens: Optional[int] = None,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = 50,
) -> AsyncGenerator[str, None]:
if params is None:
params = GenerationParams(
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
sync_gen = self._generate_streaming(
[prompt], False, max_tokens, temperature, top_p, top_k
)
sync_gen = self._generate_streaming([prompt], False, params)
async def _agen():
loop = asyncio.get_event_loop()
@ -206,14 +181,19 @@ class InferenceEngine:
return self.generate(
prompt=prompt,
stream=request.stream,
max_tokens=request.params.max_tokens,
temperature=request.params.temperature,
top_p=request.params.top_p,
top_k=request.params.top_k,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
)
def _submit_tasks(
self, prompts: List[str], params: GenerationParams
self,
prompts: List[str],
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Tuple[GenerateResult, List[str]]:
n = len(prompts)
result = GenerateResult(count=n)
@ -222,10 +202,10 @@ class InferenceEngine:
cb = self._make_callback(result, i)
task_id = self.scheduler.add_task(
prompt=p,
max_tokens=params.max_tokens,
temperature=params.temperature,
top_p=params.top_p,
top_k=params.top_k,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream_callback=cb,
)
task_ids.append(task_id)
@ -239,9 +219,17 @@ class InferenceEngine:
return cb
def _generate_streaming(
self, prompts: List[str], is_batch: bool, params: GenerationParams
self,
prompts: List[str],
is_batch: bool,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Generator:
result, task_ids = self._submit_tasks(prompts, params)
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
n = len(prompts)
remaining = n
finished = [False] * n
@ -267,9 +255,17 @@ class InferenceEngine:
return gen()
def _generate_non_streaming(
self, prompts: List[str], is_batch: bool, params: GenerationParams
self,
prompts: List[str],
is_batch: bool,
max_tokens: Optional[int],
temperature: float,
top_p: float,
top_k: int,
) -> Union[str, List[str]]:
result, task_ids = self._submit_tasks(prompts, params)
result, task_ids = self._submit_tasks(
prompts, max_tokens, temperature, top_p, top_k
)
result.wait_completion()

View File

@ -5,7 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.inference.core.cache import CacheView
from astrai.inference.core.cache import KvcacheView
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
@ -147,7 +147,7 @@ class GQA(nn.Module):
x: Tensor,
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[CacheView] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
is_causal = attn_mask is None
@ -227,7 +227,7 @@ class MLA(nn.Module):
x: Tensor,
rotary_emb: Tensor,
attn_mask: Tensor = None,
paged_cache: Optional[CacheView] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
bsz, seq_len, _ = x.size()
is_causal = attn_mask is None
@ -306,7 +306,7 @@ class DecoderBlock(nn.Module):
x: Tensor,
rotary_emb: Tensor,
attention_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
paged_cache: Optional[KvcacheView] = None,
) -> Tensor:
attn_output = self.attention(
self.input_norm(x),

View File

@ -5,7 +5,7 @@ import torch.nn as nn
from torch import Tensor
from astrai.config.model_config import ModelConfig
from astrai.inference.core.cache import CacheView
from astrai.inference.core.cache import KvcacheView
from astrai.model.automodel import AutoModel
from astrai.model.module import (
DecoderBlock,
@ -122,7 +122,7 @@ class Transformer(AutoModel):
self,
input_ids: Tensor,
input_mask: Optional[Tensor] = None,
paged_cache: Optional[CacheView] = None,
paged_cache: Optional[KvcacheView] = None,
position_ids: Optional[Tensor] = None,
) -> Tensor:
assert input_ids.ndim == 2

View File

@ -1,4 +1,4 @@
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
"""Benchmark Transformer with KVCache"""
from dataclasses import dataclass
from typing import Any, Dict
@ -6,7 +6,7 @@ from typing import Any, Dict
import torch
from astrai.config import ModelConfig
from astrai.inference import PagedCache
from astrai.inference import KVCache
from astrai.model.transformer import Transformer
@ -33,7 +33,7 @@ class GenerationBenchmark:
self.model.eval()
head_dim = config.dim // config.n_heads
n_pages = (config.max_len * 4 + page_size - 1) // page_size
self._page_cache = PagedCache(
self._page_cache = KVCache(
config.n_layers,
n_pages,
page_size,
@ -130,7 +130,12 @@ class GenerationBenchmark:
)
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
pages = self._page_cache.alloc_n(n_pages * batch_size)
total = n_pages * batch_size
pages = []
for _ in range(total):
p = self._page_cache._pool.alloc()
assert p >= 0, "OOM"
pages.append(p)
page_table = torch.tensor(
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
dtype=torch.long,
@ -176,7 +181,7 @@ class GenerationBenchmark:
total_time += trial_time
for idx in pages:
self._page_cache.free(idx)
self._page_cache._pool.free(idx)
print(
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
@ -225,7 +230,7 @@ if __name__ == "__main__":
benchmark = GenerationBenchmark(config)
print("=" * 80)
print("Running Transformer Generation Benchmark (PagedCache)")
print("Running Transformer Generation Benchmark (KVCache)")
print("=" * 80)
prefill_result = benchmark.run_prefill_benchmark(

View File

@ -3,14 +3,20 @@
import torch
from astrai.inference import (
PagedCache,
Allocator,
KVCache,
PagePool,
PrefixCache,
Storage,
TaskTable,
page_hash,
)
def make_pool(n_pages: int, page_size: int) -> PagePool:
return PagePool(Allocator(n_pages), PrefixCache(page_size))
def test_page_hash_full_page():
token_ids = list(range(256))
h = page_hash(token_ids, 0, 64)
@ -24,7 +30,7 @@ def test_page_hash_different_page_differs():
def test_page_pool_alloc_free_cycle():
pool = PagePool(n_pages=4)
pool = make_pool(4, 64)
a = pool.alloc()
b = pool.alloc()
assert a != b
@ -35,130 +41,101 @@ def test_page_pool_alloc_free_cycle():
def test_page_pool_alloc_when_full():
pool = PagePool(n_pages=2)
pool = make_pool(2, 64)
pool.alloc()
pool.alloc()
assert pool.alloc() == -1
def test_page_pool_lru_eviction():
evicted = []
def on_evict(idx):
evicted.append(idx)
pool = PagePool(n_pages=2, on_evict=on_evict)
pool = make_pool(2, 64)
p0 = pool.alloc()
p1 = pool.alloc()
pool.free(p0, keep_cached=True)
pool.free(p1, keep_cached=True)
pool.record(p0, list(range(64)), 0)
pool.record(p1, list(range(64, 128)), 0)
pool.free(p0)
pool.free(p1)
pool.alloc()
assert len(evicted) == 1
assert evicted[0] == p0
assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
def test_page_pool_inc_ref_and_free():
pool = PagePool(n_pages=2)
pool = make_pool(2, 64)
p = pool.alloc()
pool.inc_ref(p)
assert pool._refs[p] == 2
assert pool._alloc._refs[p] == 2
pool.free(p)
assert pool._refs[p] == 1
assert pool._alloc._refs[p] == 1
pool.free(p)
assert pool._refs[p] == 0
def test_page_pool_touch_moves_to_end():
pool = PagePool(n_pages=4)
p0 = pool.alloc()
p1 = pool.alloc()
p2 = pool.alloc()
pool.free(p0, keep_cached=True)
pool.free(p1, keep_cached=True)
pool.free(p2, keep_cached=True)
assert next(iter(pool._lru)) == p0
pool.touch(p0)
assert next(reversed(pool._lru)) == p0
def test_page_pool_remove_from_lru():
pool = PagePool(n_pages=4)
p0 = pool.alloc()
pool.free(p0, keep_cached=True)
assert p0 in pool._lru
pool.remove_from_lru(p0)
assert p0 not in pool._lru
assert pool._alloc._refs[p] == 0
def test_page_pool_keep_cached_realloc():
"""Free mask has priority over LRU; cached page returned only when no free pages."""
pool = PagePool(n_pages=3)
pool = make_pool(3, 64)
p0 = pool.alloc()
p1 = pool.alloc()
p2 = pool.alloc()
pool.free(p0, keep_cached=True)
pool.free(p1, keep_cached=True)
pool.free(p2, keep_cached=True)
for p in (p0, p1, p2):
pool.record(p, [p] * 64, 0)
pool.free(p0)
pool.free(p1)
pool.free(p2)
assert pool.alloc() == p0
def _record_then_cache(pool, prefix, page, token_ids, logical_idx):
"""Simulate the real lifecycle: record → ref stays >0, then free cached returns to LRU."""
prefix.record(page, token_ids, logical_idx, pool)
pool.free(page, keep_cached=True)
def test_prefix_cache_lookup_returns_hits():
token_ids = list(range(256))
pool = PagePool(n_pages=16)
prefix = PrefixCache(page_size=64)
pool = make_pool(16, 64)
pages = [pool.alloc() for _ in range(4)]
for i, p in enumerate(pages):
_record_then_cache(pool, prefix, p, token_ids, i)
hits = prefix.lookup(token_ids, pool)
pool.record(p, token_ids, i)
pool.free(p)
hits = pool.lookup(token_ids)
assert hits == pages
def test_prefix_cache_lookup_stops_at_first_miss():
token_ids = list(range(256))
pool = PagePool(n_pages=16)
prefix = PrefixCache(page_size=64)
pool = make_pool(16, 64)
p0 = pool.alloc()
_record_then_cache(pool, prefix, p0, token_ids, 0)
pool.record(p0, token_ids, 0)
pool.free(p0)
p1 = pool.alloc()
_record_then_cache(pool, prefix, p1, [99] * 64, 1)
hits = prefix.lookup(token_ids, pool)
pool.record(p1, [99] * 64, 1)
pool.free(p1)
hits = pool.lookup(token_ids)
assert len(hits) == 1
assert hits[0] == p0
def test_prefix_cache_ignores_partial_last_page():
token_ids = list(range(100))
pool = PagePool(n_pages=16)
prefix = PrefixCache(page_size=64)
pool = make_pool(16, 64)
p = pool.alloc()
_record_then_cache(pool, prefix, p, token_ids, 0)
hits = prefix.lookup(token_ids, pool)
pool.record(p, token_ids, 0)
pool.free(p)
hits = pool.lookup(token_ids)
assert len(hits) == 1
def test_prefix_cache_on_evict_clears_mappings():
pool = PagePool(n_pages=4)
prefix = PrefixCache(page_size=64)
pool = make_pool(4, 64)
p = pool.alloc()
_record_then_cache(pool, prefix, p, list(range(64)), 0)
assert prefix.has_page(p)
prefix.on_evict(p)
assert not prefix.has_page(p)
pool.record(p, list(range(64)), 0)
pool.free(p)
assert p in pool._prefix._page_to_hash
pool._prefix.evict(p)
assert p not in pool._prefix._page_to_hash
def test_prefix_cache_has_page():
pool = PagePool(n_pages=4)
prefix = PrefixCache(page_size=64)
pool = make_pool(4, 64)
p = pool.alloc()
assert not prefix.has_page(p)
_record_then_cache(pool, prefix, p, list(range(64)), 0)
assert prefix.has_page(p)
assert p not in pool._prefix._page_to_hash
pool.record(p, list(range(64)), 0)
pool.free(p)
assert p in pool._prefix._page_to_hash
def test_task_table_set_get():
@ -183,8 +160,8 @@ def test_task_table_pop():
assert table.get("task1") == []
def test_paged_cache_task_extend_allocates():
cache = PagedCache(
def test_kv_cache_task_extend_allocates():
cache = KVCache(
n_layers=1,
n_pages=8,
page_size=64,
@ -199,8 +176,8 @@ def test_paged_cache_task_extend_allocates():
assert len(cache._table.get("task1")) == 4
def test_paged_cache_task_extend_fails_when_pool_full():
cache = PagedCache(
def test_kv_cache_task_extend_fails_when_pool_full():
cache = KVCache(
n_layers=1,
n_pages=2,
page_size=64,
@ -230,8 +207,8 @@ def test_task_table_table_tensor_empty_input():
assert t.numel() == 0
def test_paged_cache_write_gather_single_page():
cache = PagedCache(
def test_storage_write_gather_single_page():
storage = Storage(
n_layers=2,
n_pages=8,
page_size=4,
@ -244,13 +221,13 @@ def test_paged_cache_write_gather_single_page():
k = torch.randn(1, 2, 2, 8)
v = torch.randn(1, 2, 2, 8)
cache.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 2)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 2)
assert torch.allclose(gk, k)
def test_paged_cache_write_cross_page():
cache = PagedCache(
def test_storage_write_cross_page():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
@ -263,13 +240,13 @@ def test_paged_cache_write_cross_page():
k = torch.randn(1, 8, 2, 8)
v = torch.randn(1, 8, 2, 8)
cache.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 8)
storage.write(0, page_table, 0, k, v)
gk, gv = storage.gather(0, page_table, 8)
assert torch.allclose(gk, k)
def test_paged_cache_gather_truncates_to_total_len():
cache = PagedCache(
def test_storage_gather_truncates_to_total_len():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
@ -281,14 +258,14 @@ def test_paged_cache_gather_truncates_to_total_len():
page_table = torch.tensor([[0, 1]], dtype=torch.long)
k = torch.randn(1, 6, 2, 8)
v = torch.randn(1, 6, 2, 8)
cache.write(0, page_table, 0, k, v)
storage.write(0, page_table, 0, k, v)
gk, gv = cache.gather(0, page_table, 5)
gk, gv = storage.gather(0, page_table, 5)
assert gk.shape == (1, 5, 2, 8)
def test_paged_cache_gather_clamps_negative_padding():
cache = PagedCache(
def test_storage_gather_clamps_negative_padding():
storage = Storage(
n_layers=1,
n_pages=8,
page_size=4,
@ -298,5 +275,5 @@ def test_paged_cache_gather_clamps_negative_padding():
dtype=torch.float32,
)
page_table = torch.tensor([[0, -1]], dtype=torch.long)
gk, gv = cache.gather(0, page_table, 4)
gk, gv = storage.gather(0, page_table, 4)
assert gk.shape == (1, 4, 2, 8)