refactor: PagedCache Facade 模式,提取 PagePool/PrefixCache/TaskTable
- cache.py: 提取 PagePool (位图+LRU)、PrefixCache (前缀哈希)、TaskTable (任务页表) PagedCache 降为 Facade 组合三者 + 张量存储,公开 API 不变 - executor.py: 移除 allocate_pages_for_activation/free_task_pages/get_cached_tokens 三冗余委托方法,去掉 page_size 构造参数(改用 page_cache.page_size) - scheduler.py: 直接调用 self._page_cache.* 代替已移除的 Executor 委托 - 移除 CacheView.__slots__、PagePool.ref_count、PagedCache.alloc/pages_needed/inc_ref PrefixCache.evict 等死/冗余方法
This commit is contained in:
parent
4753958f92
commit
38e18fdfd3
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Dict, List, Tuple
|
||||
from collections import OrderedDict
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
|
@ -13,8 +14,136 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
|||
return h
|
||||
|
||||
|
||||
class PagePool:
|
||||
"""Bitmask page allocator with ref-counting and LRU eviction."""
|
||||
|
||||
def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None):
|
||||
self._free_mask = (1 << n_pages) - 1
|
||||
self._refs: List[int] = [0] * n_pages
|
||||
self._lru: OrderedDict[int, None] = OrderedDict()
|
||||
self._on_evict = on_evict
|
||||
|
||||
def alloc(self) -> int:
|
||||
if self._free_mask:
|
||||
lsb = self._free_mask & -self._free_mask
|
||||
idx = lsb.bit_length() - 1
|
||||
self._free_mask ^= lsb
|
||||
self._refs[idx] = 1
|
||||
return idx
|
||||
if self._lru:
|
||||
idx, _ = self._lru.popitem(last=False)
|
||||
if self._on_evict:
|
||||
self._on_evict(idx)
|
||||
self._refs[idx] = 1
|
||||
self._free_mask &= ~(1 << idx)
|
||||
return idx
|
||||
return -1
|
||||
|
||||
def free(self, idx: int, keep_cached: bool = False) -> None:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
if keep_cached:
|
||||
self._lru[idx] = None
|
||||
else:
|
||||
self._free_mask |= 1 << idx
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
self._refs[idx] += 1
|
||||
|
||||
def touch(self, idx: int) -> None:
|
||||
self._lru.move_to_end(idx)
|
||||
|
||||
def remove_from_lru(self, idx: int) -> None:
|
||||
self._lru.pop(idx, None)
|
||||
|
||||
|
||||
class PrefixCache:
|
||||
"""Hash-based prefix matching: maps page hashes to physical page indices."""
|
||||
|
||||
def __init__(self, page_size: int):
|
||||
self._page_size = page_size
|
||||
self._page_to_hash: Dict[int, int] = {}
|
||||
self._hash_to_page: Dict[int, int] = {}
|
||||
|
||||
def on_evict(self, idx: int) -> None:
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
self._hash_to_page.pop(h, None)
|
||||
|
||||
def has_page(self, idx: int) -> bool:
|
||||
return idx in self._page_to_hash
|
||||
|
||||
def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]:
|
||||
full_pages = len(token_ids) // self._page_size
|
||||
hits: List[int] = []
|
||||
for i in range(full_pages):
|
||||
h = page_hash(token_ids, i, self._page_size)
|
||||
p = self._hash_to_page.get(h)
|
||||
if p is None:
|
||||
break
|
||||
pool.touch(p)
|
||||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def record(
|
||||
self,
|
||||
page_idx: int,
|
||||
token_ids: List[int],
|
||||
logical_page_idx: int,
|
||||
pool: PagePool,
|
||||
) -> None:
|
||||
h = page_hash(token_ids, logical_page_idx, self._page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
if old_h is not None:
|
||||
self._hash_to_page.pop(old_h, None)
|
||||
self._page_to_hash[page_idx] = h
|
||||
self._hash_to_page[h] = page_idx
|
||||
pool.remove_from_lru(page_idx)
|
||||
|
||||
|
||||
class TaskTable:
|
||||
"""Maps task_ids to page tables and cached token counts."""
|
||||
|
||||
def __init__(self, pool: PagePool, page_size: int):
|
||||
self._pool = pool
|
||||
self._page_size = page_size
|
||||
self._pages: Dict[str, List[int]] = {}
|
||||
self._cached: Dict[str, int] = {}
|
||||
|
||||
def set(self, task_id: str, page_table: List[int], cached: int) -> None:
|
||||
self._pages[task_id] = page_table
|
||||
self._cached[task_id] = cached
|
||||
|
||||
def get(self, task_id: str) -> List[int]:
|
||||
return self._pages.get(task_id, [])
|
||||
|
||||
def get_cached(self, task_id: str) -> int:
|
||||
return self._cached.get(task_id, 0)
|
||||
|
||||
def pop(self, task_id: str) -> Tuple[List[int], int]:
|
||||
pages = self._pages.pop(task_id, [])
|
||||
cached = self._cached.pop(task_id, 0)
|
||||
return pages, cached
|
||||
|
||||
def extend(self, task_id: str, pos: int) -> bool:
|
||||
page_table = self._pages[task_id]
|
||||
needed = (pos + 1 + self._page_size - 1) // self._page_size
|
||||
while len(page_table) < needed:
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
return False
|
||||
page_table.append(p)
|
||||
return True
|
||||
|
||||
def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
states = [self._pages.get(tid, []) for tid in task_ids]
|
||||
max_pages = max((len(s) for s in states), default=0)
|
||||
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||
|
||||
|
||||
class PagedCache:
|
||||
"""Paged KV cache: page pool, prefix-cache lookup, LRU eviction, task-page mapping."""
|
||||
"""Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -27,8 +156,10 @@ class PagedCache:
|
|||
dtype: torch.dtype,
|
||||
):
|
||||
self.page_size = page_size
|
||||
self._free_mask = (1 << n_pages) - 1
|
||||
self._refs: List[int] = [0] * n_pages
|
||||
self._prefix = PrefixCache(page_size)
|
||||
self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict)
|
||||
self._table = TaskTable(self._pool, page_size)
|
||||
|
||||
self.k_cache = torch.empty(
|
||||
(n_layers, n_pages, page_size, n_kv_heads, head_dim),
|
||||
device=device,
|
||||
|
|
@ -39,160 +170,81 @@ class PagedCache:
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
self._page_to_hash: Dict[int, int] = {}
|
||||
self._hash_to_page: Dict[int, int] = {}
|
||||
self._lru: List[int] = []
|
||||
self._pin: List[bool] = [False] * n_pages
|
||||
self._task_pages: Dict[str, List[int]] = {}
|
||||
self._task_cached: Dict[str, int] = {}
|
||||
|
||||
def pages_needed(self, n_tokens: int) -> int:
|
||||
return (n_tokens + self.page_size - 1) // self.page_size
|
||||
def alloc_n(self, n: int) -> List[int]:
|
||||
pages: List[int] = []
|
||||
for _ in range(n):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for page in pages:
|
||||
self.free(page)
|
||||
return []
|
||||
pages.append(p)
|
||||
return pages
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
cached = self._prefix.has_page(idx)
|
||||
self._pool.free(idx, keep_cached=cached)
|
||||
if not cached:
|
||||
self._prefix.on_evict(idx)
|
||||
|
||||
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||
hit_pages = self.lookup_prefix(prompt_ids)
|
||||
cached_tokens = len(hit_pages) * self.page_size
|
||||
for p in hit_pages:
|
||||
self.inc_ref(p)
|
||||
hits = self._prefix.lookup(prompt_ids, self._pool)
|
||||
cached = len(hits) * self.page_size
|
||||
for p in hits:
|
||||
self._pool.inc_ref(p)
|
||||
|
||||
remaining = len(prompt_ids) - cached_tokens
|
||||
n_new = self.pages_needed(remaining) if remaining > 0 else 0
|
||||
new_pages = self.alloc_n(n_new) if n_new > 0 else []
|
||||
remaining = len(prompt_ids) - cached
|
||||
n_new = (
|
||||
(remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0
|
||||
)
|
||||
new_pages: List[int] = []
|
||||
if n_new > 0:
|
||||
for _ in range(n_new):
|
||||
p = self._pool.alloc()
|
||||
if p < 0:
|
||||
for hp in hits:
|
||||
self.free(hp)
|
||||
for np in new_pages:
|
||||
self.free(np)
|
||||
return False
|
||||
new_pages.append(p)
|
||||
|
||||
if remaining > 0 and not new_pages:
|
||||
for p in hit_pages:
|
||||
self.free(p)
|
||||
return False
|
||||
|
||||
page_table = hit_pages + new_pages
|
||||
self._task_pages[task_id] = page_table
|
||||
self._task_cached[task_id] = cached_tokens
|
||||
self._table.set(task_id, hits + new_pages, cached)
|
||||
return True
|
||||
|
||||
def task_free(self, task_id: str) -> None:
|
||||
page_table = self._task_pages.pop(task_id, None)
|
||||
self._task_cached.pop(task_id, None)
|
||||
if page_table:
|
||||
for idx in page_table:
|
||||
self.free(idx)
|
||||
page_table, _ = self._table.pop(task_id)
|
||||
for idx in page_table:
|
||||
self.free(idx)
|
||||
|
||||
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||
needed = self.pages_needed(pos + 1)
|
||||
page_table = self._task_pages[task_id]
|
||||
while len(page_table) < needed:
|
||||
p = self.alloc()
|
||||
if p < 0:
|
||||
return False
|
||||
page_table.append(p)
|
||||
return True
|
||||
return self._table.extend(task_id, pos)
|
||||
|
||||
def task_cached(self, task_id: str) -> int:
|
||||
return self._task_cached.get(task_id, 0)
|
||||
|
||||
def task_page_table(self, task_id: str) -> List[int]:
|
||||
return self._task_pages.get(task_id, [])
|
||||
|
||||
def task_n_pages(self, task_id: str) -> int:
|
||||
return len(self._task_pages.get(task_id, []))
|
||||
return self._table.get_cached(task_id)
|
||||
|
||||
def task_record_hashes(
|
||||
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||
) -> None:
|
||||
page_table = self._task_pages[task_id]
|
||||
page_table = self._table.get(task_id)
|
||||
full_pages = len(prompt_ids) // self.page_size
|
||||
for i in range(start_logical_page, full_pages):
|
||||
self.record_page(page_table[i], prompt_ids, i)
|
||||
self._prefix.record(page_table[i], prompt_ids, i, self._pool)
|
||||
|
||||
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||
states = [self._task_pages.get(tid, []) for tid in task_ids]
|
||||
max_pages = max((len(s) for s in states), default=0)
|
||||
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||
|
||||
def _touch(self, idx: int) -> None:
|
||||
if self._refs[idx] == 0 and idx in self._lru:
|
||||
self._lru.remove(idx)
|
||||
self._lru.append(idx)
|
||||
|
||||
def _evict_one(self) -> int:
|
||||
while self._lru:
|
||||
idx = self._lru.pop(0)
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
self._hash_to_page.pop(h, None)
|
||||
self._pin[idx] = False
|
||||
self._refs[idx] = 1
|
||||
return idx
|
||||
return -1
|
||||
|
||||
def record_page(
|
||||
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||
) -> None:
|
||||
h = page_hash(token_ids, logical_page_idx, self.page_size)
|
||||
old_h = self._page_to_hash.pop(page_idx, None)
|
||||
if old_h is not None:
|
||||
self._hash_to_page.pop(old_h, None)
|
||||
self._page_to_hash[page_idx] = h
|
||||
self._hash_to_page[h] = page_idx
|
||||
self._pin[page_idx] = True
|
||||
if page_idx in self._lru:
|
||||
self._lru.remove(page_idx)
|
||||
|
||||
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
|
||||
full_pages = len(token_ids) // self.page_size
|
||||
hits: List[int] = []
|
||||
for i in range(full_pages):
|
||||
h = page_hash(token_ids, i, self.page_size)
|
||||
p = self._hash_to_page.get(h)
|
||||
if p is None:
|
||||
break
|
||||
self._touch(p)
|
||||
hits.append(p)
|
||||
return hits
|
||||
|
||||
def inc_ref(self, idx: int) -> None:
|
||||
self._refs[idx] += 1
|
||||
if self._refs[idx] == 1 and idx in self._lru:
|
||||
self._lru.remove(idx)
|
||||
|
||||
def alloc(self) -> int:
|
||||
if self._free_mask:
|
||||
lsb = self._free_mask & -self._free_mask
|
||||
idx = lsb.bit_length() - 1
|
||||
self._free_mask ^= lsb
|
||||
self._refs[idx] = 1
|
||||
if idx in self._lru:
|
||||
self._lru.remove(idx)
|
||||
return idx
|
||||
return self._evict_one()
|
||||
|
||||
def alloc_n(self, n: int) -> List[int]:
|
||||
pages = [self.alloc() for _ in range(n)]
|
||||
if any(p < 0 for p in pages):
|
||||
for p in pages:
|
||||
if p >= 0:
|
||||
self.free(p)
|
||||
return []
|
||||
return pages
|
||||
|
||||
def free(self, idx: int) -> None:
|
||||
self._refs[idx] -= 1
|
||||
if self._refs[idx] == 0:
|
||||
h = self._page_to_hash.get(idx)
|
||||
if h is not None and self._pin[idx]:
|
||||
self._lru.append(idx)
|
||||
else:
|
||||
self._free_mask |= 1 << idx
|
||||
h = self._page_to_hash.pop(idx, None)
|
||||
if h is not None:
|
||||
self._hash_to_page.pop(h, None)
|
||||
self._pin[idx] = False
|
||||
return self._table.table_tensor(task_ids, device)
|
||||
|
||||
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
||||
return CacheView(self, page_table, total_len)
|
||||
|
||||
def write(
|
||||
self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor
|
||||
self,
|
||||
layer_id: int,
|
||||
page_table: Tensor,
|
||||
start_pos: int,
|
||||
k: Tensor,
|
||||
v: Tensor,
|
||||
) -> None:
|
||||
seq_len = k.size(1)
|
||||
if seq_len == 0:
|
||||
|
|
@ -232,8 +284,6 @@ class PagedCache:
|
|||
class CacheView:
|
||||
"""Bundles PagedCache + page_table + total_len for attention layers."""
|
||||
|
||||
__slots__ = ("_cache", "_page_table", "_total_len")
|
||||
|
||||
def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0):
|
||||
self._cache = cache
|
||||
self._page_table = page_table
|
||||
|
|
|
|||
|
|
@ -13,33 +13,22 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class Executor:
|
||||
"""Model forward passes for prefill and decode; delegates page ops to PagedCache."""
|
||||
"""Model forward passes for prefill and decode phases."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: AutoModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
page_cache: PagedCache,
|
||||
page_size: int = 64,
|
||||
device: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.page_cache = page_cache
|
||||
self.page_size = page_size
|
||||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
def allocate_pages_for_activation(self, task: Task) -> bool:
|
||||
return self.page_cache.task_alloc(task.task_id, task.prompt_ids)
|
||||
|
||||
def free_task_pages(self, task: Task) -> None:
|
||||
self.page_cache.task_free(task.task_id)
|
||||
|
||||
def get_cached_tokens(self, task: Task) -> int:
|
||||
return self.page_cache.task_cached(task.task_id)
|
||||
|
||||
def execute_prefill(
|
||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||
) -> None:
|
||||
|
|
@ -71,7 +60,7 @@ class Executor:
|
|||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||
)
|
||||
|
||||
start_logical_page = start_pos // self.page_size
|
||||
start_logical_page = start_pos // self.page_cache.page_size
|
||||
for t in tasks:
|
||||
self.page_cache.task_record_hashes(
|
||||
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
|
||||
|
|
|
|||
|
|
@ -33,19 +33,16 @@ class InferenceScheduler:
|
|||
self.device = device or next(model.parameters()).device
|
||||
self.dtype = dtype or next(model.parameters()).dtype
|
||||
|
||||
n_kv_heads = config.n_kv_heads
|
||||
head_dim = config.dim // config.n_heads
|
||||
n_layers = config.n_layers
|
||||
n_pages = (
|
||||
max_batch_size * (self.max_seq_len + page_size) + page_size - 1
|
||||
) // page_size
|
||||
|
||||
page_cache = PagedCache(
|
||||
n_layers,
|
||||
self._page_cache = PagedCache(
|
||||
config.n_layers,
|
||||
n_pages,
|
||||
page_size,
|
||||
n_kv_heads,
|
||||
head_dim,
|
||||
config.n_kv_heads,
|
||||
config.dim // config.n_heads,
|
||||
self.device,
|
||||
self.dtype,
|
||||
)
|
||||
|
|
@ -60,8 +57,7 @@ class InferenceScheduler:
|
|||
self._executor = Executor(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
page_cache=page_cache,
|
||||
page_size=page_size,
|
||||
page_cache=self._page_cache,
|
||||
device=self.device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
|
@ -73,7 +69,7 @@ class InferenceScheduler:
|
|||
|
||||
def remove_task(self, task_id: str) -> None:
|
||||
for task in self._task_mgr.remove_task(task_id):
|
||||
self._executor.free_task_pages(task)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
return self._task_mgr.get_stats()
|
||||
|
|
@ -85,7 +81,7 @@ class InferenceScheduler:
|
|||
self._task_mgr.tokenizer.stop_ids
|
||||
)
|
||||
for task in finished:
|
||||
self._executor.free_task_pages(task)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
|
||||
available = self._task_mgr.max_batch_size - len(
|
||||
self._task_mgr.active_tasks
|
||||
|
|
@ -94,7 +90,7 @@ class InferenceScheduler:
|
|||
candidates = self._task_mgr.pull_candidates(available)
|
||||
failed = []
|
||||
for task in candidates:
|
||||
if self._executor.allocate_pages_for_activation(task):
|
||||
if self._page_cache.task_alloc(task.task_id, task.prompt_ids):
|
||||
self._task_mgr.activate(task)
|
||||
else:
|
||||
failed.append(task)
|
||||
|
|
@ -114,7 +110,10 @@ class InferenceScheduler:
|
|||
|
||||
groups: Dict[Tuple[int, int], List[Task]] = {}
|
||||
for t in to_prefill:
|
||||
key = (len(t.prompt_ids), self._executor.get_cached_tokens(t))
|
||||
key = (
|
||||
len(t.prompt_ids),
|
||||
self._page_cache.task_cached(t.task_id),
|
||||
)
|
||||
groups.setdefault(key, []).append(t)
|
||||
|
||||
for (prompt_len, start_pos), group in groups.items():
|
||||
|
|
|
|||
Loading…
Reference in New Issue