From 3da428e0e4f91e7db9aa02781be9a2613919c12d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 10 May 2026 18:05:11 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20PagedCache=20=E6=8C=81=E4=B9=85?= =?UTF-8?q?=E5=89=8D=E7=BC=80=E7=BC=93=E5=AD=98=20+=20LRU=20=E9=80=90?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - astrai/inference/cache.py: refcount 归零时保留 hash 映射,页加入 LRU evictable 池 - alloc() 无空闲页时从 LRU 逐出,优先释放 _free_mask - lookup_prefix/inc_ref 触发 _touch 更新 LRU 序 - record_page 设置 pin 标记并从 LRU 移除 --- astrai/inference/cache.py | 59 ++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index 1b17d14..f548bc8 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -22,14 +22,16 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int: class PagedCache: - """Paged KV cache with page-table-indirected read/write. + """Paged KV cache with page-table-indirected read/write and persistent prefix caching. Combines: - Page pool (ref-counted alloc/free via bitmask) - KV tensor storage (k_cache, v_cache) - Prefix-cache hash lookup (page_content_hash -> physical_page_idx) + - LRU eviction for persistent cross-batch prefix caching - Call :meth:`bind` to obtain a batch view for the attention layers. + Pages with recorded hashes persist after refcount reaches 0 (pinned). + They are evicted via LRU only when alloc() finds no free pages. """ def __init__( @@ -57,6 +59,24 @@ class PagedCache: ) self._page_to_hash: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {} + self._lru: List[int] = [] + self._pin: List[bool] = [False] * n_pages + + def _touch(self, idx: int) -> None: + if self._refs[idx] == 0 and idx in self._lru: + self._lru.remove(idx) + self._lru.append(idx) + + def _evict_one(self) -> int: + while self._lru: + idx = self._lru.pop(0) + h = self._page_to_hash.pop(idx, None) + if h is not None: + self._hash_to_page.pop(h, None) + self._pin[idx] = False + self._refs[idx] = 1 + return idx + return -1 def record_page( self, page_idx: int, token_ids: List[int], logical_page_idx: int @@ -67,6 +87,9 @@ class PagedCache: self._hash_to_page.pop(old_h, None) self._page_to_hash[page_idx] = h self._hash_to_page[h] = page_idx + self._pin[page_idx] = True + if page_idx in self._lru: + self._lru.remove(page_idx) def lookup_prefix(self, token_ids: List[int]) -> List[int]: full_pages = len(token_ids) // self.page_size @@ -76,20 +99,25 @@ class PagedCache: p = self._hash_to_page.get(h) if p is None: break + self._touch(p) hits.append(p) return hits def inc_ref(self, idx: int) -> None: self._refs[idx] += 1 + if self._refs[idx] == 1 and idx in self._lru: + self._lru.remove(idx) def alloc(self) -> int: - lsb = self._free_mask & -self._free_mask - if lsb == 0: - return -1 - idx = lsb.bit_length() - 1 - self._free_mask ^= lsb - self._refs[idx] = 1 - return idx + if self._free_mask: + lsb = self._free_mask & -self._free_mask + idx = lsb.bit_length() - 1 + self._free_mask ^= lsb + self._refs[idx] = 1 + if idx in self._lru: + self._lru.remove(idx) + return idx + return self._evict_one() def alloc_n(self, n: int) -> List[int]: pages = [self.alloc() for _ in range(n)] @@ -103,10 +131,15 @@ class PagedCache: def free(self, idx: int) -> None: self._refs[idx] -= 1 if self._refs[idx] == 0: - self._free_mask |= 1 << idx - h = self._page_to_hash.pop(idx, None) - if h is not None: - self._hash_to_page.pop(h, None) + h = self._page_to_hash.get(idx) + if h is not None and self._pin[idx]: + self._lru.append(idx) + else: + self._free_mask |= 1 << idx + h = self._page_to_hash.pop(idx, None) + if h is not None: + self._hash_to_page.pop(h, None) + self._pin[idx] = False def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": return CacheView(self, page_table, total_len)