perf: PagedCache 持久前缀缓存 + LRU 逐出
- astrai/inference/cache.py: refcount 归零时保留 hash 映射,页加入 LRU evictable 池 - alloc() 无空闲页时从 LRU 逐出,优先释放 _free_mask - lookup_prefix/inc_ref 触发 _touch 更新 LRU 序 - record_page 设置 pin 标记并从 LRU 移除
This commit is contained in:
parent
133a9de98f
commit
3da428e0e4
|
|
@ -22,14 +22,16 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||||
|
|
||||||
|
|
||||||
class PagedCache:
|
class PagedCache:
|
||||||
"""Paged KV cache with page-table-indirected read/write.
|
"""Paged KV cache with page-table-indirected read/write and persistent prefix caching.
|
||||||
|
|
||||||
Combines:
|
Combines:
|
||||||
- Page pool (ref-counted alloc/free via bitmask)
|
- Page pool (ref-counted alloc/free via bitmask)
|
||||||
- KV tensor storage (k_cache, v_cache)
|
- KV tensor storage (k_cache, v_cache)
|
||||||
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
|
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
|
||||||
|
- LRU eviction for persistent cross-batch prefix caching
|
||||||
|
|
||||||
Call :meth:`bind` to obtain a batch view for the attention layers.
|
Pages with recorded hashes persist after refcount reaches 0 (pinned).
|
||||||
|
They are evicted via LRU only when alloc() finds no free pages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -57,6 +59,24 @@ class PagedCache:
|
||||||
)
|
)
|
||||||
self._page_to_hash: Dict[int, int] = {}
|
self._page_to_hash: Dict[int, int] = {}
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
|
self._lru: List[int] = []
|
||||||
|
self._pin: List[bool] = [False] * n_pages
|
||||||
|
|
||||||
|
def _touch(self, idx: int) -> None:
|
||||||
|
if self._refs[idx] == 0 and idx in self._lru:
|
||||||
|
self._lru.remove(idx)
|
||||||
|
self._lru.append(idx)
|
||||||
|
|
||||||
|
def _evict_one(self) -> int:
|
||||||
|
while self._lru:
|
||||||
|
idx = self._lru.pop(0)
|
||||||
|
h = self._page_to_hash.pop(idx, None)
|
||||||
|
if h is not None:
|
||||||
|
self._hash_to_page.pop(h, None)
|
||||||
|
self._pin[idx] = False
|
||||||
|
self._refs[idx] = 1
|
||||||
|
return idx
|
||||||
|
return -1
|
||||||
|
|
||||||
def record_page(
|
def record_page(
|
||||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||||
|
|
@ -67,6 +87,9 @@ class PagedCache:
|
||||||
self._hash_to_page.pop(old_h, None)
|
self._hash_to_page.pop(old_h, None)
|
||||||
self._page_to_hash[page_idx] = h
|
self._page_to_hash[page_idx] = h
|
||||||
self._hash_to_page[h] = page_idx
|
self._hash_to_page[h] = page_idx
|
||||||
|
self._pin[page_idx] = True
|
||||||
|
if page_idx in self._lru:
|
||||||
|
self._lru.remove(page_idx)
|
||||||
|
|
||||||
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
|
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
|
||||||
full_pages = len(token_ids) // self.page_size
|
full_pages = len(token_ids) // self.page_size
|
||||||
|
|
@ -76,20 +99,25 @@ class PagedCache:
|
||||||
p = self._hash_to_page.get(h)
|
p = self._hash_to_page.get(h)
|
||||||
if p is None:
|
if p is None:
|
||||||
break
|
break
|
||||||
|
self._touch(p)
|
||||||
hits.append(p)
|
hits.append(p)
|
||||||
return hits
|
return hits
|
||||||
|
|
||||||
def inc_ref(self, idx: int) -> None:
|
def inc_ref(self, idx: int) -> None:
|
||||||
self._refs[idx] += 1
|
self._refs[idx] += 1
|
||||||
|
if self._refs[idx] == 1 and idx in self._lru:
|
||||||
|
self._lru.remove(idx)
|
||||||
|
|
||||||
def alloc(self) -> int:
|
def alloc(self) -> int:
|
||||||
lsb = self._free_mask & -self._free_mask
|
if self._free_mask:
|
||||||
if lsb == 0:
|
lsb = self._free_mask & -self._free_mask
|
||||||
return -1
|
idx = lsb.bit_length() - 1
|
||||||
idx = lsb.bit_length() - 1
|
self._free_mask ^= lsb
|
||||||
self._free_mask ^= lsb
|
self._refs[idx] = 1
|
||||||
self._refs[idx] = 1
|
if idx in self._lru:
|
||||||
return idx
|
self._lru.remove(idx)
|
||||||
|
return idx
|
||||||
|
return self._evict_one()
|
||||||
|
|
||||||
def alloc_n(self, n: int) -> List[int]:
|
def alloc_n(self, n: int) -> List[int]:
|
||||||
pages = [self.alloc() for _ in range(n)]
|
pages = [self.alloc() for _ in range(n)]
|
||||||
|
|
@ -103,10 +131,15 @@ class PagedCache:
|
||||||
def free(self, idx: int) -> None:
|
def free(self, idx: int) -> None:
|
||||||
self._refs[idx] -= 1
|
self._refs[idx] -= 1
|
||||||
if self._refs[idx] == 0:
|
if self._refs[idx] == 0:
|
||||||
self._free_mask |= 1 << idx
|
h = self._page_to_hash.get(idx)
|
||||||
h = self._page_to_hash.pop(idx, None)
|
if h is not None and self._pin[idx]:
|
||||||
if h is not None:
|
self._lru.append(idx)
|
||||||
self._hash_to_page.pop(h, None)
|
else:
|
||||||
|
self._free_mask |= 1 << idx
|
||||||
|
h = self._page_to_hash.pop(idx, None)
|
||||||
|
if h is not None:
|
||||||
|
self._hash_to_page.pop(h, None)
|
||||||
|
self._pin[idx] = False
|
||||||
|
|
||||||
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
||||||
return CacheView(self, page_table, total_len)
|
return CacheView(self, page_table, total_len)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue