perf: PagedCache 持久前缀缓存 + LRU 逐出

- astrai/inference/cache.py: refcount 归零时保留 hash 映射,页加入 LRU evictable 池
- alloc() 无空闲页时从 LRU 逐出,优先释放 _free_mask
- lookup_prefix/inc_ref 触发 _touch 更新 LRU 序
- record_page 设置 pin 标记并从 LRU 移除
This commit is contained in:
ViperEkura 2026-05-10 18:05:11 +08:00
parent 133a9de98f
commit 3da428e0e4
1 changed files with 46 additions and 13 deletions

View File

@ -22,14 +22,16 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
class PagedCache: class PagedCache:
"""Paged KV cache with page-table-indirected read/write. """Paged KV cache with page-table-indirected read/write and persistent prefix caching.
Combines: Combines:
- Page pool (ref-counted alloc/free via bitmask) - Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache) - KV tensor storage (k_cache, v_cache)
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx) - Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
- LRU eviction for persistent cross-batch prefix caching
Call :meth:`bind` to obtain a batch view for the attention layers. Pages with recorded hashes persist after refcount reaches 0 (pinned).
They are evicted via LRU only when alloc() finds no free pages.
""" """
def __init__( def __init__(
@ -57,6 +59,24 @@ class PagedCache:
) )
self._page_to_hash: Dict[int, int] = {} self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {}
self._lru: List[int] = []
self._pin: List[bool] = [False] * n_pages
def _touch(self, idx: int) -> None:
if self._refs[idx] == 0 and idx in self._lru:
self._lru.remove(idx)
self._lru.append(idx)
def _evict_one(self) -> int:
while self._lru:
idx = self._lru.pop(0)
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
self._pin[idx] = False
self._refs[idx] = 1
return idx
return -1
def record_page( def record_page(
self, page_idx: int, token_ids: List[int], logical_page_idx: int self, page_idx: int, token_ids: List[int], logical_page_idx: int
@ -67,6 +87,9 @@ class PagedCache:
self._hash_to_page.pop(old_h, None) self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx self._hash_to_page[h] = page_idx
self._pin[page_idx] = True
if page_idx in self._lru:
self._lru.remove(page_idx)
def lookup_prefix(self, token_ids: List[int]) -> List[int]: def lookup_prefix(self, token_ids: List[int]) -> List[int]:
full_pages = len(token_ids) // self.page_size full_pages = len(token_ids) // self.page_size
@ -76,20 +99,25 @@ class PagedCache:
p = self._hash_to_page.get(h) p = self._hash_to_page.get(h)
if p is None: if p is None:
break break
self._touch(p)
hits.append(p) hits.append(p)
return hits return hits
def inc_ref(self, idx: int) -> None: def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1 self._refs[idx] += 1
if self._refs[idx] == 1 and idx in self._lru:
self._lru.remove(idx)
def alloc(self) -> int: def alloc(self) -> int:
if self._free_mask:
lsb = self._free_mask & -self._free_mask lsb = self._free_mask & -self._free_mask
if lsb == 0:
return -1
idx = lsb.bit_length() - 1 idx = lsb.bit_length() - 1
self._free_mask ^= lsb self._free_mask ^= lsb
self._refs[idx] = 1 self._refs[idx] = 1
if idx in self._lru:
self._lru.remove(idx)
return idx return idx
return self._evict_one()
def alloc_n(self, n: int) -> List[int]: def alloc_n(self, n: int) -> List[int]:
pages = [self.alloc() for _ in range(n)] pages = [self.alloc() for _ in range(n)]
@ -103,10 +131,15 @@ class PagedCache:
def free(self, idx: int) -> None: def free(self, idx: int) -> None:
self._refs[idx] -= 1 self._refs[idx] -= 1
if self._refs[idx] == 0: if self._refs[idx] == 0:
h = self._page_to_hash.get(idx)
if h is not None and self._pin[idx]:
self._lru.append(idx)
else:
self._free_mask |= 1 << idx self._free_mask |= 1 << idx
h = self._page_to_hash.pop(idx, None) h = self._page_to_hash.pop(idx, None)
if h is not None: if h is not None:
self._hash_to_page.pop(h, None) self._hash_to_page.pop(h, None)
self._pin[idx] = False
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len) return CacheView(self, page_table, total_len)