refactor: 重构 cache 和 inference 参数体系,分离存储与分配
- 合并 GenerationRequest/GenerationParams,统一 max_tokens 参数名 - PagePool/PrefixCache 分离为 Allocator + PrefixCache + PagePool - 拆分 KV 存储为独立 Storage 类,PagedCache → KVCache,CacheView → KvcacheView - Allocator.inc_ref 移除 LRU 防止竞争,Storage.write 增加负页防御 - Allocator/PrefixCache/TaskTable 加 threading.Lock 保证线程安全 - server.py uvicorn.run 改为传 app 对象修复导入错误 - benchmark.py 适配 KVCache 新 API
This commit is contained in:
parent
18fe6e9339
commit
205b40bd28
|
|
@ -3,7 +3,7 @@
|
|||
Layers:
|
||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||
"""
|
||||
|
||||
|
|
@ -22,12 +22,14 @@ from astrai.inference.api import (
|
|||
)
|
||||
from astrai.inference.core import (
|
||||
STOP,
|
||||
CacheView,
|
||||
Allocator,
|
||||
Executor,
|
||||
InferenceScheduler,
|
||||
PagedCache,
|
||||
KVCache,
|
||||
KvcacheView,
|
||||
PagePool,
|
||||
PrefixCache,
|
||||
Storage,
|
||||
Task,
|
||||
TaskManager,
|
||||
TaskStatus,
|
||||
|
|
@ -35,7 +37,6 @@ from astrai.inference.core import (
|
|||
page_hash,
|
||||
)
|
||||
from astrai.inference.engine import (
|
||||
GenerationParams,
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
|
|
@ -52,7 +53,6 @@ __all__ = [
|
|||
# Engine / Requests
|
||||
"InferenceEngine",
|
||||
"GenerationRequest",
|
||||
"GenerationParams",
|
||||
# Core scheduler
|
||||
"InferenceScheduler",
|
||||
"Executor",
|
||||
|
|
@ -61,10 +61,12 @@ __all__ = [
|
|||
"TaskManager",
|
||||
"TaskStatus",
|
||||
# Core cache
|
||||
"CacheView",
|
||||
"PagedCache",
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
"PagePool",
|
||||
"PrefixCache",
|
||||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
# Sampling (Strategy pattern)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from astrai.inference.engine import GenerationParams, InferenceEngine
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
|
||||
|
||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||
|
|
@ -143,13 +143,13 @@ class ProtocolHandler(ABC):
|
|||
prompt_tokens=self._count_prompt_tokens(),
|
||||
)
|
||||
|
||||
params = GenerationParams(
|
||||
agen = self.engine.generate_async(
|
||||
prompt=self.build_prompt(),
|
||||
max_tokens=self.request.max_tokens,
|
||||
temperature=self.request.temperature,
|
||||
top_p=self.request.top_p,
|
||||
top_k=self.request.top_k,
|
||||
)
|
||||
agen = self.engine.generate_async(prompt=self.build_prompt(), params=params)
|
||||
|
||||
if self.request.stream:
|
||||
return self._handle_stream(agen, ctx)
|
||||
|
|
|
|||
|
|
@ -160,8 +160,7 @@ def run_server(
|
|||
"max_batch_size": max_batch_size,
|
||||
}
|
||||
uvicorn.run(
|
||||
"astrai.inference.server:app",
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
reload=reload,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""Inference core: cache, executor, scheduler, task management."""
|
||||
|
||||
from astrai.inference.core.cache import (
|
||||
CacheView,
|
||||
PagedCache,
|
||||
Allocator,
|
||||
KVCache,
|
||||
KvcacheView,
|
||||
PagePool,
|
||||
PrefixCache,
|
||||
Storage,
|
||||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
|
|
@ -13,10 +15,12 @@ from astrai.inference.core.scheduler import InferenceScheduler
|
|||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||
|
||||
__all__ = [
|
||||
"CacheView",
|
||||
"PagedCache",
|
||||
"Allocator",
|
||||
"KVCache",
|
||||
"KvcacheView",
|
||||
"PagePool",
|
||||
"PrefixCache",
|
||||
"Storage",
|
||||
"TaskTable",
|
||||
"page_hash",
|
||||
"Executor",
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
|
|
@ -14,16 +15,18 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
|||
return h
|
||||
|
||||
|
||||
class PagePool:
|
||||
"""Bitmask page allocator with ref-counting and LRU eviction."""
|
||||
class Allocator:
|
||||
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
|
||||
|
||||
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
|
||||
def __init__(self, n_pages: int):
|
||||
self._free_mask = (1 << n_pages) - 1
|
||||
self._refs: List[int] = [0] * n_pages
|
||||
self._lru: OrderedDict[int, None] = OrderedDict()
|
||||
self._on_evict = on_evict
|
||||
self.on_evict: Optional[Callable[[int], None]] = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def alloc(self) -> int:
|
||||
with self._lock:
|
||||
if self._free_mask:
|
||||
lsb = self._free_mask & -self._free_mask
|
||||
idx = lsb.bit_length() - 1
|
||||
|
|
@ -32,14 +35,15 @@ class PagePool:
|
|||
return idx
|
||||
if self._lru:
|
||||
idx, _ = self._lru.popitem(last=False)
|
||||
if self._on_evict:
|
||||
self._on_evict(idx)
|
||||
if self.on_evict:
|
||||
self.on_evict(idx)
|
||||
self._refs[idx] = 1
|
||||
self._free_mask &= ~(1 << idx)
|
||||
return idx
|
||||
return -1
|
||||
|
||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||
with self._lock:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
if keep_cached:
|
||||
|
|
@ -48,14 +52,18 @@ class PagePool:
|
|||
self._free_mask |= 1 << idx
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
self._refs[idx] += 1
|
||||
self._lru.pop(idx, None)
|
||||
|
||||
def ref_count(self, idx: int) -> int:
|
||||
with self._lock:
|
||||
return self._refs[idx]
|
||||
|
||||
def touch(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
self._lru.move_to_end(idx)
|
||||
|
||||
def remove_from_lru(self, idx: int) -> None:
|
||||
self._lru.pop(idx, None)
|
||||
|
||||
|
||||
class PrefixCache:
|
||||
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
||||
|
|
@ -64,16 +72,20 @@ class PrefixCache:
|
|||
self._page_size = page_size
|
||||
self._page_to_hash: Dict[int, int] = {}
|
||||
self._hash_to_page: Dict[int, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def on_evict(self, idx: int) -> None:
|
||||
def evict(self, idx: int) -> None:
|
||||
with self._lock:
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
self._hash_to_page.pop(h, None)
|
||||
|
||||
def has_page(self, idx: int) -> bool:
|
||||
with self._lock:
|
||||
return idx in self._page_to_hash
|
||||
|
||||
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
with self._lock:
|
||||
full_pages = len(token_ids) // self._page_size
|
||||
hits: List[int] = []
|
||||
for i in range(full_pages):
|
||||
|
|
@ -81,24 +93,59 @@ class PrefixCache:
|
|||
p = self._hash_to_page.get(h)
|
||||
if p is None:
|
||||
break
|
||||
pool.touch(p)
|
||||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self,
|
||||
page_idx: int,
|
||||
token_ids: List[int],
|
||||
logical_page_idx: int,
|
||||
pool: PagePool,
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
with self._lock:
|
||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
if old_h is not None:
|
||||
self._hash_to_page.pop(old_h, None)
|
||||
self._page_to_hash[page_idx] = h
|
||||
self._hash_to_page[h] = page_idx
|
||||
pool.remove_from_lru(page_idx)
|
||||
|
||||
|
||||
class PagePool:
|
||||
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
|
||||
|
||||
def __init__(self, allocator: Allocator, prefix: PrefixCache):
|
||||
self._alloc = allocator
|
||||
self._prefix = prefix
|
||||
self._alloc.on_evict = prefix.evict
|
||||
|
||||
@property
|
||||
def allocator(self) -> Allocator:
|
||||
return self._alloc
|
||||
|
||||
@property
|
||||
def prefix(self) -> PrefixCache:
|
||||
return self._prefix
|
||||
|
||||
def alloc(self) -> int:
|
||||
return self._alloc.alloc()
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
keep = self._prefix.has_page(idx)
|
||||
self._alloc.free(idx, keep_cached=keep)
|
||||
if not keep:
|
||||
self._prefix.evict(idx)
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
self._alloc.inc_ref(idx)
|
||||
|
||||
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||
hits = self._prefix.lookup(token_ids)
|
||||
for p in hits:
|
||||
self._alloc.touch(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||
|
||||
|
||||
class TaskTable:
|
||||
|
|
@ -108,34 +155,41 @@ class TaskTable:
|
|||
self._page_size = page_size
|
||||
self._pages: Dict[str, List[int]] = {}
|
||||
self._cached: Dict[str, int] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||
with self._lock:
|
||||
self._pages[task_id] = page_table
|
||||
self._cached[task_id] = cached
|
||||
|
||||
def get(self, task_id: str) -> List[int]:
|
||||
with self._lock:
|
||||
return self._pages.get(task_id, [])
|
||||
|
||||
def get_cached(self, task_id: str) -> int:
|
||||
with self._lock:
|
||||
return self._cached.get(task_id, 0)
|
||||
|
||||
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
||||
with self._lock:
|
||||
pages = self._pages.pop(task_id, [])
|
||||
cached = self._cached.pop(task_id, 0)
|
||||
return pages, cached
|
||||
|
||||
def get_ref(self, task_id: str) -> List[int]:
|
||||
with self._lock:
|
||||
return self._pages.setdefault(task_id, [])
|
||||
|
||||
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
with self._lock:
|
||||
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||
max_pages = max((len(s) for s in states), default=0)
|
||||
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
class PagedCache:
|
||||
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
|
||||
class Storage:
|
||||
"""KV-cache tensor storage with paged write/gather."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -148,10 +202,6 @@ class PagedCache:
|
|||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self._prefix = PrefixCache(page_size)
|
||||
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
|
||||
self._table = TaskTable(page_size)
|
||||
|
||||
self.k_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
|
|
@ -163,80 +213,6 @@ class PagedCache:
|
|||
dtype=dtype,
|
||||
)
|
||||
|
||||
def alloc_n(self, n: int) -> List[int]:
|
||||
pages: List[int] = []
|
||||
for _ in range(n):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for page in pages:
|
||||
self.free(page)
|
||||
return []
|
||||
pages.append(p)
|
||||
return pages
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
cached = self._prefix.has_page(idx)
|
||||
self._pool.free(idx, keep_cached=cached)
|
||||
if not cached:
|
||||
self._prefix.on_evict(idx)
|
||||
|
||||
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||
hits = self._prefix.lookup(prompt_ids, self._pool)
|
||||
cached = len(hits) * self.page_size
|
||||
for p in hits:
|
||||
self._pool.inc_ref(p)
|
||||
|
||||
remaining = len(prompt_ids) - cached
|
||||
n_new = (
|
||||
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||
)
|
||||
new_pages: List[int] = []
|
||||
if n_new > 0:
|
||||
for _ in range(n_new):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for hp in hits:
|
||||
self.free(hp)
|
||||
for np in new_pages:
|
||||
self.free(np)
|
||||
return False
|
||||
new_pages.append(p)
|
||||
|
||||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str) -> None:
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self.free(idx)
|
||||
|
||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||
page_table = self._table.get(task_id)
|
||||
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
||||
while len(page_table) < needed:
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
return False
|
||||
page_table.append(p)
|
||||
return True
|
||||
|
||||
def task_cached(self, task_id: str) -> int:
|
||||
return self._table.get_cached(task_id)
|
||||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
) -> None:
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
|
||||
|
||||
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
return self._table.table_tensor(task_ids, device)
|
||||
|
||||
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
||||
return CacheView(self, page_table, total_len)
|
||||
|
||||
def write(
|
||||
self,
|
||||
layer_id: int,
|
||||
|
|
@ -259,6 +235,9 @@ class PagedCache:
|
|||
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||
offset = write_start - page_start
|
||||
chunk = write_end - write_start
|
||||
if (phys_pages < 0).any():
|
||||
written += chunk
|
||||
continue
|
||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||
:, written : written + chunk
|
||||
]
|
||||
|
|
@ -280,17 +259,95 @@ class PagedCache:
|
|||
return k, v
|
||||
|
||||
|
||||
class CacheView:
|
||||
"""Bundles PagedCache + page_table + total_len for attention layers."""
|
||||
class KvcacheView:
|
||||
"""Bundles Storage + page_table + total_len for attention layers."""
|
||||
|
||||
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
||||
self._cache = cache
|
||||
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
|
||||
self._storage = storage
|
||||
self._page_table = page_table
|
||||
self._total_len = total_len
|
||||
|
||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
||||
start_pos = self._total_len - k.size(1)
|
||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
||||
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||
|
||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
||||
return self._storage.gather(layer_id, self._page_table, self._total_len)
|
||||
|
||||
|
||||
class KVCache:
|
||||
"""Facade: page management + KV-cache I/O for continuous batching."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
n_pages: int,
|
||||
page_size: int,
|
||||
n_kv_heads: int,
|
||||
head_dim: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
|
||||
self._table = TaskTable(page_size)
|
||||
self._storage = Storage(
|
||||
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
|
||||
)
|
||||
|
||||
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||
hits = self._pool.lookup(prompt_ids)
|
||||
cached = len(hits) * self.page_size
|
||||
for p in hits:
|
||||
self._pool.inc_ref(p)
|
||||
|
||||
remaining = len(prompt_ids) - cached
|
||||
n_new = (
|
||||
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||
)
|
||||
new_pages: List[int] = []
|
||||
if n_new > 0:
|
||||
for _ in range(n_new):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for hp in hits:
|
||||
self._pool.free(hp)
|
||||
for np in new_pages:
|
||||
self._pool.free(np)
|
||||
return False
|
||||
new_pages.append(p)
|
||||
|
||||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str) -> None:
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self._pool.free(idx)
|
||||
|
||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||
page_table = self._table.get(task_id)
|
||||
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
||||
while len(page_table) < needed:
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
return False
|
||||
page_table.append(p)
|
||||
return True
|
||||
|
||||
def task_cached(self, task_id: str) -> int:
|
||||
return self._table.get_cached(task_id)
|
||||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
) -> None:
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
self._pool.record(page_table[i], prompt_ids, i)
|
||||
|
||||
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
return self._table.table_tensor(task_ids, device)
|
||||
|
||||
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
|
||||
return KvcacheView(self._storage, page_table, total_len)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
|||
|
||||
import torch
|
||||
|
||||
from astrai.inference.core.cache import PagedCache
|
||||
from astrai.inference.core.cache import KVCache
|
||||
from astrai.inference.core.task import Task
|
||||
from astrai.inference.sample import sample
|
||||
from astrai.model.automodel import AutoModel
|
||||
|
|
@ -19,7 +19,7 @@ class Executor:
|
|||
self,
|
||||
model: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
page_cache: PagedCache,
|
||||
page_cache: KVCache,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from astrai.inference.core.cache import PagedCache
|
||||
from astrai.inference.core.cache import KVCache
|
||||
from astrai.inference.core.executor import Executor
|
||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||
from astrai.model.automodel import AutoModel
|
||||
|
|
@ -37,7 +37,7 @@ class InferenceScheduler:
|
|||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||
) // page_size
|
||||
|
||||
self._page_cache = PagedCache(
|
||||
self._page_cache = KVCache(
|
||||
config.n_layers,
|
||||
n_pages,
|
||||
page_size,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class Task:
|
|||
self,
|
||||
task_id: str,
|
||||
prompt_ids: List[int],
|
||||
max_tokens: int = 1024,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
|
|
@ -54,7 +54,7 @@ class Task:
|
|||
return self.input_tokens + len(self.output_ids)
|
||||
|
||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||
if self.output_tokens >= self.max_tokens:
|
||||
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
|
||||
return True
|
||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||
return True
|
||||
|
|
@ -88,7 +88,7 @@ class TaskManager:
|
|||
def add_task(
|
||||
self,
|
||||
prompt: str,
|
||||
max_tokens: int = 1024,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
|
|
@ -104,6 +104,9 @@ class TaskManager:
|
|||
stream_callback(STOP)
|
||||
return task_id
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = self.max_seq_len - len(prompt_ids)
|
||||
else:
|
||||
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
||||
|
||||
task = Task(
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import asyncio
|
||||
import gc
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
|
@ -14,6 +13,17 @@ from astrai.inference.core.task import STOP
|
|||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
def _validate_sampling_params(
|
||||
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
||||
):
|
||||
if not (isinstance(top_k, int) and top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class GenerateResult:
|
||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||
|
||||
|
|
@ -58,24 +68,6 @@ class GenerateResult:
|
|||
return self.results.copy()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GenerationParams:
|
||||
"""Immutable value object for sampling hyperparameters."""
|
||||
|
||||
top_k: int = 50
|
||||
top_p: float = 1.0
|
||||
temperature: float = 1.0
|
||||
max_tokens: int = 1024
|
||||
|
||||
def __post_init__(self):
|
||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
||||
raise ValueError("top_k must be a non-negative integer")
|
||||
if not (0.0 <= self.top_p <= 1.0):
|
||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
||||
raise ValueError("temperature must be a non-negative number")
|
||||
|
||||
|
||||
class GenerationRequest:
|
||||
"""Request parameters for text generation."""
|
||||
|
||||
|
|
@ -85,34 +77,18 @@ class GenerationRequest:
|
|||
top_k: int = 50,
|
||||
top_p: float = 1.0,
|
||||
temperature: float = 1.0,
|
||||
max_len: int = 1024,
|
||||
max_tokens: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
):
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
|
||||
self.messages = messages
|
||||
self.params = GenerationParams(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
max_tokens=max_len,
|
||||
)
|
||||
self.top_k = top_k
|
||||
self.top_p = top_p
|
||||
self.temperature = temperature
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
|
||||
@property
|
||||
def top_k(self) -> int:
|
||||
return self.params.top_k
|
||||
|
||||
@property
|
||||
def top_p(self) -> float:
|
||||
return self.params.top_p
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
return self.params.temperature
|
||||
|
||||
@property
|
||||
def max_len(self) -> int:
|
||||
return self.params.max_tokens
|
||||
|
||||
|
||||
class InferenceEngine:
|
||||
"""Unified inference engine backed by continuous-batching scheduler."""
|
||||
|
|
@ -150,37 +126,36 @@ class InferenceEngine:
|
|||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
stream: bool = False,
|
||||
max_tokens: int = 1024,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> Union[Generator, str, List[str]]:
|
||||
params = GenerationParams(
|
||||
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
|
||||
)
|
||||
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
is_batch = isinstance(prompt, list)
|
||||
prompts = prompt if is_batch else [prompt]
|
||||
|
||||
if stream:
|
||||
return self._generate_streaming(prompts, is_batch, params)
|
||||
return self._generate_streaming(
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
else:
|
||||
return self._generate_non_streaming(prompts, is_batch, params)
|
||||
return self._generate_non_streaming(
|
||||
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
||||
def generate_async(
|
||||
self,
|
||||
prompt: str,
|
||||
params: Optional[GenerationParams] = None,
|
||||
max_tokens: int = 1024,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 1.0,
|
||||
top_k: int = 50,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
if params is None:
|
||||
params = GenerationParams(
|
||||
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
|
||||
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||
sync_gen = self._generate_streaming(
|
||||
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
sync_gen = self._generate_streaming([prompt], False, params)
|
||||
|
||||
async def _agen():
|
||||
loop = asyncio.get_event_loop()
|
||||
|
|
@ -206,14 +181,19 @@ class InferenceEngine:
|
|||
return self.generate(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
max_tokens=request.params.max_tokens,
|
||||
temperature=request.params.temperature,
|
||||
top_p=request.params.top_p,
|
||||
top_k=request.params.top_k,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
)
|
||||
|
||||
def _submit_tasks(
|
||||
self, prompts: List[str], params: GenerationParams
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: Optional[int],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Tuple[GenerateResult, List[str]]:
|
||||
n = len(prompts)
|
||||
result = GenerateResult(count=n)
|
||||
|
|
@ -222,10 +202,10 @@ class InferenceEngine:
|
|||
cb = self._make_callback(result, i)
|
||||
task_id = self.scheduler.add_task(
|
||||
prompt=p,
|
||||
max_tokens=params.max_tokens,
|
||||
temperature=params.temperature,
|
||||
top_p=params.top_p,
|
||||
top_k=params.top_k,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream_callback=cb,
|
||||
)
|
||||
task_ids.append(task_id)
|
||||
|
|
@ -239,9 +219,17 @@ class InferenceEngine:
|
|||
return cb
|
||||
|
||||
def _generate_streaming(
|
||||
self, prompts: List[str], is_batch: bool, params: GenerationParams
|
||||
self,
|
||||
prompts: List[str],
|
||||
is_batch: bool,
|
||||
max_tokens: Optional[int],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Generator:
|
||||
result, task_ids = self._submit_tasks(prompts, params)
|
||||
result, task_ids = self._submit_tasks(
|
||||
prompts, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
n = len(prompts)
|
||||
remaining = n
|
||||
finished = [False] * n
|
||||
|
|
@ -267,9 +255,17 @@ class InferenceEngine:
|
|||
return gen()
|
||||
|
||||
def _generate_non_streaming(
|
||||
self, prompts: List[str], is_batch: bool, params: GenerationParams
|
||||
self,
|
||||
prompts: List[str],
|
||||
is_batch: bool,
|
||||
max_tokens: Optional[int],
|
||||
temperature: float,
|
||||
top_p: float,
|
||||
top_k: int,
|
||||
) -> Union[str, List[str]]:
|
||||
result, task_ids = self._submit_tasks(prompts, params)
|
||||
result, task_ids = self._submit_tasks(
|
||||
prompts, max_tokens, temperature, top_p, top_k
|
||||
)
|
||||
|
||||
result.wait_completion()
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.inference.core.cache import CacheView
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
|
|
@ -147,7 +147,7 @@ class GQA(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attn_mask: Tensor = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
is_causal = attn_mask is None
|
||||
|
||||
|
|
@ -227,7 +227,7 @@ class MLA(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attn_mask: Tensor = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
bsz, seq_len, _ = x.size()
|
||||
is_causal = attn_mask is None
|
||||
|
|
@ -306,7 +306,7 @@ class DecoderBlock(nn.Module):
|
|||
x: Tensor,
|
||||
rotary_emb: Tensor,
|
||||
attention_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
) -> Tensor:
|
||||
attn_output = self.attention(
|
||||
self.input_norm(x),
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.inference.core.cache import CacheView
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.module import (
|
||||
DecoderBlock,
|
||||
|
|
@ -122,7 +122,7 @@ class Transformer(AutoModel):
|
|||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor] = None,
|
||||
paged_cache: Optional[CacheView] = None,
|
||||
paged_cache: Optional[KvcacheView] = None,
|
||||
position_ids: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
|
||||
"""Benchmark Transformer with KVCache"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict
|
||||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict
|
|||
import torch
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
from astrai.inference import PagedCache
|
||||
from astrai.inference import KVCache
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ class GenerationBenchmark:
|
|||
self.model.eval()
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||
self._page_cache = PagedCache(
|
||||
self._page_cache = KVCache(
|
||||
config.n_layers,
|
||||
n_pages,
|
||||
page_size,
|
||||
|
|
@ -130,7 +130,12 @@ class GenerationBenchmark:
|
|||
)
|
||||
|
||||
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
|
||||
pages = self._page_cache.alloc_n(n_pages * batch_size)
|
||||
total = n_pages * batch_size
|
||||
pages = []
|
||||
for _ in range(total):
|
||||
p = self._page_cache._pool.alloc()
|
||||
assert p >= 0, "OOM"
|
||||
pages.append(p)
|
||||
page_table = torch.tensor(
|
||||
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
|
||||
dtype=torch.long,
|
||||
|
|
@ -176,7 +181,7 @@ class GenerationBenchmark:
|
|||
total_time += trial_time
|
||||
|
||||
for idx in pages:
|
||||
self._page_cache.free(idx)
|
||||
self._page_cache._pool.free(idx)
|
||||
|
||||
print(
|
||||
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||
|
|
@ -225,7 +230,7 @@ if __name__ == "__main__":
|
|||
benchmark = GenerationBenchmark(config)
|
||||
|
||||
print("=" * 80)
|
||||
print("Running Transformer Generation Benchmark (PagedCache)")
|
||||
print("Running Transformer Generation Benchmark (KVCache)")
|
||||
print("=" * 80)
|
||||
|
||||
prefill_result = benchmark.run_prefill_benchmark(
|
||||
|
|
|
|||
|
|
@ -3,14 +3,20 @@
|
|||
import torch
|
||||
|
||||
from astrai.inference import (
|
||||
PagedCache,
|
||||
Allocator,
|
||||
KVCache,
|
||||
PagePool,
|
||||
PrefixCache,
|
||||
Storage,
|
||||
TaskTable,
|
||||
page_hash,
|
||||
)
|
||||
|
||||
|
||||
def make_pool(n_pages: int, page_size: int) -> PagePool:
|
||||
return PagePool(Allocator(n_pages), PrefixCache(page_size))
|
||||
|
||||
|
||||
def test_page_hash_full_page():
|
||||
token_ids = list(range(256))
|
||||
h = page_hash(token_ids, 0, 64)
|
||||
|
|
@ -24,7 +30,7 @@ def test_page_hash_different_page_differs():
|
|||
|
||||
|
||||
def test_page_pool_alloc_free_cycle():
|
||||
pool = PagePool(n_pages=4)
|
||||
pool = make_pool(4, 64)
|
||||
a = pool.alloc()
|
||||
b = pool.alloc()
|
||||
assert a != b
|
||||
|
|
@ -35,130 +41,101 @@ def test_page_pool_alloc_free_cycle():
|
|||
|
||||
|
||||
def test_page_pool_alloc_when_full():
|
||||
pool = PagePool(n_pages=2)
|
||||
pool = make_pool(2, 64)
|
||||
pool.alloc()
|
||||
pool.alloc()
|
||||
assert pool.alloc() == -1
|
||||
|
||||
|
||||
def test_page_pool_lru_eviction():
|
||||
evicted = []
|
||||
|
||||
def on_evict(idx):
|
||||
evicted.append(idx)
|
||||
|
||||
pool = PagePool(n_pages=2, on_evict=on_evict)
|
||||
pool = make_pool(2, 64)
|
||||
p0 = pool.alloc()
|
||||
p1 = pool.alloc()
|
||||
pool.free(p0, keep_cached=True)
|
||||
pool.free(p1, keep_cached=True)
|
||||
pool.record(p0, list(range(64)), 0)
|
||||
pool.record(p1, list(range(64, 128)), 0)
|
||||
pool.free(p0)
|
||||
pool.free(p1)
|
||||
pool.alloc()
|
||||
assert len(evicted) == 1
|
||||
assert evicted[0] == p0
|
||||
assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
|
||||
|
||||
|
||||
def test_page_pool_inc_ref_and_free():
|
||||
pool = PagePool(n_pages=2)
|
||||
pool = make_pool(2, 64)
|
||||
p = pool.alloc()
|
||||
pool.inc_ref(p)
|
||||
assert pool._refs[p] == 2
|
||||
assert pool._alloc._refs[p] == 2
|
||||
pool.free(p)
|
||||
assert pool._refs[p] == 1
|
||||
assert pool._alloc._refs[p] == 1
|
||||
pool.free(p)
|
||||
assert pool._refs[p] == 0
|
||||
|
||||
|
||||
def test_page_pool_touch_moves_to_end():
|
||||
pool = PagePool(n_pages=4)
|
||||
p0 = pool.alloc()
|
||||
p1 = pool.alloc()
|
||||
p2 = pool.alloc()
|
||||
pool.free(p0, keep_cached=True)
|
||||
pool.free(p1, keep_cached=True)
|
||||
pool.free(p2, keep_cached=True)
|
||||
assert next(iter(pool._lru)) == p0
|
||||
pool.touch(p0)
|
||||
assert next(reversed(pool._lru)) == p0
|
||||
|
||||
|
||||
def test_page_pool_remove_from_lru():
|
||||
pool = PagePool(n_pages=4)
|
||||
p0 = pool.alloc()
|
||||
pool.free(p0, keep_cached=True)
|
||||
assert p0 in pool._lru
|
||||
pool.remove_from_lru(p0)
|
||||
assert p0 not in pool._lru
|
||||
assert pool._alloc._refs[p] == 0
|
||||
|
||||
|
||||
def test_page_pool_keep_cached_realloc():
|
||||
"""Free mask has priority over LRU; cached page returned only when no free pages."""
|
||||
pool = PagePool(n_pages=3)
|
||||
pool = make_pool(3, 64)
|
||||
p0 = pool.alloc()
|
||||
p1 = pool.alloc()
|
||||
p2 = pool.alloc()
|
||||
pool.free(p0, keep_cached=True)
|
||||
pool.free(p1, keep_cached=True)
|
||||
pool.free(p2, keep_cached=True)
|
||||
for p in (p0, p1, p2):
|
||||
pool.record(p, [p] * 64, 0)
|
||||
pool.free(p0)
|
||||
pool.free(p1)
|
||||
pool.free(p2)
|
||||
assert pool.alloc() == p0
|
||||
|
||||
|
||||
def _record_then_cache(pool, prefix, page, token_ids, logical_idx):
|
||||
"""Simulate the real lifecycle: record → ref stays >0, then free cached returns to LRU."""
|
||||
prefix.record(page, token_ids, logical_idx, pool)
|
||||
pool.free(page, keep_cached=True)
|
||||
|
||||
|
||||
def test_prefix_cache_lookup_returns_hits():
|
||||
token_ids = list(range(256))
|
||||
pool = PagePool(n_pages=16)
|
||||
prefix = PrefixCache(page_size=64)
|
||||
pool = make_pool(16, 64)
|
||||
pages = [pool.alloc() for _ in range(4)]
|
||||
for i, p in enumerate(pages):
|
||||
_record_then_cache(pool, prefix, p, token_ids, i)
|
||||
hits = prefix.lookup(token_ids, pool)
|
||||
pool.record(p, token_ids, i)
|
||||
pool.free(p)
|
||||
hits = pool.lookup(token_ids)
|
||||
assert hits == pages
|
||||
|
||||
|
||||
def test_prefix_cache_lookup_stops_at_first_miss():
|
||||
token_ids = list(range(256))
|
||||
pool = PagePool(n_pages=16)
|
||||
prefix = PrefixCache(page_size=64)
|
||||
pool = make_pool(16, 64)
|
||||
p0 = pool.alloc()
|
||||
_record_then_cache(pool, prefix, p0, token_ids, 0)
|
||||
pool.record(p0, token_ids, 0)
|
||||
pool.free(p0)
|
||||
p1 = pool.alloc()
|
||||
_record_then_cache(pool, prefix, p1, [99] * 64, 1)
|
||||
hits = prefix.lookup(token_ids, pool)
|
||||
pool.record(p1, [99] * 64, 1)
|
||||
pool.free(p1)
|
||||
hits = pool.lookup(token_ids)
|
||||
assert len(hits) == 1
|
||||
assert hits[0] == p0
|
||||
|
||||
|
||||
def test_prefix_cache_ignores_partial_last_page():
|
||||
token_ids = list(range(100))
|
||||
pool = PagePool(n_pages=16)
|
||||
prefix = PrefixCache(page_size=64)
|
||||
pool = make_pool(16, 64)
|
||||
p = pool.alloc()
|
||||
_record_then_cache(pool, prefix, p, token_ids, 0)
|
||||
hits = prefix.lookup(token_ids, pool)
|
||||
pool.record(p, token_ids, 0)
|
||||
pool.free(p)
|
||||
hits = pool.lookup(token_ids)
|
||||
assert len(hits) == 1
|
||||
|
||||
|
||||
def test_prefix_cache_on_evict_clears_mappings():
|
||||
pool = PagePool(n_pages=4)
|
||||
prefix = PrefixCache(page_size=64)
|
||||
pool = make_pool(4, 64)
|
||||
p = pool.alloc()
|
||||
_record_then_cache(pool, prefix, p, list(range(64)), 0)
|
||||
assert prefix.has_page(p)
|
||||
prefix.on_evict(p)
|
||||
assert not prefix.has_page(p)
|
||||
pool.record(p, list(range(64)), 0)
|
||||
pool.free(p)
|
||||
assert p in pool._prefix._page_to_hash
|
||||
pool._prefix.evict(p)
|
||||
assert p not in pool._prefix._page_to_hash
|
||||
|
||||
|
||||
def test_prefix_cache_has_page():
|
||||
pool = PagePool(n_pages=4)
|
||||
prefix = PrefixCache(page_size=64)
|
||||
pool = make_pool(4, 64)
|
||||
p = pool.alloc()
|
||||
assert not prefix.has_page(p)
|
||||
_record_then_cache(pool, prefix, p, list(range(64)), 0)
|
||||
assert prefix.has_page(p)
|
||||
assert p not in pool._prefix._page_to_hash
|
||||
pool.record(p, list(range(64)), 0)
|
||||
pool.free(p)
|
||||
assert p in pool._prefix._page_to_hash
|
||||
|
||||
|
||||
def test_task_table_set_get():
|
||||
|
|
@ -183,8 +160,8 @@ def test_task_table_pop():
|
|||
assert table.get("task1") == []
|
||||
|
||||
|
||||
def test_paged_cache_task_extend_allocates():
|
||||
cache = PagedCache(
|
||||
def test_kv_cache_task_extend_allocates():
|
||||
cache = KVCache(
|
||||
n_layers=1,
|
||||
n_pages=8,
|
||||
page_size=64,
|
||||
|
|
@ -199,8 +176,8 @@ def test_paged_cache_task_extend_allocates():
|
|||
assert len(cache._table.get("task1")) == 4
|
||||
|
||||
|
||||
def test_paged_cache_task_extend_fails_when_pool_full():
|
||||
cache = PagedCache(
|
||||
def test_kv_cache_task_extend_fails_when_pool_full():
|
||||
cache = KVCache(
|
||||
n_layers=1,
|
||||
n_pages=2,
|
||||
page_size=64,
|
||||
|
|
@ -230,8 +207,8 @@ def test_task_table_table_tensor_empty_input():
|
|||
assert t.numel() == 0
|
||||
|
||||
|
||||
def test_paged_cache_write_gather_single_page():
|
||||
cache = PagedCache(
|
||||
def test_storage_write_gather_single_page():
|
||||
storage = Storage(
|
||||
n_layers=2,
|
||||
n_pages=8,
|
||||
page_size=4,
|
||||
|
|
@ -244,13 +221,13 @@ def test_paged_cache_write_gather_single_page():
|
|||
k = torch.randn(1, 2, 2, 8)
|
||||
v = torch.randn(1, 2, 2, 8)
|
||||
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
gk, gv = cache.gather(0, page_table, 2)
|
||||
storage.write(0, page_table, 0, k, v)
|
||||
gk, gv = storage.gather(0, page_table, 2)
|
||||
assert torch.allclose(gk, k)
|
||||
|
||||
|
||||
def test_paged_cache_write_cross_page():
|
||||
cache = PagedCache(
|
||||
def test_storage_write_cross_page():
|
||||
storage = Storage(
|
||||
n_layers=1,
|
||||
n_pages=8,
|
||||
page_size=4,
|
||||
|
|
@ -263,13 +240,13 @@ def test_paged_cache_write_cross_page():
|
|||
k = torch.randn(1, 8, 2, 8)
|
||||
v = torch.randn(1, 8, 2, 8)
|
||||
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
gk, gv = cache.gather(0, page_table, 8)
|
||||
storage.write(0, page_table, 0, k, v)
|
||||
gk, gv = storage.gather(0, page_table, 8)
|
||||
assert torch.allclose(gk, k)
|
||||
|
||||
|
||||
def test_paged_cache_gather_truncates_to_total_len():
|
||||
cache = PagedCache(
|
||||
def test_storage_gather_truncates_to_total_len():
|
||||
storage = Storage(
|
||||
n_layers=1,
|
||||
n_pages=8,
|
||||
page_size=4,
|
||||
|
|
@ -281,14 +258,14 @@ def test_paged_cache_gather_truncates_to_total_len():
|
|||
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||
k = torch.randn(1, 6, 2, 8)
|
||||
v = torch.randn(1, 6, 2, 8)
|
||||
cache.write(0, page_table, 0, k, v)
|
||||
storage.write(0, page_table, 0, k, v)
|
||||
|
||||
gk, gv = cache.gather(0, page_table, 5)
|
||||
gk, gv = storage.gather(0, page_table, 5)
|
||||
assert gk.shape == (1, 5, 2, 8)
|
||||
|
||||
|
||||
def test_paged_cache_gather_clamps_negative_padding():
|
||||
cache = PagedCache(
|
||||
def test_storage_gather_clamps_negative_padding():
|
||||
storage = Storage(
|
||||
n_layers=1,
|
||||
n_pages=8,
|
||||
page_size=4,
|
||||
|
|
@ -298,5 +275,5 @@ def test_paged_cache_gather_clamps_negative_padding():
|
|||
dtype=torch.float32,
|
||||
)
|
||||
page_table = torch.tensor([[0, -1]], dtype=torch.long)
|
||||
gk, gv = cache.gather(0, page_table, 4)
|
||||
gk, gv = storage.gather(0, page_table, 4)
|
||||
assert gk.shape == (1, 4, 2, 8)
|
||||
|
|
|
|||
Loading…
Reference in New Issue