297 lines
9.6 KiB
Python
297 lines
9.6 KiB
Python
from collections import OrderedDict
|
|
from typing import Callable, Dict, List, Optional, Tuple
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
|
|
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
|
start = page_idx * page_size
|
|
end = min(start + page_size, len(token_ids))
|
|
h = 0
|
|
for i in range(start, end):
|
|
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
|
|
return h
|
|
|
|
|
|
class PagePool:
|
|
"""Bitmask page allocator with ref-counting and LRU eviction."""
|
|
|
|
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
|
|
self._free_mask = (1 << n_pages) - 1
|
|
self._refs: List[int] = [0] * n_pages
|
|
self._lru: OrderedDict[int, None] = OrderedDict()
|
|
self._on_evict = on_evict
|
|
|
|
def alloc(self) -> int:
|
|
if self._free_mask:
|
|
lsb = self._free_mask & -self._free_mask
|
|
idx = lsb.bit_length() - 1
|
|
self._free_mask ^= lsb
|
|
self._refs[idx] = 1
|
|
return idx
|
|
if self._lru:
|
|
idx, _ = self._lru.popitem(last=False)
|
|
if self._on_evict:
|
|
self._on_evict(idx)
|
|
self._refs[idx] = 1
|
|
self._free_mask &= ~(1 << idx)
|
|
return idx
|
|
return -1
|
|
|
|
def free(self, idx: int, keep_cached: bool = False) -> None:
|
|
self._refs[idx] -= 1
|
|
if self._refs[idx] == 0:
|
|
if keep_cached:
|
|
self._lru[idx] = None
|
|
else:
|
|
self._free_mask |= 1 << idx
|
|
|
|
def inc_ref(self, idx: int) -> None:
|
|
self._refs[idx] += 1
|
|
|
|
def touch(self, idx: int) -> None:
|
|
self._lru.move_to_end(idx)
|
|
|
|
def remove_from_lru(self, idx: int) -> None:
|
|
self._lru.pop(idx, None)
|
|
|
|
|
|
class PrefixCache:
|
|
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
|
|
|
def __init__(self, page_size: int):
|
|
self._page_size = page_size
|
|
self._page_to_hash: Dict[int, int] = {}
|
|
self._hash_to_page: Dict[int, int] = {}
|
|
|
|
def on_evict(self, idx: int) -> None:
|
|
h = self._page_to_hash.pop(idx, None)
|
|
if h is not None:
|
|
self._hash_to_page.pop(h, None)
|
|
|
|
def has_page(self, idx: int) -> bool:
|
|
return idx in self._page_to_hash
|
|
|
|
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
|
|
full_pages = len(token_ids) // self._page_size
|
|
hits: List[int] = []
|
|
for i in range(full_pages):
|
|
h = page_hash(token_ids, i, self._page_size)
|
|
p = self._hash_to_page.get(h)
|
|
if p is None:
|
|
break
|
|
pool.touch(p)
|
|
hits.append(p)
|
|
return hits
|
|
|
|
def record(
|
|
self,
|
|
page_idx: int,
|
|
token_ids: List[int],
|
|
logical_page_idx: int,
|
|
pool: PagePool,
|
|
) -> None:
|
|
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
|
old_h = self._page_to_hash.pop(page_idx, None)
|
|
if old_h is not None:
|
|
self._hash_to_page.pop(old_h, None)
|
|
self._page_to_hash[page_idx] = h
|
|
self._hash_to_page[h] = page_idx
|
|
pool.remove_from_lru(page_idx)
|
|
|
|
|
|
class TaskTable:
|
|
"""Maps task_ids to page tables and cached token counts."""
|
|
|
|
def __init__(self, pool: PagePool, page_size: int):
|
|
self._pool = pool
|
|
self._page_size = page_size
|
|
self._pages: Dict[str, List[int]] = {}
|
|
self._cached: Dict[str, int] = {}
|
|
|
|
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
|
self._pages[task_id] = page_table
|
|
self._cached[task_id] = cached
|
|
|
|
def get(self, task_id: str) -> List[int]:
|
|
return self._pages.get(task_id, [])
|
|
|
|
def get_cached(self, task_id: str) -> int:
|
|
return self._cached.get(task_id, 0)
|
|
|
|
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
|
pages = self._pages.pop(task_id, [])
|
|
cached = self._cached.pop(task_id, 0)
|
|
return pages, cached
|
|
|
|
def extend(self, task_id: str, pos: int) -> bool:
|
|
page_table = self._pages[task_id]
|
|
needed = (pos + 1 + self._page_size - 1) // self._page_size
|
|
while len(page_table) < needed:
|
|
p = self._pool.alloc()
|
|
if p < 0:
|
|
return False
|
|
page_table.append(p)
|
|
return True
|
|
|
|
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
|
states = [self._pages.get(tid, []) for tid in task_ids]
|
|
max_pages = max((len(s) for s in states), default=0)
|
|
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
|
return torch.tensor(rows, dtype=torch.long, device=device)
|
|
|
|
|
|
class PagedCache:
|
|
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
|
|
|
|
def __init__(
|
|
self,
|
|
n_layers: int,
|
|
n_pages: int,
|
|
page_size: int,
|
|
n_kv_heads: int,
|
|
head_dim: int,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
):
|
|
self.page_size = page_size
|
|
self._prefix = PrefixCache(page_size)
|
|
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
|
|
self._table = TaskTable(self._pool, page_size)
|
|
|
|
self.k_cache = torch.empty(
|
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.v_cache = torch.empty(
|
|
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
def alloc_n(self, n: int) -> List[int]:
|
|
pages: List[int] = []
|
|
for _ in range(n):
|
|
p = self._pool.alloc()
|
|
if p < 0:
|
|
for page in pages:
|
|
self.free(page)
|
|
return []
|
|
pages.append(p)
|
|
return pages
|
|
|
|
def free(self, idx: int) -> None:
|
|
cached = self._prefix.has_page(idx)
|
|
self._pool.free(idx, keep_cached=cached)
|
|
if not cached:
|
|
self._prefix.on_evict(idx)
|
|
|
|
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
|
hits = self._prefix.lookup(prompt_ids, self._pool)
|
|
cached = len(hits) * self.page_size
|
|
for p in hits:
|
|
self._pool.inc_ref(p)
|
|
|
|
remaining = len(prompt_ids) - cached
|
|
n_new = (
|
|
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
|
)
|
|
new_pages: List[int] = []
|
|
if n_new > 0:
|
|
for _ in range(n_new):
|
|
p = self._pool.alloc()
|
|
if p < 0:
|
|
for hp in hits:
|
|
self.free(hp)
|
|
for np in new_pages:
|
|
self.free(np)
|
|
return False
|
|
new_pages.append(p)
|
|
|
|
self._table.set(task_id, hits + new_pages, cached)
|
|
return True
|
|
|
|
def task_free(self, task_id: str) -> None:
|
|
page_table, _ = self._table.pop(task_id)
|
|
for idx in page_table:
|
|
self.free(idx)
|
|
|
|
def task_extend(self, task_id: str, pos: int) -> bool:
|
|
return self._table.extend(task_id, pos)
|
|
|
|
def task_cached(self, task_id: str) -> int:
|
|
return self._table.get_cached(task_id)
|
|
|
|
def task_record_hashes(
|
|
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
|
) -> None:
|
|
page_table = self._table.get(task_id)
|
|
full_pages = len(prompt_ids) // self.page_size
|
|
for i in range(start_logical_page, full_pages):
|
|
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
|
|
|
|
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
|
return self._table.table_tensor(task_ids, device)
|
|
|
|
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
|
return CacheView(self, page_table, total_len)
|
|
|
|
def write(
|
|
self,
|
|
layer_id: int,
|
|
page_table: Tensor,
|
|
start_pos: int,
|
|
k: Tensor,
|
|
v: Tensor,
|
|
) -> None:
|
|
seq_len = k.size(1)
|
|
if seq_len == 0:
|
|
return
|
|
page_size = self.page_size
|
|
written = 0
|
|
first_page = start_pos // page_size
|
|
last_page = (start_pos + seq_len - 1) // page_size
|
|
for pi in range(first_page, last_page + 1):
|
|
phys_pages = page_table[:, pi]
|
|
page_start = pi * page_size
|
|
write_start = max(page_start, start_pos)
|
|
write_end = min(page_start + page_size, start_pos + seq_len)
|
|
offset = write_start - page_start
|
|
chunk = write_end - write_start
|
|
self.k_cache[layer_id, phys_pages, offset : offset + chunk] = k[
|
|
:, written : written + chunk
|
|
]
|
|
self.v_cache[layer_id, phys_pages, offset : offset + chunk] = v[
|
|
:, written : written + chunk
|
|
]
|
|
written += chunk
|
|
|
|
def gather(
|
|
self, layer_id: int, page_table: Tensor, total_len: int
|
|
) -> Tuple[Tensor, Tensor]:
|
|
safe = page_table.clamp(min=0)
|
|
k = self.k_cache[layer_id, safe]
|
|
v = self.v_cache[layer_id, safe]
|
|
k = k.flatten(1, 2)
|
|
v = v.flatten(1, 2)
|
|
k = k[:, :total_len]
|
|
v = v[:, :total_len]
|
|
return k, v
|
|
|
|
|
|
class CacheView:
|
|
"""Bundles PagedCache + page_table + total_len for attention layers."""
|
|
|
|
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
|
self._cache = cache
|
|
self._page_table = page_table
|
|
self._total_len = total_len
|
|
|
|
def write(self, layer_id: int, start_pos: int, k: Tensor, v: Tensor) -> None:
|
|
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
|
|
|
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
|
return self._cache.gather(layer_id, self._page_table, self._total_len)
|