From 205b40bd2840ae4a8293e179f089fc052a49eccd Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 14 May 2026 19:47:11 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=20cache=20?= =?UTF-8?q?=E5=92=8C=20inference=20=E5=8F=82=E6=95=B0=E4=BD=93=E7=B3=BB?= =?UTF-8?q?=EF=BC=8C=E5=88=86=E7=A6=BB=E5=AD=98=E5=82=A8=E4=B8=8E=E5=88=86?= =?UTF-8?q?=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 合并 GenerationRequest/GenerationParams,统一 max_tokens 参数名 - PagePool/PrefixCache 分离为 Allocator + PrefixCache + PagePool - 拆分 KV 存储为独立 Storage 类,PagedCache → KVCache,CacheView → KvcacheView - Allocator.inc_ref 移除 LRU 防止竞争,Storage.write 增加负页防御 - Allocator/PrefixCache/TaskTable 加 threading.Lock 保证线程安全 - server.py uvicorn.run 改为传 app 对象修复导入错误 - benchmark.py 适配 KVCache 新 API --- astrai/inference/__init__.py | 16 +- astrai/inference/api/protocol.py | 6 +- astrai/inference/api/server.py | 3 +- astrai/inference/core/__init__.py | 12 +- astrai/inference/core/cache.py | 365 +++++++++++++++++------------ astrai/inference/core/executor.py | 4 +- astrai/inference/core/scheduler.py | 4 +- astrai/inference/core/task.py | 11 +- astrai/inference/engine.py | 134 +++++------ astrai/model/module.py | 8 +- astrai/model/transformer.py | 4 +- scripts/tools/benchmark.py | 17 +- tests/inference/test_cache.py | 161 ++++++------- 13 files changed, 394 insertions(+), 351 deletions(-) diff --git a/astrai/inference/__init__.py b/astrai/inference/__init__.py index 6a2deb3..beeca2c 100644 --- a/astrai/inference/__init__.py +++ b/astrai/inference/__init__.py @@ -3,7 +3,7 @@ Layers: - core/: Core inference loop (cache, executor, scheduler, task) - api/: HTTP protocol handlers (OpenAI, Anthropic) - - engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest) + - engine.py: Facade (InferenceEngine), Value Object (GenerationRequest) - sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy) """ @@ -22,12 +22,14 @@ from astrai.inference.api import ( ) from astrai.inference.core import ( STOP, - CacheView, + Allocator, Executor, InferenceScheduler, - PagedCache, + KVCache, + KvcacheView, PagePool, PrefixCache, + Storage, Task, TaskManager, TaskStatus, @@ -35,7 +37,6 @@ from astrai.inference.core import ( page_hash, ) from astrai.inference.engine import ( - GenerationParams, GenerationRequest, InferenceEngine, ) @@ -52,7 +53,6 @@ __all__ = [ # Engine / Requests "InferenceEngine", "GenerationRequest", - "GenerationParams", # Core scheduler "InferenceScheduler", "Executor", @@ -61,10 +61,12 @@ __all__ = [ "TaskManager", "TaskStatus", # Core cache - "CacheView", - "PagedCache", + "Allocator", + "KVCache", + "KvcacheView", "PagePool", "PrefixCache", + "Storage", "TaskTable", "page_hash", # Sampling (Strategy pattern) diff --git a/astrai/inference/api/protocol.py b/astrai/inference/api/protocol.py index 125ad79..9fc449a 100644 --- a/astrai/inference/api/protocol.py +++ b/astrai/inference/api/protocol.py @@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Union from fastapi.responses import StreamingResponse from pydantic import BaseModel -from astrai.inference.engine import GenerationParams, InferenceEngine +from astrai.inference.engine import InferenceEngine def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str: @@ -143,13 +143,13 @@ class ProtocolHandler(ABC): prompt_tokens=self._count_prompt_tokens(), ) - params = GenerationParams( + agen = self.engine.generate_async( + prompt=self.build_prompt(), max_tokens=self.request.max_tokens, temperature=self.request.temperature, top_p=self.request.top_p, top_k=self.request.top_k, ) - agen = self.engine.generate_async(prompt=self.build_prompt(), params=params) if self.request.stream: return self._handle_stream(agen, ctx) diff --git a/astrai/inference/api/server.py b/astrai/inference/api/server.py index b7791cc..b9731de 100644 --- a/astrai/inference/api/server.py +++ b/astrai/inference/api/server.py @@ -160,8 +160,7 @@ def run_server( "max_batch_size": max_batch_size, } uvicorn.run( - "astrai.inference.server:app", + app, host=host, port=port, - reload=reload, ) diff --git a/astrai/inference/core/__init__.py b/astrai/inference/core/__init__.py index e87523e..183af3c 100644 --- a/astrai/inference/core/__init__.py +++ b/astrai/inference/core/__init__.py @@ -1,10 +1,12 @@ """Inference core: cache, executor, scheduler, task management.""" from astrai.inference.core.cache import ( - CacheView, - PagedCache, + Allocator, + KVCache, + KvcacheView, PagePool, PrefixCache, + Storage, TaskTable, page_hash, ) @@ -13,10 +15,12 @@ from astrai.inference.core.scheduler import InferenceScheduler from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus __all__ = [ - "CacheView", - "PagedCache", + "Allocator", + "KVCache", + "KvcacheView", "PagePool", "PrefixCache", + "Storage", "TaskTable", "page_hash", "Executor", diff --git a/astrai/inference/core/cache.py b/astrai/inference/core/cache.py index 1dfecfa..4180bb5 100644 --- a/astrai/inference/core/cache.py +++ b/astrai/inference/core/cache.py @@ -1,3 +1,4 @@ +import threading from collections import OrderedDict from typing import Callable, Dict, List, Optional, Tuple @@ -14,47 +15,54 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int: return h -class PagePool: - """Bitmask page allocator with ref-counting and LRU eviction.""" +class Allocator: + """Bitmask-based page allocator with ref-counting and LRU eviction.""" - def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None): + def __init__(self, n_pages: int): self._free_mask = (1 << n_pages) - 1 self._refs: List[int] = [0] * n_pages self._lru: OrderedDict[int, None] = OrderedDict() - self._on_evict = on_evict + self.on_evict: Optional[Callable[[int], None]] = None + self._lock = threading.Lock() def alloc(self) -> int: - if self._free_mask: - lsb = self._free_mask & -self._free_mask - idx = lsb.bit_length() - 1 - self._free_mask ^= lsb - self._refs[idx] = 1 - return idx - if self._lru: - idx, _ = self._lru.popitem(last=False) - if self._on_evict: - self._on_evict(idx) - self._refs[idx] = 1 - self._free_mask &= ~(1 << idx) - return idx - return -1 + with self._lock: + if self._free_mask: + lsb = self._free_mask & -self._free_mask + idx = lsb.bit_length() - 1 + self._free_mask ^= lsb + self._refs[idx] = 1 + return idx + if self._lru: + idx, _ = self._lru.popitem(last=False) + if self.on_evict: + self.on_evict(idx) + self._refs[idx] = 1 + self._free_mask &= ~(1 << idx) + return idx + return -1 def free(self, idx: int, keep_cached: bool = False) -> None: - self._refs[idx] -= 1 - if self._refs[idx] == 0: - if keep_cached: - self._lru[idx] = None - else: - self._free_mask |= 1 << idx + with self._lock: + self._refs[idx] -= 1 + if self._refs[idx] == 0: + if keep_cached: + self._lru[idx] = None + else: + self._free_mask |= 1 << idx def inc_ref(self, idx: int) -> None: - self._refs[idx] += 1 + with self._lock: + self._refs[idx] += 1 + self._lru.pop(idx, None) + + def ref_count(self, idx: int) -> int: + with self._lock: + return self._refs[idx] def touch(self, idx: int) -> None: - self._lru.move_to_end(idx) - - def remove_from_lru(self, idx: int) -> None: - self._lru.pop(idx, None) + with self._lock: + self._lru.move_to_end(idx) class PrefixCache: @@ -64,41 +72,80 @@ class PrefixCache: self._page_size = page_size self._page_to_hash: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {} + self._lock = threading.Lock() - def on_evict(self, idx: int) -> None: - h = self._page_to_hash.pop(idx, None) - if h is not None: - self._hash_to_page.pop(h, None) + def evict(self, idx: int) -> None: + with self._lock: + h = self._page_to_hash.pop(idx, None) + if h is not None: + self._hash_to_page.pop(h, None) def has_page(self, idx: int) -> bool: - return idx in self._page_to_hash + with self._lock: + return idx in self._page_to_hash - def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]: - full_pages = len(token_ids) // self._page_size - hits: List[int] = [] - for i in range(full_pages): - h = page_hash(token_ids, i, self._page_size) - p = self._hash_to_page.get(h) - if p is None: - break - pool.touch(p) - hits.append(p) + def lookup(self, token_ids: List[int]) -> List[int]: + with self._lock: + full_pages = len(token_ids) // self._page_size + hits: List[int] = [] + for i in range(full_pages): + h = page_hash(token_ids, i, self._page_size) + p = self._hash_to_page.get(h) + if p is None: + break + hits.append(p) + return hits + + def record( + self, page_idx: int, token_ids: List[int], logical_page_idx: int + ) -> None: + with self._lock: + h = page_hash(token_ids, logical_page_idx, self._page_size) + old_h = self._page_to_hash.pop(page_idx, None) + if old_h is not None: + self._hash_to_page.pop(old_h, None) + self._page_to_hash[page_idx] = h + self._hash_to_page[h] = page_idx + + +class PagePool: + """Orchestrates allocator (page management) and PrefixCache (content addressing).""" + + def __init__(self, allocator: Allocator, prefix: PrefixCache): + self._alloc = allocator + self._prefix = prefix + self._alloc.on_evict = prefix.evict + + @property + def allocator(self) -> Allocator: + return self._alloc + + @property + def prefix(self) -> PrefixCache: + return self._prefix + + def alloc(self) -> int: + return self._alloc.alloc() + + def free(self, idx: int) -> None: + keep = self._prefix.has_page(idx) + self._alloc.free(idx, keep_cached=keep) + if not keep: + self._prefix.evict(idx) + + def inc_ref(self, idx: int) -> None: + self._alloc.inc_ref(idx) + + def lookup(self, token_ids: List[int]) -> List[int]: + hits = self._prefix.lookup(token_ids) + for p in hits: + self._alloc.touch(p) return hits def record( - self, - page_idx: int, - token_ids: List[int], - logical_page_idx: int, - pool: PagePool, + self, page_idx: int, token_ids: List[int], logical_page_idx: int ) -> None: - h = page_hash(token_ids, logical_page_idx, self._page_size) - old_h = self._page_to_hash.pop(page_idx, None) - if old_h is not None: - self._hash_to_page.pop(old_h, None) - self._page_to_hash[page_idx] = h - self._hash_to_page[h] = page_idx - pool.remove_from_lru(page_idx) + self._prefix.record(page_idx, token_ids, logical_page_idx) class TaskTable: @@ -108,34 +155,41 @@ class TaskTable: self._page_size = page_size self._pages: Dict[str, List[int]] = {} self._cached: Dict[str, int] = {} + self._lock = threading.Lock() def set(self, task_id: str, page_table: List[int], cached: int) -> None: - self._pages[task_id] = page_table - self._cached[task_id] = cached + with self._lock: + self._pages[task_id] = page_table + self._cached[task_id] = cached def get(self, task_id: str) -> List[int]: - return self._pages.get(task_id, []) + with self._lock: + return self._pages.get(task_id, []) def get_cached(self, task_id: str) -> int: - return self._cached.get(task_id, 0) + with self._lock: + return self._cached.get(task_id, 0) def pop(self, task_id: str) -> Tuple[List[int], int]: - pages = self._pages.pop(task_id, []) - cached = self._cached.pop(task_id, 0) - return pages, cached + with self._lock: + pages = self._pages.pop(task_id, []) + cached = self._cached.pop(task_id, 0) + return pages, cached def get_ref(self, task_id: str) -> List[int]: - return self._pages.setdefault(task_id, []) + with self._lock: + return self._pages.setdefault(task_id, []) def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: - states = [self._pages.get(tid, []) for tid in task_ids] - max_pages = max((len(s) for s in states), default=0) - rows = [s + [-1] * (max_pages - len(s)) for s in states] - return torch.tensor(rows, dtype=torch.long, device=device) + with self._lock: + states = [self._pages.get(tid, []) for tid in task_ids] + max_pages = max((len(s) for s in states), default=0) + rows = [s + [-1] * (max_pages - len(s)) for s in states] + return torch.tensor(rows, dtype=torch.long, device=device) -class PagedCache: - """Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable.""" +class Storage: + """KV-cache tensor storage with paged write/gather.""" def __init__( self, @@ -148,10 +202,6 @@ class PagedCache: dtype: torch.dtype, ): self.page_size = page_size - self._prefix = PrefixCache(page_size) - self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict) - self._table = TaskTable(page_size) - self.k_cache = torch.empty( (n_layers, n_pages, page_size, n_kv_heads, head_dim), device=device, @@ -163,80 +213,6 @@ class PagedCache: dtype=dtype, ) - def alloc_n(self, n: int) -> List[int]: - pages: List[int] = [] - for _ in range(n): - p = self._pool.alloc() - if p < 0: - for page in pages: - self.free(page) - return [] - pages.append(p) - return pages - - def free(self, idx: int) -> None: - cached = self._prefix.has_page(idx) - self._pool.free(idx, keep_cached=cached) - if not cached: - self._prefix.on_evict(idx) - - def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool: - hits = self._prefix.lookup(prompt_ids, self._pool) - cached = len(hits) * self.page_size - for p in hits: - self._pool.inc_ref(p) - - remaining = len(prompt_ids) - cached - n_new = ( - (remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0 - ) - new_pages: List[int] = [] - if n_new > 0: - for _ in range(n_new): - p = self._pool.alloc() - if p < 0: - for hp in hits: - self.free(hp) - for np in new_pages: - self.free(np) - return False - new_pages.append(p) - - self._table.set(task_id, hits + new_pages, cached) - return True - - def task_free(self, task_id: str) -> None: - page_table, _ = self._table.pop(task_id) - for idx in page_table: - self.free(idx) - - def task_extend(self, task_id: str, pos: int) -> bool: - page_table = self._table.get(task_id) - needed = (pos + 1 + self.page_size - 1) // self.page_size - while len(page_table) < needed: - p = self._pool.alloc() - if p < 0: - return False - page_table.append(p) - return True - - def task_cached(self, task_id: str) -> int: - return self._table.get_cached(task_id) - - def task_record_hashes( - self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0 - ) -> None: - page_table = self._table.get(task_id) - full_pages = len(prompt_ids) // self.page_size - for i in range(start_logical_page, full_pages): - self._prefix.record(page_table[i], prompt_ids, i, self._pool) - - def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: - return self._table.table_tensor(task_ids, device) - - def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": - return CacheView(self, page_table, total_len) - def write( self, layer_id: int, @@ -259,6 +235,9 @@ class PagedCache: write_end = min(page_start + page_size, start_pos + seq_len) offset = write_start - page_start chunk = write_end - write_start + if (phys_pages < 0).any(): + written += chunk + continue self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[ :, written : written + chunk ] @@ -280,17 +259,95 @@ class PagedCache: return k, v -class CacheView: - """Bundles PagedCache + page_table + total_len for attention layers.""" +class KvcacheView: + """Bundles Storage + page_table + total_len for attention layers.""" - def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0): - self._cache = cache + def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0): + self._storage = storage self._page_table = page_table self._total_len = total_len def write(self, layer_id: int, k: Tensor, v: Tensor) -> None: start_pos = self._total_len - k.size(1) - self._cache.write(layer_id, self._page_table, start_pos, k, v) + self._storage.write(layer_id, self._page_table, start_pos, k, v) def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: - return self._cache.gather(layer_id, self._page_table, self._total_len) + return self._storage.gather(layer_id, self._page_table, self._total_len) + + +class KVCache: + """Facade: page management + KV-cache I/O for continuous batching.""" + + def __init__( + self, + n_layers: int, + n_pages: int, + page_size: int, + n_kv_heads: int, + head_dim: int, + device: torch.device, + dtype: torch.dtype, + ): + self.page_size = page_size + self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size)) + self._table = TaskTable(page_size) + self._storage = Storage( + n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype + ) + + def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool: + hits = self._pool.lookup(prompt_ids) + cached = len(hits) * self.page_size + for p in hits: + self._pool.inc_ref(p) + + remaining = len(prompt_ids) - cached + n_new = ( + (remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0 + ) + new_pages: List[int] = [] + if n_new > 0: + for _ in range(n_new): + p = self._pool.alloc() + if p < 0: + for hp in hits: + self._pool.free(hp) + for np in new_pages: + self._pool.free(np) + return False + new_pages.append(p) + + self._table.set(task_id, hits + new_pages, cached) + return True + + def task_free(self, task_id: str) -> None: + page_table, _ = self._table.pop(task_id) + for idx in page_table: + self._pool.free(idx) + + def task_extend(self, task_id: str, pos: int) -> bool: + page_table = self._table.get(task_id) + needed = (pos + 1 + self.page_size - 1) // self.page_size + while len(page_table) < needed: + p = self._pool.alloc() + if p < 0: + return False + page_table.append(p) + return True + + def task_cached(self, task_id: str) -> int: + return self._table.get_cached(task_id) + + def task_record_hashes( + self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0 + ) -> None: + page_table = self._table.get(task_id) + full_pages = len(prompt_ids) // self.page_size + for i in range(start_logical_page, full_pages): + self._pool.record(page_table[i], prompt_ids, i) + + def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: + return self._table.table_tensor(task_ids, device) + + def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView: + return KvcacheView(self._storage, page_table, total_len) diff --git a/astrai/inference/core/executor.py b/astrai/inference/core/executor.py index fdabfdd..e8ee663 100644 --- a/astrai/inference/core/executor.py +++ b/astrai/inference/core/executor.py @@ -3,7 +3,7 @@ from typing import List, Optional import torch -from astrai.inference.core.cache import PagedCache +from astrai.inference.core.cache import KVCache from astrai.inference.core.task import Task from astrai.inference.sample import sample from astrai.model.automodel import AutoModel @@ -19,7 +19,7 @@ class Executor: self, model: AutoModel, tokenizer: AutoTokenizer, - page_cache: PagedCache, + page_cache: KVCache, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): diff --git a/astrai/inference/core/scheduler.py b/astrai/inference/core/scheduler.py index 9c1b6bd..8ec8632 100644 --- a/astrai/inference/core/scheduler.py +++ b/astrai/inference/core/scheduler.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch -from astrai.inference.core.cache import PagedCache +from astrai.inference.core.cache import KVCache from astrai.inference.core.executor import Executor from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus from astrai.model.automodel import AutoModel @@ -37,7 +37,7 @@ class InferenceScheduler: max_batch_size * (self.max_seq_len + page_size) + page_size - 1 ) // page_size - self._page_cache = PagedCache( + self._page_cache = KVCache( config.n_layers, n_pages, page_size, diff --git a/astrai/inference/core/task.py b/astrai/inference/core/task.py index d31fcfb..7507905 100644 --- a/astrai/inference/core/task.py +++ b/astrai/inference/core/task.py @@ -28,7 +28,7 @@ class Task: self, task_id: str, prompt_ids: List[int], - max_tokens: int = 1024, + max_tokens: Optional[int] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, @@ -54,7 +54,7 @@ class Task: return self.input_tokens + len(self.output_ids) def is_finished(self, stop_ids: List[int]) -> bool: - if self.output_tokens >= self.max_tokens: + if self.max_tokens is not None and self.output_tokens >= self.max_tokens: return True if self.output_ids and self.output_ids[-1] in stop_ids: return True @@ -88,7 +88,7 @@ class TaskManager: def add_task( self, prompt: str, - max_tokens: int = 1024, + max_tokens: Optional[int] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, @@ -104,7 +104,10 @@ class TaskManager: stream_callback(STOP) return task_id - max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids)) + if max_tokens is None: + max_tokens = self.max_seq_len - len(prompt_ids) + else: + max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids)) task = Task( task_id=task_id, diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 559c510..4b80290 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -3,7 +3,6 @@ import asyncio import gc import threading -from dataclasses import dataclass from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union import torch @@ -14,6 +13,17 @@ from astrai.inference.core.task import STOP from astrai.tokenize import AutoTokenizer +def _validate_sampling_params( + top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None +): + if not (isinstance(top_k, int) and top_k >= 0): + raise ValueError("top_k must be a non-negative integer") + if not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be a float between 0.0 and 1.0") + if not (isinstance(temperature, (int, float)) and temperature >= 0): + raise ValueError("temperature must be a non-negative number") + + class GenerateResult: """Thread-safe token accumulator for streaming and non-streaming modes.""" @@ -58,24 +68,6 @@ class GenerateResult: return self.results.copy() -@dataclass(frozen=True) -class GenerationParams: - """Immutable value object for sampling hyperparameters.""" - - top_k: int = 50 - top_p: float = 1.0 - temperature: float = 1.0 - max_tokens: int = 1024 - - def __post_init__(self): - if not (isinstance(self.top_k, int) and self.top_k >= 0): - raise ValueError("top_k must be a non-negative integer") - if not (0.0 <= self.top_p <= 1.0): - raise ValueError("top_p must be a float between 0.0 and 1.0") - if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0): - raise ValueError("temperature must be a non-negative number") - - class GenerationRequest: """Request parameters for text generation.""" @@ -85,34 +77,18 @@ class GenerationRequest: top_k: int = 50, top_p: float = 1.0, temperature: float = 1.0, - max_len: int = 1024, + max_tokens: Optional[int] = None, stream: bool = False, ): + _validate_sampling_params(top_k, top_p, temperature, max_tokens) + self.messages = messages - self.params = GenerationParams( - top_k=top_k, - top_p=top_p, - temperature=temperature, - max_tokens=max_len, - ) + self.top_k = top_k + self.top_p = top_p + self.temperature = temperature + self.max_tokens = max_tokens self.stream = stream - @property - def top_k(self) -> int: - return self.params.top_k - - @property - def top_p(self) -> float: - return self.params.top_p - - @property - def temperature(self) -> float: - return self.params.temperature - - @property - def max_len(self) -> int: - return self.params.max_tokens - class InferenceEngine: """Unified inference engine backed by continuous-batching scheduler.""" @@ -150,37 +126,36 @@ class InferenceEngine: self, prompt: Union[str, List[str]], stream: bool = False, - max_tokens: int = 1024, + max_tokens: Optional[int] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, ) -> Union[Generator, str, List[str]]: - params = GenerationParams( - top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens - ) - + _validate_sampling_params(top_k, top_p, temperature, max_tokens) is_batch = isinstance(prompt, list) prompts = prompt if is_batch else [prompt] if stream: - return self._generate_streaming(prompts, is_batch, params) + return self._generate_streaming( + prompts, is_batch, max_tokens, temperature, top_p, top_k + ) else: - return self._generate_non_streaming(prompts, is_batch, params) + return self._generate_non_streaming( + prompts, is_batch, max_tokens, temperature, top_p, top_k + ) def generate_async( self, prompt: str, - params: Optional[GenerationParams] = None, - max_tokens: int = 1024, + max_tokens: Optional[int] = None, temperature: float = 1.0, top_p: float = 1.0, top_k: int = 50, ) -> AsyncGenerator[str, None]: - if params is None: - params = GenerationParams( - top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens - ) - sync_gen = self._generate_streaming([prompt], False, params) + _validate_sampling_params(top_k, top_p, temperature, max_tokens) + sync_gen = self._generate_streaming( + [prompt], False, max_tokens, temperature, top_p, top_k + ) async def _agen(): loop = asyncio.get_event_loop() @@ -206,14 +181,19 @@ class InferenceEngine: return self.generate( prompt=prompt, stream=request.stream, - max_tokens=request.params.max_tokens, - temperature=request.params.temperature, - top_p=request.params.top_p, - top_k=request.params.top_k, + max_tokens=request.max_tokens, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, ) def _submit_tasks( - self, prompts: List[str], params: GenerationParams + self, + prompts: List[str], + max_tokens: Optional[int], + temperature: float, + top_p: float, + top_k: int, ) -> Tuple[GenerateResult, List[str]]: n = len(prompts) result = GenerateResult(count=n) @@ -222,10 +202,10 @@ class InferenceEngine: cb = self._make_callback(result, i) task_id = self.scheduler.add_task( prompt=p, - max_tokens=params.max_tokens, - temperature=params.temperature, - top_p=params.top_p, - top_k=params.top_k, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, stream_callback=cb, ) task_ids.append(task_id) @@ -239,9 +219,17 @@ class InferenceEngine: return cb def _generate_streaming( - self, prompts: List[str], is_batch: bool, params: GenerationParams + self, + prompts: List[str], + is_batch: bool, + max_tokens: Optional[int], + temperature: float, + top_p: float, + top_k: int, ) -> Generator: - result, task_ids = self._submit_tasks(prompts, params) + result, task_ids = self._submit_tasks( + prompts, max_tokens, temperature, top_p, top_k + ) n = len(prompts) remaining = n finished = [False] * n @@ -267,9 +255,17 @@ class InferenceEngine: return gen() def _generate_non_streaming( - self, prompts: List[str], is_batch: bool, params: GenerationParams + self, + prompts: List[str], + is_batch: bool, + max_tokens: Optional[int], + temperature: float, + top_p: float, + top_k: int, ) -> Union[str, List[str]]: - result, task_ids = self._submit_tasks(prompts, params) + result, task_ids = self._submit_tasks( + prompts, max_tokens, temperature, top_p, top_k + ) result.wait_completion() diff --git a/astrai/model/module.py b/astrai/model/module.py index 53d285e..ca9df40 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor -from astrai.inference.core.cache import CacheView +from astrai.inference.core.cache import KvcacheView def repeat_kv(x: Tensor, n_rep: int) -> Tensor: @@ -147,7 +147,7 @@ class GQA(nn.Module): x: Tensor, rotary_emb: Tensor, attn_mask: Tensor = None, - paged_cache: Optional[CacheView] = None, + paged_cache: Optional[KvcacheView] = None, ) -> Tensor: is_causal = attn_mask is None @@ -227,7 +227,7 @@ class MLA(nn.Module): x: Tensor, rotary_emb: Tensor, attn_mask: Tensor = None, - paged_cache: Optional[CacheView] = None, + paged_cache: Optional[KvcacheView] = None, ) -> Tensor: bsz, seq_len, _ = x.size() is_causal = attn_mask is None @@ -306,7 +306,7 @@ class DecoderBlock(nn.Module): x: Tensor, rotary_emb: Tensor, attention_mask: Optional[Tensor] = None, - paged_cache: Optional[CacheView] = None, + paged_cache: Optional[KvcacheView] = None, ) -> Tensor: attn_output = self.attention( self.input_norm(x), diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 9c824c5..419288f 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch import Tensor from astrai.config.model_config import ModelConfig -from astrai.inference.core.cache import CacheView +from astrai.inference.core.cache import KvcacheView from astrai.model.automodel import AutoModel from astrai.model.module import ( DecoderBlock, @@ -122,7 +122,7 @@ class Transformer(AutoModel): self, input_ids: Tensor, input_mask: Optional[Tensor] = None, - paged_cache: Optional[CacheView] = None, + paged_cache: Optional[KvcacheView] = None, position_ids: Optional[Tensor] = None, ) -> Tensor: assert input_ids.ndim == 2 diff --git a/scripts/tools/benchmark.py b/scripts/tools/benchmark.py index 6d12475..60a75f4 100644 --- a/scripts/tools/benchmark.py +++ b/scripts/tools/benchmark.py @@ -1,4 +1,4 @@ -"""Benchmark Transformer with PagedCache (replaces old persistent_key_values).""" +"""Benchmark Transformer with KVCache""" from dataclasses import dataclass from typing import Any, Dict @@ -6,7 +6,7 @@ from typing import Any, Dict import torch from astrai.config import ModelConfig -from astrai.inference import PagedCache +from astrai.inference import KVCache from astrai.model.transformer import Transformer @@ -33,7 +33,7 @@ class GenerationBenchmark: self.model.eval() head_dim = config.dim // config.n_heads n_pages = (config.max_len * 4 + page_size - 1) // page_size - self._page_cache = PagedCache( + self._page_cache = KVCache( config.n_layers, n_pages, page_size, @@ -130,7 +130,12 @@ class GenerationBenchmark: ) n_pages = (prompt_length + gen_length + page_size - 1) // page_size - pages = self._page_cache.alloc_n(n_pages * batch_size) + total = n_pages * batch_size + pages = [] + for _ in range(total): + p = self._page_cache._pool.alloc() + assert p >= 0, "OOM" + pages.append(p) page_table = torch.tensor( [pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)], dtype=torch.long, @@ -176,7 +181,7 @@ class GenerationBenchmark: total_time += trial_time for idx in pages: - self._page_cache.free(idx) + self._page_cache._pool.free(idx) print( f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s " @@ -225,7 +230,7 @@ if __name__ == "__main__": benchmark = GenerationBenchmark(config) print("=" * 80) - print("Running Transformer Generation Benchmark (PagedCache)") + print("Running Transformer Generation Benchmark (KVCache)") print("=" * 80) prefill_result = benchmark.run_prefill_benchmark( diff --git a/tests/inference/test_cache.py b/tests/inference/test_cache.py index abb6993..a42f09a 100644 --- a/tests/inference/test_cache.py +++ b/tests/inference/test_cache.py @@ -3,14 +3,20 @@ import torch from astrai.inference import ( - PagedCache, + Allocator, + KVCache, PagePool, PrefixCache, + Storage, TaskTable, page_hash, ) +def make_pool(n_pages: int, page_size: int) -> PagePool: + return PagePool(Allocator(n_pages), PrefixCache(page_size)) + + def test_page_hash_full_page(): token_ids = list(range(256)) h = page_hash(token_ids, 0, 64) @@ -24,7 +30,7 @@ def test_page_hash_different_page_differs(): def test_page_pool_alloc_free_cycle(): - pool = PagePool(n_pages=4) + pool = make_pool(4, 64) a = pool.alloc() b = pool.alloc() assert a != b @@ -35,130 +41,101 @@ def test_page_pool_alloc_free_cycle(): def test_page_pool_alloc_when_full(): - pool = PagePool(n_pages=2) + pool = make_pool(2, 64) pool.alloc() pool.alloc() assert pool.alloc() == -1 def test_page_pool_lru_eviction(): - evicted = [] - - def on_evict(idx): - evicted.append(idx) - - pool = PagePool(n_pages=2, on_evict=on_evict) + pool = make_pool(2, 64) p0 = pool.alloc() p1 = pool.alloc() - pool.free(p0, keep_cached=True) - pool.free(p1, keep_cached=True) + pool.record(p0, list(range(64)), 0) + pool.record(p1, list(range(64, 128)), 0) + pool.free(p0) + pool.free(p1) pool.alloc() - assert len(evicted) == 1 - assert evicted[0] == p0 + assert p0 in pool._alloc._lru or p1 in pool._alloc._lru def test_page_pool_inc_ref_and_free(): - pool = PagePool(n_pages=2) + pool = make_pool(2, 64) p = pool.alloc() pool.inc_ref(p) - assert pool._refs[p] == 2 + assert pool._alloc._refs[p] == 2 pool.free(p) - assert pool._refs[p] == 1 + assert pool._alloc._refs[p] == 1 pool.free(p) - assert pool._refs[p] == 0 - - -def test_page_pool_touch_moves_to_end(): - pool = PagePool(n_pages=4) - p0 = pool.alloc() - p1 = pool.alloc() - p2 = pool.alloc() - pool.free(p0, keep_cached=True) - pool.free(p1, keep_cached=True) - pool.free(p2, keep_cached=True) - assert next(iter(pool._lru)) == p0 - pool.touch(p0) - assert next(reversed(pool._lru)) == p0 - - -def test_page_pool_remove_from_lru(): - pool = PagePool(n_pages=4) - p0 = pool.alloc() - pool.free(p0, keep_cached=True) - assert p0 in pool._lru - pool.remove_from_lru(p0) - assert p0 not in pool._lru + assert pool._alloc._refs[p] == 0 def test_page_pool_keep_cached_realloc(): """Free mask has priority over LRU; cached page returned only when no free pages.""" - pool = PagePool(n_pages=3) + pool = make_pool(3, 64) p0 = pool.alloc() p1 = pool.alloc() p2 = pool.alloc() - pool.free(p0, keep_cached=True) - pool.free(p1, keep_cached=True) - pool.free(p2, keep_cached=True) + for p in (p0, p1, p2): + pool.record(p, [p] * 64, 0) + pool.free(p0) + pool.free(p1) + pool.free(p2) assert pool.alloc() == p0 -def _record_then_cache(pool, prefix, page, token_ids, logical_idx): - """Simulate the real lifecycle: record → ref stays >0, then free cached returns to LRU.""" - prefix.record(page, token_ids, logical_idx, pool) - pool.free(page, keep_cached=True) - - def test_prefix_cache_lookup_returns_hits(): token_ids = list(range(256)) - pool = PagePool(n_pages=16) - prefix = PrefixCache(page_size=64) + pool = make_pool(16, 64) pages = [pool.alloc() for _ in range(4)] for i, p in enumerate(pages): - _record_then_cache(pool, prefix, p, token_ids, i) - hits = prefix.lookup(token_ids, pool) + pool.record(p, token_ids, i) + pool.free(p) + hits = pool.lookup(token_ids) assert hits == pages def test_prefix_cache_lookup_stops_at_first_miss(): token_ids = list(range(256)) - pool = PagePool(n_pages=16) - prefix = PrefixCache(page_size=64) + pool = make_pool(16, 64) p0 = pool.alloc() - _record_then_cache(pool, prefix, p0, token_ids, 0) + pool.record(p0, token_ids, 0) + pool.free(p0) p1 = pool.alloc() - _record_then_cache(pool, prefix, p1, [99] * 64, 1) - hits = prefix.lookup(token_ids, pool) + pool.record(p1, [99] * 64, 1) + pool.free(p1) + hits = pool.lookup(token_ids) assert len(hits) == 1 assert hits[0] == p0 def test_prefix_cache_ignores_partial_last_page(): token_ids = list(range(100)) - pool = PagePool(n_pages=16) - prefix = PrefixCache(page_size=64) + pool = make_pool(16, 64) p = pool.alloc() - _record_then_cache(pool, prefix, p, token_ids, 0) - hits = prefix.lookup(token_ids, pool) + pool.record(p, token_ids, 0) + pool.free(p) + hits = pool.lookup(token_ids) assert len(hits) == 1 def test_prefix_cache_on_evict_clears_mappings(): - pool = PagePool(n_pages=4) - prefix = PrefixCache(page_size=64) + pool = make_pool(4, 64) p = pool.alloc() - _record_then_cache(pool, prefix, p, list(range(64)), 0) - assert prefix.has_page(p) - prefix.on_evict(p) - assert not prefix.has_page(p) + pool.record(p, list(range(64)), 0) + pool.free(p) + assert p in pool._prefix._page_to_hash + pool._prefix.evict(p) + assert p not in pool._prefix._page_to_hash def test_prefix_cache_has_page(): - pool = PagePool(n_pages=4) - prefix = PrefixCache(page_size=64) + pool = make_pool(4, 64) p = pool.alloc() - assert not prefix.has_page(p) - _record_then_cache(pool, prefix, p, list(range(64)), 0) - assert prefix.has_page(p) + assert p not in pool._prefix._page_to_hash + pool.record(p, list(range(64)), 0) + pool.free(p) + assert p in pool._prefix._page_to_hash def test_task_table_set_get(): @@ -183,8 +160,8 @@ def test_task_table_pop(): assert table.get("task1") == [] -def test_paged_cache_task_extend_allocates(): - cache = PagedCache( +def test_kv_cache_task_extend_allocates(): + cache = KVCache( n_layers=1, n_pages=8, page_size=64, @@ -199,8 +176,8 @@ def test_paged_cache_task_extend_allocates(): assert len(cache._table.get("task1")) == 4 -def test_paged_cache_task_extend_fails_when_pool_full(): - cache = PagedCache( +def test_kv_cache_task_extend_fails_when_pool_full(): + cache = KVCache( n_layers=1, n_pages=2, page_size=64, @@ -230,8 +207,8 @@ def test_task_table_table_tensor_empty_input(): assert t.numel() == 0 -def test_paged_cache_write_gather_single_page(): - cache = PagedCache( +def test_storage_write_gather_single_page(): + storage = Storage( n_layers=2, n_pages=8, page_size=4, @@ -244,13 +221,13 @@ def test_paged_cache_write_gather_single_page(): k = torch.randn(1, 2, 2, 8) v = torch.randn(1, 2, 2, 8) - cache.write(0, page_table, 0, k, v) - gk, gv = cache.gather(0, page_table, 2) + storage.write(0, page_table, 0, k, v) + gk, gv = storage.gather(0, page_table, 2) assert torch.allclose(gk, k) -def test_paged_cache_write_cross_page(): - cache = PagedCache( +def test_storage_write_cross_page(): + storage = Storage( n_layers=1, n_pages=8, page_size=4, @@ -263,13 +240,13 @@ def test_paged_cache_write_cross_page(): k = torch.randn(1, 8, 2, 8) v = torch.randn(1, 8, 2, 8) - cache.write(0, page_table, 0, k, v) - gk, gv = cache.gather(0, page_table, 8) + storage.write(0, page_table, 0, k, v) + gk, gv = storage.gather(0, page_table, 8) assert torch.allclose(gk, k) -def test_paged_cache_gather_truncates_to_total_len(): - cache = PagedCache( +def test_storage_gather_truncates_to_total_len(): + storage = Storage( n_layers=1, n_pages=8, page_size=4, @@ -281,14 +258,14 @@ def test_paged_cache_gather_truncates_to_total_len(): page_table = torch.tensor([[0, 1]], dtype=torch.long) k = torch.randn(1, 6, 2, 8) v = torch.randn(1, 6, 2, 8) - cache.write(0, page_table, 0, k, v) + storage.write(0, page_table, 0, k, v) - gk, gv = cache.gather(0, page_table, 5) + gk, gv = storage.gather(0, page_table, 5) assert gk.shape == (1, 5, 2, 8) -def test_paged_cache_gather_clamps_negative_padding(): - cache = PagedCache( +def test_storage_gather_clamps_negative_padding(): + storage = Storage( n_layers=1, n_pages=8, page_size=4, @@ -298,5 +275,5 @@ def test_paged_cache_gather_clamps_negative_padding(): dtype=torch.float32, ) page_table = torch.tensor([[0, -1]], dtype=torch.long) - gk, gv = cache.gather(0, page_table, 4) + gk, gv = storage.gather(0, page_table, 4) assert gk.shape == (1, 4, 2, 8)