refactor: 页状态移入 PagedCache,Task 纯化为域对象

- PagedCache 增 task_alloc/task_free/task_extend/task_cached/task_record_hashes/make_table_tensor
- Task 移除 page_table/n_pages/_prefix_cached_tokens/_pages_freed
- Executor 移除 _PageState,页操作全部委托 PagedCache
- CacheView.gather 截断逻辑下沉到 PagedCache.gather
- 各类补充单行职责 docstring
This commit is contained in:
ViperEkura 2026-05-11 14:42:39 +08:00
parent 73d6cc0f26
commit 4753958f92
4 changed files with 98 additions and 90 deletions

View File

@ -1,9 +1,3 @@
"""Page-based KV cache with page-table-indirected read/write.
Provides:
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
from typing import Dict, List, Tuple
import torch
@ -20,17 +14,7 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
class PagedCache:
"""Paged KV cache with page-table-indirected read/write and persistent prefix caching.
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
- LRU eviction for persistent cross-batch prefix caching
Pages with recorded hashes persist after refcount reaches 0 (pinned).
They are evicted via LRU only when alloc() finds no free pages.
"""
"""Paged KV cache: page pool, prefix-cache lookup, LRU eviction, task-page mapping."""
def __init__(
self,
@ -59,6 +43,71 @@ class PagedCache:
self._hash_to_page: Dict[int, int] = {}
self._lru: List[int] = []
self._pin: List[bool] = [False] * n_pages
self._task_pages: Dict[str, List[int]] = {}
self._task_cached: Dict[str, int] = {}
def pages_needed(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hit_pages = self.lookup_prefix(prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.inc_ref(p)
remaining = len(prompt_ids) - cached_tokens
n_new = self.pages_needed(remaining) if remaining > 0 else 0
new_pages = self.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.free(p)
return False
page_table = hit_pages + new_pages
self._task_pages[task_id] = page_table
self._task_cached[task_id] = cached_tokens
return True
def task_free(self, task_id: str) -> None:
page_table = self._task_pages.pop(task_id, None)
self._task_cached.pop(task_id, None)
if page_table:
for idx in page_table:
self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
needed = self.pages_needed(pos + 1)
page_table = self._task_pages[task_id]
while len(page_table) < needed:
p = self.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._task_cached.get(task_id, 0)
def task_page_table(self, task_id: str) -> List[int]:
return self._task_pages.get(task_id, [])
def task_n_pages(self, task_id: str) -> int:
return len(self._task_pages.get(task_id, []))
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._task_pages[task_id]
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.record_page(page_table[i], prompt_ids, i)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
states = [self._task_pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
def _touch(self, idx: int) -> None:
if self._refs[idx] == 0 and idx in self._lru:
@ -167,23 +216,21 @@ class PagedCache:
]
written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
def gather(
self, layer_id: int, page_table: Tensor, total_len: int
) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2)
v = v.flatten(1, 2)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v
class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len.
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
"""
"""Bundles PagedCache + page_table + total_len for attention layers."""
__slots__ = ("_cache", "_page_table", "_total_len")
@ -196,8 +243,4 @@ class CacheView:
self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v
return self._cache.gather(layer_id, self._page_table, self._total_len)

View File

@ -2,7 +2,6 @@ import logging
from typing import List, Optional
import torch
from torch import Tensor
from astrai.inference.cache import PagedCache
from astrai.inference.sample import sample
@ -14,6 +13,8 @@ logger = logging.getLogger(__name__)
class Executor:
"""Model forward passes for prefill and decode; delegates page ops to PagedCache."""
def __init__(
self,
model: AutoModel,
@ -31,34 +32,13 @@ class Executor:
self.dtype = dtype or next(model.parameters()).dtype
def allocate_pages_for_activation(self, task: Task) -> bool:
prompt_len = len(task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
return False
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
return True
return self.page_cache.task_alloc(task.task_id, task.prompt_ids)
def free_task_pages(self, task: Task) -> None:
if task._pages_freed:
return
for idx in task.page_table:
self.page_cache.free(idx)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
self.page_cache.task_free(task.task_id)
def get_cached_tokens(self, task: Task) -> int:
return self.page_cache.task_cached(task.task_id)
def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
@ -80,7 +60,8 @@ class Executor:
t.prompt_ids[start_pos:prompt_len], device=self.device
)
page_tables = self._make_page_table_tensor(tasks)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode():
self.model(
@ -92,7 +73,9 @@ class Executor:
start_logical_page = start_pos // self.page_size
for t in tasks:
self._record_page_hashes(t, start_logical_page=start_logical_page)
self.page_cache.task_record_hashes(
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks:
@ -102,7 +85,7 @@ class Executor:
valid: List[Task] = []
for t in tasks:
if self._maybe_alloc_page(t, start_pos):
if self.page_cache.task_extend(t.task_id, start_pos):
valid.append(t)
else:
t.status = TaskStatus.ABORTED
@ -123,7 +106,8 @@ class Executor:
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks)
task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
@ -150,7 +134,7 @@ class Executor:
t.output_ids.append(ntok)
t.output_tokens += 1
pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos)
self.page_cache.task_extend(t.task_id, pos)
if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok]))
@ -158,26 +142,3 @@ class Executor:
if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback:
t.stream_callback(STOP)
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
return False
task.page_table.append(p)
task.n_pages += 1
return True

View File

@ -14,6 +14,8 @@ logger = logging.getLogger(__name__)
class InferenceScheduler:
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
def __init__(
self,
model: AutoModel,
@ -112,7 +114,7 @@ class InferenceScheduler:
groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill:
key = (len(t.prompt_ids), t._prefix_cached_tokens)
key = (len(t.prompt_ids), self._executor.get_cached_tokens(t))
groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items():

View File

@ -13,6 +13,8 @@ STOP = object()
class TaskStatus(Enum):
"""Task lifecycle states."""
PENDING = "pending"
RUNNING = "running"
FINISHED = "finished"
@ -20,6 +22,8 @@ class TaskStatus(Enum):
class Task:
"""Single generation request: prompt, sampling params, output state."""
def __init__(
self,
task_id: str,
@ -41,13 +45,9 @@ class Task:
self.output_ids: List[int] = []
self.input_tokens: int = 0
self.output_tokens: int = 0
self.page_table: List[int] = []
self.n_pages: int = 0
self._prefix_cached_tokens: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
self._pages_freed: bool = False
@property
def next_pos(self) -> int:
@ -62,6 +62,8 @@ class Task:
class TaskManager:
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
def __init__(
self,
tokenizer: AutoTokenizer,