refactor: 重构 cache 和 inference 参数体系,分离存储与分配
- 合并 GenerationRequest/GenerationParams,统一 max_tokens 参数名 - PagePool/PrefixCache 分离为 Allocator + PrefixCache + PagePool - 拆分 KV 存储为独立 Storage 类,PagedCache → KVCache,CacheView → KvcacheView - Allocator.inc_ref 移除 LRU 防止竞争,Storage.write 增加负页防御 - Allocator/PrefixCache/TaskTable 加 threading.Lock 保证线程安全 - server.py uvicorn.run 改为传 app 对象修复导入错误 - benchmark.py 适配 KVCache 新 API
This commit is contained in:
parent
18fe6e9339
commit
205b40bd28
|
|
@ -3,7 +3,7 @@
|
||||||
Layers:
|
Layers:
|
||||||
- core/: Core inference loop (cache, executor, scheduler, task)
|
- core/: Core inference loop (cache, executor, scheduler, task)
|
||||||
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
- api/: HTTP protocol handlers (OpenAI, Anthropic)
|
||||||
- engine.py: Facade (InferenceEngine), Value Object (GenerationParams, GenerationRequest)
|
- engine.py: Facade (InferenceEngine), Value Object (GenerationRequest)
|
||||||
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
- sample.py: Strategy pattern (TemperatureStrategy, TopKStrategy, TopPStrategy)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -22,12 +22,14 @@ from astrai.inference.api import (
|
||||||
)
|
)
|
||||||
from astrai.inference.core import (
|
from astrai.inference.core import (
|
||||||
STOP,
|
STOP,
|
||||||
CacheView,
|
Allocator,
|
||||||
Executor,
|
Executor,
|
||||||
InferenceScheduler,
|
InferenceScheduler,
|
||||||
PagedCache,
|
KVCache,
|
||||||
|
KvcacheView,
|
||||||
PagePool,
|
PagePool,
|
||||||
PrefixCache,
|
PrefixCache,
|
||||||
|
Storage,
|
||||||
Task,
|
Task,
|
||||||
TaskManager,
|
TaskManager,
|
||||||
TaskStatus,
|
TaskStatus,
|
||||||
|
|
@ -35,7 +37,6 @@ from astrai.inference.core import (
|
||||||
page_hash,
|
page_hash,
|
||||||
)
|
)
|
||||||
from astrai.inference.engine import (
|
from astrai.inference.engine import (
|
||||||
GenerationParams,
|
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
|
|
@ -52,7 +53,6 @@ __all__ = [
|
||||||
# Engine / Requests
|
# Engine / Requests
|
||||||
"InferenceEngine",
|
"InferenceEngine",
|
||||||
"GenerationRequest",
|
"GenerationRequest",
|
||||||
"GenerationParams",
|
|
||||||
# Core scheduler
|
# Core scheduler
|
||||||
"InferenceScheduler",
|
"InferenceScheduler",
|
||||||
"Executor",
|
"Executor",
|
||||||
|
|
@ -61,10 +61,12 @@ __all__ = [
|
||||||
"TaskManager",
|
"TaskManager",
|
||||||
"TaskStatus",
|
"TaskStatus",
|
||||||
# Core cache
|
# Core cache
|
||||||
"CacheView",
|
"Allocator",
|
||||||
"PagedCache",
|
"KVCache",
|
||||||
|
"KvcacheView",
|
||||||
"PagePool",
|
"PagePool",
|
||||||
"PrefixCache",
|
"PrefixCache",
|
||||||
|
"Storage",
|
||||||
"TaskTable",
|
"TaskTable",
|
||||||
"page_hash",
|
"page_hash",
|
||||||
# Sampling (Strategy pattern)
|
# Sampling (Strategy pattern)
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from astrai.inference.engine import GenerationParams, InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
|
||||||
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
def _sse_event(data: Dict[str, Any], event: Optional[str] = None) -> str:
|
||||||
|
|
@ -143,13 +143,13 @@ class ProtocolHandler(ABC):
|
||||||
prompt_tokens=self._count_prompt_tokens(),
|
prompt_tokens=self._count_prompt_tokens(),
|
||||||
)
|
)
|
||||||
|
|
||||||
params = GenerationParams(
|
agen = self.engine.generate_async(
|
||||||
|
prompt=self.build_prompt(),
|
||||||
max_tokens=self.request.max_tokens,
|
max_tokens=self.request.max_tokens,
|
||||||
temperature=self.request.temperature,
|
temperature=self.request.temperature,
|
||||||
top_p=self.request.top_p,
|
top_p=self.request.top_p,
|
||||||
top_k=self.request.top_k,
|
top_k=self.request.top_k,
|
||||||
)
|
)
|
||||||
agen = self.engine.generate_async(prompt=self.build_prompt(), params=params)
|
|
||||||
|
|
||||||
if self.request.stream:
|
if self.request.stream:
|
||||||
return self._handle_stream(agen, ctx)
|
return self._handle_stream(agen, ctx)
|
||||||
|
|
|
||||||
|
|
@ -160,8 +160,7 @@ def run_server(
|
||||||
"max_batch_size": max_batch_size,
|
"max_batch_size": max_batch_size,
|
||||||
}
|
}
|
||||||
uvicorn.run(
|
uvicorn.run(
|
||||||
"astrai.inference.server:app",
|
app,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
reload=reload,
|
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
"""Inference core: cache, executor, scheduler, task management."""
|
"""Inference core: cache, executor, scheduler, task management."""
|
||||||
|
|
||||||
from astrai.inference.core.cache import (
|
from astrai.inference.core.cache import (
|
||||||
CacheView,
|
Allocator,
|
||||||
PagedCache,
|
KVCache,
|
||||||
|
KvcacheView,
|
||||||
PagePool,
|
PagePool,
|
||||||
PrefixCache,
|
PrefixCache,
|
||||||
|
Storage,
|
||||||
TaskTable,
|
TaskTable,
|
||||||
page_hash,
|
page_hash,
|
||||||
)
|
)
|
||||||
|
|
@ -13,10 +15,12 @@ from astrai.inference.core.scheduler import InferenceScheduler
|
||||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CacheView",
|
"Allocator",
|
||||||
"PagedCache",
|
"KVCache",
|
||||||
|
"KvcacheView",
|
||||||
"PagePool",
|
"PagePool",
|
||||||
"PrefixCache",
|
"PrefixCache",
|
||||||
|
"Storage",
|
||||||
"TaskTable",
|
"TaskTable",
|
||||||
"page_hash",
|
"page_hash",
|
||||||
"Executor",
|
"Executor",
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import threading
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Callable, Dict, List, Optional, Tuple
|
from typing import Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
|
@ -14,16 +15,18 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
class PagePool:
|
class Allocator:
|
||||||
"""Bitmask page allocator with ref-counting and LRU eviction."""
|
"""Bitmask-based page allocator with ref-counting and LRU eviction."""
|
||||||
|
|
||||||
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
|
def __init__(self, n_pages: int):
|
||||||
self._free_mask = (1 << n_pages) - 1
|
self._free_mask = (1 << n_pages) - 1
|
||||||
self._refs: List[int] = [0] * n_pages
|
self._refs: List[int] = [0] * n_pages
|
||||||
self._lru: OrderedDict[int, None] = OrderedDict()
|
self._lru: OrderedDict[int, None] = OrderedDict()
|
||||||
self._on_evict = on_evict
|
self.on_evict: Optional[Callable[[int], None]] = None
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def alloc(self) -> int:
|
def alloc(self) -> int:
|
||||||
|
with self._lock:
|
||||||
if self._free_mask:
|
if self._free_mask:
|
||||||
lsb = self._free_mask & -self._free_mask
|
lsb = self._free_mask & -self._free_mask
|
||||||
idx = lsb.bit_length() - 1
|
idx = lsb.bit_length() - 1
|
||||||
|
|
@ -32,14 +35,15 @@ class PagePool:
|
||||||
return idx
|
return idx
|
||||||
if self._lru:
|
if self._lru:
|
||||||
idx, _ = self._lru.popitem(last=False)
|
idx, _ = self._lru.popitem(last=False)
|
||||||
if self._on_evict:
|
if self.on_evict:
|
||||||
self._on_evict(idx)
|
self.on_evict(idx)
|
||||||
self._refs[idx] = 1
|
self._refs[idx] = 1
|
||||||
self._free_mask &= ~(1 << idx)
|
self._free_mask &= ~(1 << idx)
|
||||||
return idx
|
return idx
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||||
|
with self._lock:
|
||||||
self._refs[idx] -= 1
|
self._refs[idx] -= 1
|
||||||
if self._refs[idx] == 0:
|
if self._refs[idx] == 0:
|
||||||
if keep_cached:
|
if keep_cached:
|
||||||
|
|
@ -48,14 +52,18 @@ class PagePool:
|
||||||
self._free_mask |= 1 << idx
|
self._free_mask |= 1 << idx
|
||||||
|
|
||||||
def inc_ref(self, idx: int) -> None:
|
def inc_ref(self, idx: int) -> None:
|
||||||
|
with self._lock:
|
||||||
self._refs[idx] += 1
|
self._refs[idx] += 1
|
||||||
|
self._lru.pop(idx, None)
|
||||||
|
|
||||||
|
def ref_count(self, idx: int) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return self._refs[idx]
|
||||||
|
|
||||||
def touch(self, idx: int) -> None:
|
def touch(self, idx: int) -> None:
|
||||||
|
with self._lock:
|
||||||
self._lru.move_to_end(idx)
|
self._lru.move_to_end(idx)
|
||||||
|
|
||||||
def remove_from_lru(self, idx: int) -> None:
|
|
||||||
self._lru.pop(idx, None)
|
|
||||||
|
|
||||||
|
|
||||||
class PrefixCache:
|
class PrefixCache:
|
||||||
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
||||||
|
|
@ -64,16 +72,20 @@ class PrefixCache:
|
||||||
self._page_size = page_size
|
self._page_size = page_size
|
||||||
self._page_to_hash: Dict[int, int] = {}
|
self._page_to_hash: Dict[int, int] = {}
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def on_evict(self, idx: int) -> None:
|
def evict(self, idx: int) -> None:
|
||||||
|
with self._lock:
|
||||||
h = self._page_to_hash.pop(idx, None)
|
h = self._page_to_hash.pop(idx, None)
|
||||||
if h is not None:
|
if h is not None:
|
||||||
self._hash_to_page.pop(h, None)
|
self._hash_to_page.pop(h, None)
|
||||||
|
|
||||||
def has_page(self, idx: int) -> bool:
|
def has_page(self, idx: int) -> bool:
|
||||||
|
with self._lock:
|
||||||
return idx in self._page_to_hash
|
return idx in self._page_to_hash
|
||||||
|
|
||||||
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
|
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||||
|
with self._lock:
|
||||||
full_pages = len(token_ids) // self._page_size
|
full_pages = len(token_ids) // self._page_size
|
||||||
hits: List[int] = []
|
hits: List[int] = []
|
||||||
for i in range(full_pages):
|
for i in range(full_pages):
|
||||||
|
|
@ -81,24 +93,59 @@ class PrefixCache:
|
||||||
p = self._hash_to_page.get(h)
|
p = self._hash_to_page.get(h)
|
||||||
if p is None:
|
if p is None:
|
||||||
break
|
break
|
||||||
pool.touch(p)
|
|
||||||
hits.append(p)
|
hits.append(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
def record(
|
def record(
|
||||||
self,
|
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||||
page_idx: int,
|
|
||||||
token_ids: List[int],
|
|
||||||
logical_page_idx: int,
|
|
||||||
pool: PagePool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
|
with self._lock:
|
||||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||||
old_h = self._page_to_hash.pop(page_idx, None)
|
old_h = self._page_to_hash.pop(page_idx, None)
|
||||||
if old_h is not None:
|
if old_h is not None:
|
||||||
self._hash_to_page.pop(old_h, None)
|
self._hash_to_page.pop(old_h, None)
|
||||||
self._page_to_hash[page_idx] = h
|
self._page_to_hash[page_idx] = h
|
||||||
self._hash_to_page[h] = page_idx
|
self._hash_to_page[h] = page_idx
|
||||||
pool.remove_from_lru(page_idx)
|
|
||||||
|
|
||||||
|
class PagePool:
|
||||||
|
"""Orchestrates allocator (page management) and PrefixCache (content addressing)."""
|
||||||
|
|
||||||
|
def __init__(self, allocator: Allocator, prefix: PrefixCache):
|
||||||
|
self._alloc = allocator
|
||||||
|
self._prefix = prefix
|
||||||
|
self._alloc.on_evict = prefix.evict
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allocator(self) -> Allocator:
|
||||||
|
return self._alloc
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefix(self) -> PrefixCache:
|
||||||
|
return self._prefix
|
||||||
|
|
||||||
|
def alloc(self) -> int:
|
||||||
|
return self._alloc.alloc()
|
||||||
|
|
||||||
|
def free(self, idx: int) -> None:
|
||||||
|
keep = self._prefix.has_page(idx)
|
||||||
|
self._alloc.free(idx, keep_cached=keep)
|
||||||
|
if not keep:
|
||||||
|
self._prefix.evict(idx)
|
||||||
|
|
||||||
|
def inc_ref(self, idx: int) -> None:
|
||||||
|
self._alloc.inc_ref(idx)
|
||||||
|
|
||||||
|
def lookup(self, token_ids: List[int]) -> List[int]:
|
||||||
|
hits = self._prefix.lookup(token_ids)
|
||||||
|
for p in hits:
|
||||||
|
self._alloc.touch(p)
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||||
|
) -> None:
|
||||||
|
self._prefix.record(page_idx, token_ids, logical_page_idx)
|
||||||
|
|
||||||
|
|
||||||
class TaskTable:
|
class TaskTable:
|
||||||
|
|
@ -108,34 +155,41 @@ class TaskTable:
|
||||||
self._page_size = page_size
|
self._page_size = page_size
|
||||||
self._pages: Dict[str, List[int]] = {}
|
self._pages: Dict[str, List[int]] = {}
|
||||||
self._cached: Dict[str, int] = {}
|
self._cached: Dict[str, int] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||||
|
with self._lock:
|
||||||
self._pages[task_id] = page_table
|
self._pages[task_id] = page_table
|
||||||
self._cached[task_id] = cached
|
self._cached[task_id] = cached
|
||||||
|
|
||||||
def get(self, task_id: str) -> List[int]:
|
def get(self, task_id: str) -> List[int]:
|
||||||
|
with self._lock:
|
||||||
return self._pages.get(task_id, [])
|
return self._pages.get(task_id, [])
|
||||||
|
|
||||||
def get_cached(self, task_id: str) -> int:
|
def get_cached(self, task_id: str) -> int:
|
||||||
|
with self._lock:
|
||||||
return self._cached.get(task_id, 0)
|
return self._cached.get(task_id, 0)
|
||||||
|
|
||||||
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
||||||
|
with self._lock:
|
||||||
pages = self._pages.pop(task_id, [])
|
pages = self._pages.pop(task_id, [])
|
||||||
cached = self._cached.pop(task_id, 0)
|
cached = self._cached.pop(task_id, 0)
|
||||||
return pages, cached
|
return pages, cached
|
||||||
|
|
||||||
def get_ref(self, task_id: str) -> List[int]:
|
def get_ref(self, task_id: str) -> List[int]:
|
||||||
|
with self._lock:
|
||||||
return self._pages.setdefault(task_id, [])
|
return self._pages.setdefault(task_id, [])
|
||||||
|
|
||||||
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||||
|
with self._lock:
|
||||||
states = [self._pages.get(tid, []) for tid in task_ids]
|
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||||
max_pages = max((len(s) for s in states), default=0)
|
max_pages = max((len(s) for s in states), default=0)
|
||||||
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||||
return torch.tensor(rows, dtype=torch.long, device=device)
|
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
|
||||||
class PagedCache:
|
class Storage:
|
||||||
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
|
"""KV-cache tensor storage with paged write/gather."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -148,10 +202,6 @@ class PagedCache:
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
self._prefix = PrefixCache(page_size)
|
|
||||||
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
|
|
||||||
self._table = TaskTable(page_size)
|
|
||||||
|
|
||||||
self.k_cache = torch.empty(
|
self.k_cache = torch.empty(
|
||||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||||
device=device,
|
device=device,
|
||||||
|
|
@ -163,80 +213,6 @@ class PagedCache:
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def alloc_n(self, n: int) -> List[int]:
|
|
||||||
pages: List[int] = []
|
|
||||||
for _ in range(n):
|
|
||||||
p = self._pool.alloc()
|
|
||||||
if p < 0:
|
|
||||||
for page in pages:
|
|
||||||
self.free(page)
|
|
||||||
return []
|
|
||||||
pages.append(p)
|
|
||||||
return pages
|
|
||||||
|
|
||||||
def free(self, idx: int) -> None:
|
|
||||||
cached = self._prefix.has_page(idx)
|
|
||||||
self._pool.free(idx, keep_cached=cached)
|
|
||||||
if not cached:
|
|
||||||
self._prefix.on_evict(idx)
|
|
||||||
|
|
||||||
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
|
||||||
hits = self._prefix.lookup(prompt_ids, self._pool)
|
|
||||||
cached = len(hits) * self.page_size
|
|
||||||
for p in hits:
|
|
||||||
self._pool.inc_ref(p)
|
|
||||||
|
|
||||||
remaining = len(prompt_ids) - cached
|
|
||||||
n_new = (
|
|
||||||
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
|
||||||
)
|
|
||||||
new_pages: List[int] = []
|
|
||||||
if n_new > 0:
|
|
||||||
for _ in range(n_new):
|
|
||||||
p = self._pool.alloc()
|
|
||||||
if p < 0:
|
|
||||||
for hp in hits:
|
|
||||||
self.free(hp)
|
|
||||||
for np in new_pages:
|
|
||||||
self.free(np)
|
|
||||||
return False
|
|
||||||
new_pages.append(p)
|
|
||||||
|
|
||||||
self._table.set(task_id, hits + new_pages, cached)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def task_free(self, task_id: str) -> None:
|
|
||||||
page_table, _ = self._table.pop(task_id)
|
|
||||||
for idx in page_table:
|
|
||||||
self.free(idx)
|
|
||||||
|
|
||||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
|
||||||
page_table = self._table.get(task_id)
|
|
||||||
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
|
||||||
while len(page_table) < needed:
|
|
||||||
p = self._pool.alloc()
|
|
||||||
if p < 0:
|
|
||||||
return False
|
|
||||||
page_table.append(p)
|
|
||||||
return True
|
|
||||||
|
|
||||||
def task_cached(self, task_id: str) -> int:
|
|
||||||
return self._table.get_cached(task_id)
|
|
||||||
|
|
||||||
def task_record_hashes(
|
|
||||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
|
||||||
) -> None:
|
|
||||||
page_table = self._table.get(task_id)
|
|
||||||
full_pages = len(prompt_ids) // self.page_size
|
|
||||||
for i in range(start_logical_page, full_pages):
|
|
||||||
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
|
|
||||||
|
|
||||||
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
|
||||||
return self._table.table_tensor(task_ids, device)
|
|
||||||
|
|
||||||
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
|
||||||
return CacheView(self, page_table, total_len)
|
|
||||||
|
|
||||||
def write(
|
def write(
|
||||||
self,
|
self,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
|
|
@ -259,6 +235,9 @@ class PagedCache:
|
||||||
write_end = min(page_start + page_size, start_pos + seq_len)
|
write_end = min(page_start + page_size, start_pos + seq_len)
|
||||||
offset = write_start - page_start
|
offset = write_start - page_start
|
||||||
chunk = write_end - write_start
|
chunk = write_end - write_start
|
||||||
|
if (phys_pages < 0).any():
|
||||||
|
written += chunk
|
||||||
|
continue
|
||||||
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
||||||
:, written : written + chunk
|
:, written : written + chunk
|
||||||
]
|
]
|
||||||
|
|
@ -280,17 +259,95 @@ class PagedCache:
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
|
|
||||||
class CacheView:
|
class KvcacheView:
|
||||||
"""Bundles PagedCache + page_table + total_len for attention layers."""
|
"""Bundles Storage + page_table + total_len for attention layers."""
|
||||||
|
|
||||||
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
def __init__(self, storage: Storage, page_table: Tensor, total_len: int = 0):
|
||||||
self._cache = cache
|
self._storage = storage
|
||||||
self._page_table = page_table
|
self._page_table = page_table
|
||||||
self._total_len = total_len
|
self._total_len = total_len
|
||||||
|
|
||||||
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
def write(self, layer_id: int, k: Tensor, v: Tensor) -> None:
|
||||||
start_pos = self._total_len - k.size(1)
|
start_pos = self._total_len - k.size(1)
|
||||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
self._storage.write(layer_id, self._page_table, start_pos, k, v)
|
||||||
|
|
||||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
return self._storage.gather(layer_id, self._page_table, self._total_len)
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache:
|
||||||
|
"""Facade: page management + KV-cache I/O for continuous batching."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_layers: int,
|
||||||
|
n_pages: int,
|
||||||
|
page_size: int,
|
||||||
|
n_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
self.page_size = page_size
|
||||||
|
self._pool = PagePool(Allocator(n_pages), PrefixCache(page_size))
|
||||||
|
self._table = TaskTable(page_size)
|
||||||
|
self._storage = Storage(
|
||||||
|
n_layers, n_pages, page_size, n_kv_heads, head_dim, device, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||||
|
hits = self._pool.lookup(prompt_ids)
|
||||||
|
cached = len(hits) * self.page_size
|
||||||
|
for p in hits:
|
||||||
|
self._pool.inc_ref(p)
|
||||||
|
|
||||||
|
remaining = len(prompt_ids) - cached
|
||||||
|
n_new = (
|
||||||
|
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||||
|
)
|
||||||
|
new_pages: List[int] = []
|
||||||
|
if n_new > 0:
|
||||||
|
for _ in range(n_new):
|
||||||
|
p = self._pool.alloc()
|
||||||
|
if p < 0:
|
||||||
|
for hp in hits:
|
||||||
|
self._pool.free(hp)
|
||||||
|
for np in new_pages:
|
||||||
|
self._pool.free(np)
|
||||||
|
return False
|
||||||
|
new_pages.append(p)
|
||||||
|
|
||||||
|
self._table.set(task_id, hits + new_pages, cached)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def task_free(self, task_id: str) -> None:
|
||||||
|
page_table, _ = self._table.pop(task_id)
|
||||||
|
for idx in page_table:
|
||||||
|
self._pool.free(idx)
|
||||||
|
|
||||||
|
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||||
|
page_table = self._table.get(task_id)
|
||||||
|
needed = (pos + 1 + self.page_size - 1) // self.page_size
|
||||||
|
while len(page_table) < needed:
|
||||||
|
p = self._pool.alloc()
|
||||||
|
if p < 0:
|
||||||
|
return False
|
||||||
|
page_table.append(p)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def task_cached(self, task_id: str) -> int:
|
||||||
|
return self._table.get_cached(task_id)
|
||||||
|
|
||||||
|
def task_record_hashes(
|
||||||
|
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||||
|
) -> None:
|
||||||
|
page_table = self._table.get(task_id)
|
||||||
|
full_pages = len(prompt_ids) // self.page_size
|
||||||
|
for i in range(start_logical_page, full_pages):
|
||||||
|
self._pool.record(page_table[i], prompt_ids, i)
|
||||||
|
|
||||||
|
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||||
|
return self._table.table_tensor(task_ids, device)
|
||||||
|
|
||||||
|
def bind(self, page_table: Tensor, total_len: int = 0) -> KvcacheView:
|
||||||
|
return KvcacheView(self._storage, page_table, total_len)
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference.core.cache import PagedCache
|
from astrai.inference.core.cache import KVCache
|
||||||
from astrai.inference.core.task import Task
|
from astrai.inference.core.task import Task
|
||||||
from astrai.inference.sample import sample
|
from astrai.inference.sample import sample
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
|
|
@ -19,7 +19,7 @@ class Executor:
|
||||||
self,
|
self,
|
||||||
model: AutoModel,
|
model: AutoModel,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
page_cache: PagedCache,
|
page_cache: KVCache,
|
||||||
device: Optional[str] = None,
|
device: Optional[str] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference.core.cache import PagedCache
|
from astrai.inference.core.cache import KVCache
|
||||||
from astrai.inference.core.executor import Executor
|
from astrai.inference.core.executor import Executor
|
||||||
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
from astrai.inference.core.task import STOP, Task, TaskManager, TaskStatus
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
|
|
@ -37,7 +37,7 @@ class InferenceScheduler:
|
||||||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||||
) // page_size
|
) // page_size
|
||||||
|
|
||||||
self._page_cache = PagedCache(
|
self._page_cache = KVCache(
|
||||||
config.n_layers,
|
config.n_layers,
|
||||||
n_pages,
|
n_pages,
|
||||||
page_size,
|
page_size,
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class Task:
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
prompt_ids: List[int],
|
prompt_ids: List[int],
|
||||||
max_tokens: int = 1024,
|
max_tokens: Optional[int] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
|
|
@ -54,7 +54,7 @@ class Task:
|
||||||
return self.input_tokens + len(self.output_ids)
|
return self.input_tokens + len(self.output_ids)
|
||||||
|
|
||||||
def is_finished(self, stop_ids: List[int]) -> bool:
|
def is_finished(self, stop_ids: List[int]) -> bool:
|
||||||
if self.output_tokens >= self.max_tokens:
|
if self.max_tokens is not None and self.output_tokens >= self.max_tokens:
|
||||||
return True
|
return True
|
||||||
if self.output_ids and self.output_ids[-1] in stop_ids:
|
if self.output_ids and self.output_ids[-1] in stop_ids:
|
||||||
return True
|
return True
|
||||||
|
|
@ -88,7 +88,7 @@ class TaskManager:
|
||||||
def add_task(
|
def add_task(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
max_tokens: int = 1024,
|
max_tokens: Optional[int] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
|
|
@ -104,6 +104,9 @@ class TaskManager:
|
||||||
stream_callback(STOP)
|
stream_callback(STOP)
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
|
if max_tokens is None:
|
||||||
|
max_tokens = self.max_seq_len - len(prompt_ids)
|
||||||
|
else:
|
||||||
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
max_tokens = min(max_tokens, self.max_seq_len - len(prompt_ids))
|
||||||
|
|
||||||
task = Task(
|
task = Task(
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import gc
|
import gc
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
from typing import Any, AsyncGenerator, Dict, Generator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -14,6 +13,17 @@ from astrai.inference.core.task import STOP
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_sampling_params(
|
||||||
|
top_k: int, top_p: float, temperature: float, max_tokens: Optional[int] = None
|
||||||
|
):
|
||||||
|
if not (isinstance(top_k, int) and top_k >= 0):
|
||||||
|
raise ValueError("top_k must be a non-negative integer")
|
||||||
|
if not (0.0 <= top_p <= 1.0):
|
||||||
|
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
||||||
|
if not (isinstance(temperature, (int, float)) and temperature >= 0):
|
||||||
|
raise ValueError("temperature must be a non-negative number")
|
||||||
|
|
||||||
|
|
||||||
class GenerateResult:
|
class GenerateResult:
|
||||||
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
"""Thread-safe token accumulator for streaming and non-streaming modes."""
|
||||||
|
|
||||||
|
|
@ -58,24 +68,6 @@ class GenerateResult:
|
||||||
return self.results.copy()
|
return self.results.copy()
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class GenerationParams:
|
|
||||||
"""Immutable value object for sampling hyperparameters."""
|
|
||||||
|
|
||||||
top_k: int = 50
|
|
||||||
top_p: float = 1.0
|
|
||||||
temperature: float = 1.0
|
|
||||||
max_tokens: int = 1024
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if not (isinstance(self.top_k, int) and self.top_k >= 0):
|
|
||||||
raise ValueError("top_k must be a non-negative integer")
|
|
||||||
if not (0.0 <= self.top_p <= 1.0):
|
|
||||||
raise ValueError("top_p must be a float between 0.0 and 1.0")
|
|
||||||
if not (isinstance(self.temperature, (int, float)) and self.temperature >= 0):
|
|
||||||
raise ValueError("temperature must be a non-negative number")
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationRequest:
|
class GenerationRequest:
|
||||||
"""Request parameters for text generation."""
|
"""Request parameters for text generation."""
|
||||||
|
|
||||||
|
|
@ -85,34 +77,18 @@ class GenerationRequest:
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
max_len: int = 1024,
|
max_tokens: Optional[int] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
):
|
):
|
||||||
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
|
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
self.params = GenerationParams(
|
self.top_k = top_k
|
||||||
top_k=top_k,
|
self.top_p = top_p
|
||||||
top_p=top_p,
|
self.temperature = temperature
|
||||||
temperature=temperature,
|
self.max_tokens = max_tokens
|
||||||
max_tokens=max_len,
|
|
||||||
)
|
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
|
|
||||||
@property
|
|
||||||
def top_k(self) -> int:
|
|
||||||
return self.params.top_k
|
|
||||||
|
|
||||||
@property
|
|
||||||
def top_p(self) -> float:
|
|
||||||
return self.params.top_p
|
|
||||||
|
|
||||||
@property
|
|
||||||
def temperature(self) -> float:
|
|
||||||
return self.params.temperature
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_len(self) -> int:
|
|
||||||
return self.params.max_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceEngine:
|
class InferenceEngine:
|
||||||
"""Unified inference engine backed by continuous-batching scheduler."""
|
"""Unified inference engine backed by continuous-batching scheduler."""
|
||||||
|
|
@ -150,37 +126,36 @@ class InferenceEngine:
|
||||||
self,
|
self,
|
||||||
prompt: Union[str, List[str]],
|
prompt: Union[str, List[str]],
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
max_tokens: int = 1024,
|
max_tokens: Optional[int] = None,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> Union[Generator, str, List[str]]:
|
) -> Union[Generator, str, List[str]]:
|
||||||
params = GenerationParams(
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
is_batch = isinstance(prompt, list)
|
is_batch = isinstance(prompt, list)
|
||||||
prompts = prompt if is_batch else [prompt]
|
prompts = prompt if is_batch else [prompt]
|
||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._generate_streaming(prompts, is_batch, params)
|
return self._generate_streaming(
|
||||||
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self._generate_non_streaming(prompts, is_batch, params)
|
return self._generate_non_streaming(
|
||||||
|
prompts, is_batch, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
|
||||||
def generate_async(
|
def generate_async(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
params: Optional[GenerationParams] = None,
|
max_tokens: Optional[int] = None,
|
||||||
max_tokens: int = 1024,
|
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_p: float = 1.0,
|
top_p: float = 1.0,
|
||||||
top_k: int = 50,
|
top_k: int = 50,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
if params is None:
|
_validate_sampling_params(top_k, top_p, temperature, max_tokens)
|
||||||
params = GenerationParams(
|
sync_gen = self._generate_streaming(
|
||||||
top_k=top_k, top_p=top_p, temperature=temperature, max_tokens=max_tokens
|
[prompt], False, max_tokens, temperature, top_p, top_k
|
||||||
)
|
)
|
||||||
sync_gen = self._generate_streaming([prompt], False, params)
|
|
||||||
|
|
||||||
async def _agen():
|
async def _agen():
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
@ -206,14 +181,19 @@ class InferenceEngine:
|
||||||
return self.generate(
|
return self.generate(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=request.stream,
|
stream=request.stream,
|
||||||
max_tokens=request.params.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.params.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.params.top_p,
|
top_p=request.top_p,
|
||||||
top_k=request.params.top_k,
|
top_k=request.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _submit_tasks(
|
def _submit_tasks(
|
||||||
self, prompts: List[str], params: GenerationParams
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
) -> Tuple[GenerateResult, List[str]]:
|
) -> Tuple[GenerateResult, List[str]]:
|
||||||
n = len(prompts)
|
n = len(prompts)
|
||||||
result = GenerateResult(count=n)
|
result = GenerateResult(count=n)
|
||||||
|
|
@ -222,10 +202,10 @@ class InferenceEngine:
|
||||||
cb = self._make_callback(result, i)
|
cb = self._make_callback(result, i)
|
||||||
task_id = self.scheduler.add_task(
|
task_id = self.scheduler.add_task(
|
||||||
prompt=p,
|
prompt=p,
|
||||||
max_tokens=params.max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=params.temperature,
|
temperature=temperature,
|
||||||
top_p=params.top_p,
|
top_p=top_p,
|
||||||
top_k=params.top_k,
|
top_k=top_k,
|
||||||
stream_callback=cb,
|
stream_callback=cb,
|
||||||
)
|
)
|
||||||
task_ids.append(task_id)
|
task_ids.append(task_id)
|
||||||
|
|
@ -239,9 +219,17 @@ class InferenceEngine:
|
||||||
return cb
|
return cb
|
||||||
|
|
||||||
def _generate_streaming(
|
def _generate_streaming(
|
||||||
self, prompts: List[str], is_batch: bool, params: GenerationParams
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
is_batch: bool,
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
result, task_ids = self._submit_tasks(prompts, params)
|
result, task_ids = self._submit_tasks(
|
||||||
|
prompts, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
n = len(prompts)
|
n = len(prompts)
|
||||||
remaining = n
|
remaining = n
|
||||||
finished = [False] * n
|
finished = [False] * n
|
||||||
|
|
@ -267,9 +255,17 @@ class InferenceEngine:
|
||||||
return gen()
|
return gen()
|
||||||
|
|
||||||
def _generate_non_streaming(
|
def _generate_non_streaming(
|
||||||
self, prompts: List[str], is_batch: bool, params: GenerationParams
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
is_batch: bool,
|
||||||
|
max_tokens: Optional[int],
|
||||||
|
temperature: float,
|
||||||
|
top_p: float,
|
||||||
|
top_k: int,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, List[str]]:
|
||||||
result, task_ids = self._submit_tasks(prompts, params)
|
result, task_ids = self._submit_tasks(
|
||||||
|
prompts, max_tokens, temperature, top_p, top_k
|
||||||
|
)
|
||||||
|
|
||||||
result.wait_completion()
|
result.wait_completion()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.inference.core.cache import CacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
|
|
@ -147,7 +147,7 @@ class GQA(nn.Module):
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tensor,
|
rotary_emb: Tensor,
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
is_causal = attn_mask is None
|
is_causal = attn_mask is None
|
||||||
|
|
||||||
|
|
@ -227,7 +227,7 @@ class MLA(nn.Module):
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tensor,
|
rotary_emb: Tensor,
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
bsz, seq_len, _ = x.size()
|
bsz, seq_len, _ = x.size()
|
||||||
is_causal = attn_mask is None
|
is_causal = attn_mask is None
|
||||||
|
|
@ -306,7 +306,7 @@ class DecoderBlock(nn.Module):
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
rotary_emb: Tensor,
|
rotary_emb: Tensor,
|
||||||
attention_mask: Optional[Tensor] = None,
|
attention_mask: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
self.input_norm(x),
|
self.input_norm(x),
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.inference.core.cache import CacheView
|
from astrai.inference.core.cache import KvcacheView
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.module import (
|
||||||
DecoderBlock,
|
DecoderBlock,
|
||||||
|
|
@ -122,7 +122,7 @@ class Transformer(AutoModel):
|
||||||
self,
|
self,
|
||||||
input_ids: Tensor,
|
input_ids: Tensor,
|
||||||
input_mask: Optional[Tensor] = None,
|
input_mask: Optional[Tensor] = None,
|
||||||
paged_cache: Optional[CacheView] = None,
|
paged_cache: Optional[KvcacheView] = None,
|
||||||
position_ids: Optional[Tensor] = None,
|
position_ids: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
assert input_ids.ndim == 2
|
assert input_ids.ndim == 2
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""Benchmark Transformer with PagedCache (replaces old persistent_key_values)."""
|
"""Benchmark Transformer with KVCache"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import ModelConfig
|
||||||
from astrai.inference import PagedCache
|
from astrai.inference import KVCache
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -33,7 +33,7 @@ class GenerationBenchmark:
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
head_dim = config.dim // config.n_heads
|
head_dim = config.dim // config.n_heads
|
||||||
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
n_pages = (config.max_len * 4 + page_size - 1) // page_size
|
||||||
self._page_cache = PagedCache(
|
self._page_cache = KVCache(
|
||||||
config.n_layers,
|
config.n_layers,
|
||||||
n_pages,
|
n_pages,
|
||||||
page_size,
|
page_size,
|
||||||
|
|
@ -130,7 +130,12 @@ class GenerationBenchmark:
|
||||||
)
|
)
|
||||||
|
|
||||||
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
|
n_pages = (prompt_length + gen_length + page_size - 1) // page_size
|
||||||
pages = self._page_cache.alloc_n(n_pages * batch_size)
|
total = n_pages * batch_size
|
||||||
|
pages = []
|
||||||
|
for _ in range(total):
|
||||||
|
p = self._page_cache._pool.alloc()
|
||||||
|
assert p >= 0, "OOM"
|
||||||
|
pages.append(p)
|
||||||
page_table = torch.tensor(
|
page_table = torch.tensor(
|
||||||
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
|
[pages[i * n_pages : (i + 1) * n_pages] for i in range(batch_size)],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
|
|
@ -176,7 +181,7 @@ class GenerationBenchmark:
|
||||||
total_time += trial_time
|
total_time += trial_time
|
||||||
|
|
||||||
for idx in pages:
|
for idx in pages:
|
||||||
self._page_cache.free(idx)
|
self._page_cache._pool.free(idx)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
f" Trial {trial + 1}/{num_trials}: {gen_length} tokens in {trial_time:.3f}s "
|
||||||
|
|
@ -225,7 +230,7 @@ if __name__ == "__main__":
|
||||||
benchmark = GenerationBenchmark(config)
|
benchmark = GenerationBenchmark(config)
|
||||||
|
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
print("Running Transformer Generation Benchmark (PagedCache)")
|
print("Running Transformer Generation Benchmark (KVCache)")
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
|
|
||||||
prefill_result = benchmark.run_prefill_benchmark(
|
prefill_result = benchmark.run_prefill_benchmark(
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,20 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference import (
|
from astrai.inference import (
|
||||||
PagedCache,
|
Allocator,
|
||||||
|
KVCache,
|
||||||
PagePool,
|
PagePool,
|
||||||
PrefixCache,
|
PrefixCache,
|
||||||
|
Storage,
|
||||||
TaskTable,
|
TaskTable,
|
||||||
page_hash,
|
page_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_pool(n_pages: int, page_size: int) -> PagePool:
|
||||||
|
return PagePool(Allocator(n_pages), PrefixCache(page_size))
|
||||||
|
|
||||||
|
|
||||||
def test_page_hash_full_page():
|
def test_page_hash_full_page():
|
||||||
token_ids = list(range(256))
|
token_ids = list(range(256))
|
||||||
h = page_hash(token_ids, 0, 64)
|
h = page_hash(token_ids, 0, 64)
|
||||||
|
|
@ -24,7 +30,7 @@ def test_page_hash_different_page_differs():
|
||||||
|
|
||||||
|
|
||||||
def test_page_pool_alloc_free_cycle():
|
def test_page_pool_alloc_free_cycle():
|
||||||
pool = PagePool(n_pages=4)
|
pool = make_pool(4, 64)
|
||||||
a = pool.alloc()
|
a = pool.alloc()
|
||||||
b = pool.alloc()
|
b = pool.alloc()
|
||||||
assert a != b
|
assert a != b
|
||||||
|
|
@ -35,130 +41,101 @@ def test_page_pool_alloc_free_cycle():
|
||||||
|
|
||||||
|
|
||||||
def test_page_pool_alloc_when_full():
|
def test_page_pool_alloc_when_full():
|
||||||
pool = PagePool(n_pages=2)
|
pool = make_pool(2, 64)
|
||||||
pool.alloc()
|
pool.alloc()
|
||||||
pool.alloc()
|
pool.alloc()
|
||||||
assert pool.alloc() == -1
|
assert pool.alloc() == -1
|
||||||
|
|
||||||
|
|
||||||
def test_page_pool_lru_eviction():
|
def test_page_pool_lru_eviction():
|
||||||
evicted = []
|
pool = make_pool(2, 64)
|
||||||
|
|
||||||
def on_evict(idx):
|
|
||||||
evicted.append(idx)
|
|
||||||
|
|
||||||
pool = PagePool(n_pages=2, on_evict=on_evict)
|
|
||||||
p0 = pool.alloc()
|
p0 = pool.alloc()
|
||||||
p1 = pool.alloc()
|
p1 = pool.alloc()
|
||||||
pool.free(p0, keep_cached=True)
|
pool.record(p0, list(range(64)), 0)
|
||||||
pool.free(p1, keep_cached=True)
|
pool.record(p1, list(range(64, 128)), 0)
|
||||||
|
pool.free(p0)
|
||||||
|
pool.free(p1)
|
||||||
pool.alloc()
|
pool.alloc()
|
||||||
assert len(evicted) == 1
|
assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
|
||||||
assert evicted[0] == p0
|
|
||||||
|
|
||||||
|
|
||||||
def test_page_pool_inc_ref_and_free():
|
def test_page_pool_inc_ref_and_free():
|
||||||
pool = PagePool(n_pages=2)
|
pool = make_pool(2, 64)
|
||||||
p = pool.alloc()
|
p = pool.alloc()
|
||||||
pool.inc_ref(p)
|
pool.inc_ref(p)
|
||||||
assert pool._refs[p] == 2
|
assert pool._alloc._refs[p] == 2
|
||||||
pool.free(p)
|
pool.free(p)
|
||||||
assert pool._refs[p] == 1
|
assert pool._alloc._refs[p] == 1
|
||||||
pool.free(p)
|
pool.free(p)
|
||||||
assert pool._refs[p] == 0
|
assert pool._alloc._refs[p] == 0
|
||||||
|
|
||||||
|
|
||||||
def test_page_pool_touch_moves_to_end():
|
|
||||||
pool = PagePool(n_pages=4)
|
|
||||||
p0 = pool.alloc()
|
|
||||||
p1 = pool.alloc()
|
|
||||||
p2 = pool.alloc()
|
|
||||||
pool.free(p0, keep_cached=True)
|
|
||||||
pool.free(p1, keep_cached=True)
|
|
||||||
pool.free(p2, keep_cached=True)
|
|
||||||
assert next(iter(pool._lru)) == p0
|
|
||||||
pool.touch(p0)
|
|
||||||
assert next(reversed(pool._lru)) == p0
|
|
||||||
|
|
||||||
|
|
||||||
def test_page_pool_remove_from_lru():
|
|
||||||
pool = PagePool(n_pages=4)
|
|
||||||
p0 = pool.alloc()
|
|
||||||
pool.free(p0, keep_cached=True)
|
|
||||||
assert p0 in pool._lru
|
|
||||||
pool.remove_from_lru(p0)
|
|
||||||
assert p0 not in pool._lru
|
|
||||||
|
|
||||||
|
|
||||||
def test_page_pool_keep_cached_realloc():
|
def test_page_pool_keep_cached_realloc():
|
||||||
"""Free mask has priority over LRU; cached page returned only when no free pages."""
|
"""Free mask has priority over LRU; cached page returned only when no free pages."""
|
||||||
pool = PagePool(n_pages=3)
|
pool = make_pool(3, 64)
|
||||||
p0 = pool.alloc()
|
p0 = pool.alloc()
|
||||||
p1 = pool.alloc()
|
p1 = pool.alloc()
|
||||||
p2 = pool.alloc()
|
p2 = pool.alloc()
|
||||||
pool.free(p0, keep_cached=True)
|
for p in (p0, p1, p2):
|
||||||
pool.free(p1, keep_cached=True)
|
pool.record(p, [p] * 64, 0)
|
||||||
pool.free(p2, keep_cached=True)
|
pool.free(p0)
|
||||||
|
pool.free(p1)
|
||||||
|
pool.free(p2)
|
||||||
assert pool.alloc() == p0
|
assert pool.alloc() == p0
|
||||||
|
|
||||||
|
|
||||||
def _record_then_cache(pool, prefix, page, token_ids, logical_idx):
|
|
||||||
"""Simulate the real lifecycle: record → ref stays >0, then free cached returns to LRU."""
|
|
||||||
prefix.record(page, token_ids, logical_idx, pool)
|
|
||||||
pool.free(page, keep_cached=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_lookup_returns_hits():
|
def test_prefix_cache_lookup_returns_hits():
|
||||||
token_ids = list(range(256))
|
token_ids = list(range(256))
|
||||||
pool = PagePool(n_pages=16)
|
pool = make_pool(16, 64)
|
||||||
prefix = PrefixCache(page_size=64)
|
|
||||||
pages = [pool.alloc() for _ in range(4)]
|
pages = [pool.alloc() for _ in range(4)]
|
||||||
for i, p in enumerate(pages):
|
for i, p in enumerate(pages):
|
||||||
_record_then_cache(pool, prefix, p, token_ids, i)
|
pool.record(p, token_ids, i)
|
||||||
hits = prefix.lookup(token_ids, pool)
|
pool.free(p)
|
||||||
|
hits = pool.lookup(token_ids)
|
||||||
assert hits == pages
|
assert hits == pages
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_lookup_stops_at_first_miss():
|
def test_prefix_cache_lookup_stops_at_first_miss():
|
||||||
token_ids = list(range(256))
|
token_ids = list(range(256))
|
||||||
pool = PagePool(n_pages=16)
|
pool = make_pool(16, 64)
|
||||||
prefix = PrefixCache(page_size=64)
|
|
||||||
p0 = pool.alloc()
|
p0 = pool.alloc()
|
||||||
_record_then_cache(pool, prefix, p0, token_ids, 0)
|
pool.record(p0, token_ids, 0)
|
||||||
|
pool.free(p0)
|
||||||
p1 = pool.alloc()
|
p1 = pool.alloc()
|
||||||
_record_then_cache(pool, prefix, p1, [99] * 64, 1)
|
pool.record(p1, [99] * 64, 1)
|
||||||
hits = prefix.lookup(token_ids, pool)
|
pool.free(p1)
|
||||||
|
hits = pool.lookup(token_ids)
|
||||||
assert len(hits) == 1
|
assert len(hits) == 1
|
||||||
assert hits[0] == p0
|
assert hits[0] == p0
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_ignores_partial_last_page():
|
def test_prefix_cache_ignores_partial_last_page():
|
||||||
token_ids = list(range(100))
|
token_ids = list(range(100))
|
||||||
pool = PagePool(n_pages=16)
|
pool = make_pool(16, 64)
|
||||||
prefix = PrefixCache(page_size=64)
|
|
||||||
p = pool.alloc()
|
p = pool.alloc()
|
||||||
_record_then_cache(pool, prefix, p, token_ids, 0)
|
pool.record(p, token_ids, 0)
|
||||||
hits = prefix.lookup(token_ids, pool)
|
pool.free(p)
|
||||||
|
hits = pool.lookup(token_ids)
|
||||||
assert len(hits) == 1
|
assert len(hits) == 1
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_on_evict_clears_mappings():
|
def test_prefix_cache_on_evict_clears_mappings():
|
||||||
pool = PagePool(n_pages=4)
|
pool = make_pool(4, 64)
|
||||||
prefix = PrefixCache(page_size=64)
|
|
||||||
p = pool.alloc()
|
p = pool.alloc()
|
||||||
_record_then_cache(pool, prefix, p, list(range(64)), 0)
|
pool.record(p, list(range(64)), 0)
|
||||||
assert prefix.has_page(p)
|
pool.free(p)
|
||||||
prefix.on_evict(p)
|
assert p in pool._prefix._page_to_hash
|
||||||
assert not prefix.has_page(p)
|
pool._prefix.evict(p)
|
||||||
|
assert p not in pool._prefix._page_to_hash
|
||||||
|
|
||||||
|
|
||||||
def test_prefix_cache_has_page():
|
def test_prefix_cache_has_page():
|
||||||
pool = PagePool(n_pages=4)
|
pool = make_pool(4, 64)
|
||||||
prefix = PrefixCache(page_size=64)
|
|
||||||
p = pool.alloc()
|
p = pool.alloc()
|
||||||
assert not prefix.has_page(p)
|
assert p not in pool._prefix._page_to_hash
|
||||||
_record_then_cache(pool, prefix, p, list(range(64)), 0)
|
pool.record(p, list(range(64)), 0)
|
||||||
assert prefix.has_page(p)
|
pool.free(p)
|
||||||
|
assert p in pool._prefix._page_to_hash
|
||||||
|
|
||||||
|
|
||||||
def test_task_table_set_get():
|
def test_task_table_set_get():
|
||||||
|
|
@ -183,8 +160,8 @@ def test_task_table_pop():
|
||||||
assert table.get("task1") == []
|
assert table.get("task1") == []
|
||||||
|
|
||||||
|
|
||||||
def test_paged_cache_task_extend_allocates():
|
def test_kv_cache_task_extend_allocates():
|
||||||
cache = PagedCache(
|
cache = KVCache(
|
||||||
n_layers=1,
|
n_layers=1,
|
||||||
n_pages=8,
|
n_pages=8,
|
||||||
page_size=64,
|
page_size=64,
|
||||||
|
|
@ -199,8 +176,8 @@ def test_paged_cache_task_extend_allocates():
|
||||||
assert len(cache._table.get("task1")) == 4
|
assert len(cache._table.get("task1")) == 4
|
||||||
|
|
||||||
|
|
||||||
def test_paged_cache_task_extend_fails_when_pool_full():
|
def test_kv_cache_task_extend_fails_when_pool_full():
|
||||||
cache = PagedCache(
|
cache = KVCache(
|
||||||
n_layers=1,
|
n_layers=1,
|
||||||
n_pages=2,
|
n_pages=2,
|
||||||
page_size=64,
|
page_size=64,
|
||||||
|
|
@ -230,8 +207,8 @@ def test_task_table_table_tensor_empty_input():
|
||||||
assert t.numel() == 0
|
assert t.numel() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_paged_cache_write_gather_single_page():
|
def test_storage_write_gather_single_page():
|
||||||
cache = PagedCache(
|
storage = Storage(
|
||||||
n_layers=2,
|
n_layers=2,
|
||||||
n_pages=8,
|
n_pages=8,
|
||||||
page_size=4,
|
page_size=4,
|
||||||
|
|
@ -244,13 +221,13 @@ def test_paged_cache_write_gather_single_page():
|
||||||
k = torch.randn(1, 2, 2, 8)
|
k = torch.randn(1, 2, 2, 8)
|
||||||
v = torch.randn(1, 2, 2, 8)
|
v = torch.randn(1, 2, 2, 8)
|
||||||
|
|
||||||
cache.write(0, page_table, 0, k, v)
|
storage.write(0, page_table, 0, k, v)
|
||||||
gk, gv = cache.gather(0, page_table, 2)
|
gk, gv = storage.gather(0, page_table, 2)
|
||||||
assert torch.allclose(gk, k)
|
assert torch.allclose(gk, k)
|
||||||
|
|
||||||
|
|
||||||
def test_paged_cache_write_cross_page():
|
def test_storage_write_cross_page():
|
||||||
cache = PagedCache(
|
storage = Storage(
|
||||||
n_layers=1,
|
n_layers=1,
|
||||||
n_pages=8,
|
n_pages=8,
|
||||||
page_size=4,
|
page_size=4,
|
||||||
|
|
@ -263,13 +240,13 @@ def test_paged_cache_write_cross_page():
|
||||||
k = torch.randn(1, 8, 2, 8)
|
k = torch.randn(1, 8, 2, 8)
|
||||||
v = torch.randn(1, 8, 2, 8)
|
v = torch.randn(1, 8, 2, 8)
|
||||||
|
|
||||||
cache.write(0, page_table, 0, k, v)
|
storage.write(0, page_table, 0, k, v)
|
||||||
gk, gv = cache.gather(0, page_table, 8)
|
gk, gv = storage.gather(0, page_table, 8)
|
||||||
assert torch.allclose(gk, k)
|
assert torch.allclose(gk, k)
|
||||||
|
|
||||||
|
|
||||||
def test_paged_cache_gather_truncates_to_total_len():
|
def test_storage_gather_truncates_to_total_len():
|
||||||
cache = PagedCache(
|
storage = Storage(
|
||||||
n_layers=1,
|
n_layers=1,
|
||||||
n_pages=8,
|
n_pages=8,
|
||||||
page_size=4,
|
page_size=4,
|
||||||
|
|
@ -281,14 +258,14 @@ def test_paged_cache_gather_truncates_to_total_len():
|
||||||
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||||
k = torch.randn(1, 6, 2, 8)
|
k = torch.randn(1, 6, 2, 8)
|
||||||
v = torch.randn(1, 6, 2, 8)
|
v = torch.randn(1, 6, 2, 8)
|
||||||
cache.write(0, page_table, 0, k, v)
|
storage.write(0, page_table, 0, k, v)
|
||||||
|
|
||||||
gk, gv = cache.gather(0, page_table, 5)
|
gk, gv = storage.gather(0, page_table, 5)
|
||||||
assert gk.shape == (1, 5, 2, 8)
|
assert gk.shape == (1, 5, 2, 8)
|
||||||
|
|
||||||
|
|
||||||
def test_paged_cache_gather_clamps_negative_padding():
|
def test_storage_gather_clamps_negative_padding():
|
||||||
cache = PagedCache(
|
storage = Storage(
|
||||||
n_layers=1,
|
n_layers=1,
|
||||||
n_pages=8,
|
n_pages=8,
|
||||||
page_size=4,
|
page_size=4,
|
||||||
|
|
@ -298,5 +275,5 @@ def test_paged_cache_gather_clamps_negative_padding():
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
page_table = torch.tensor([[0, -1]], dtype=torch.long)
|
page_table = torch.tensor([[0, -1]], dtype=torch.long)
|
||||||
gk, gv = cache.gather(0, page_table, 4)
|
gk, gv = storage.gather(0, page_table, 4)
|
||||||
assert gk.shape == (1, 4, 2, 8)
|
assert gk.shape == (1, 4, 2, 8)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue