refactor: 页状态移入 PagedCache,Task 纯化为域对象

- PagedCache 增 task_alloc/task_free/task_extend/task_cached/task_record_hashes/make_table_tensor
- Task 移除 page_table/n_pages/_prefix_cached_tokens/_pages_freed
- Executor 移除 _PageState,页操作全部委托 PagedCache
- CacheView.gather 截断逻辑下沉到 PagedCache.gather
- 各类补充单行职责 docstring
This commit is contained in:
ViperEkura 2026-05-11 14:42:39 +08:00
parent 73d6cc0f26
commit 4753958f92
4 changed files with 98 additions and 90 deletions

View File

@ -1,9 +1,3 @@
"""Page-based KV cache with page-table-indirected read/write.
Provides:
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import torch import torch
@ -20,17 +14,7 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
class PagedCache: class PagedCache:
"""Paged KV cache with page-table-indirected read/write and persistent prefix caching. """Paged KV cache: page pool, prefix-cache lookup, LRU eviction, task-page mapping."""
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
- LRU eviction for persistent cross-batch prefix caching
Pages with recorded hashes persist after refcount reaches 0 (pinned).
They are evicted via LRU only when alloc() finds no free pages.
"""
def __init__( def __init__(
self, self,
@ -59,6 +43,71 @@ class PagedCache:
self._hash_to_page: Dict[int, int] = {} self._hash_to_page: Dict[int, int] = {}
self._lru: List[int] = [] self._lru: List[int] = []
self._pin: List[bool] = [False] * n_pages self._pin: List[bool] = [False] * n_pages
self._task_pages: Dict[str, List[int]] = {}
self._task_cached: Dict[str, int] = {}
def pages_needed(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
hit_pages = self.lookup_prefix(prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.inc_ref(p)
remaining = len(prompt_ids) - cached_tokens
n_new = self.pages_needed(remaining) if remaining > 0 else 0
new_pages = self.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.free(p)
return False
page_table = hit_pages + new_pages
self._task_pages[task_id] = page_table
self._task_cached[task_id] = cached_tokens
return True
def task_free(self, task_id: str) -> None:
page_table = self._task_pages.pop(task_id, None)
self._task_cached.pop(task_id, None)
if page_table:
for idx in page_table:
self.free(idx)
def task_extend(self, task_id: str, pos: int) -> bool:
needed = self.pages_needed(pos + 1)
page_table = self._task_pages[task_id]
while len(page_table) < needed:
p = self.alloc()
if p < 0:
return False
page_table.append(p)
return True
def task_cached(self, task_id: str) -> int:
return self._task_cached.get(task_id, 0)
def task_page_table(self, task_id: str) -> List[int]:
return self._task_pages.get(task_id, [])
def task_n_pages(self, task_id: str) -> int:
return len(self._task_pages.get(task_id, []))
def task_record_hashes(
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
) -> None:
page_table = self._task_pages[task_id]
full_pages = len(prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.record_page(page_table[i], prompt_ids, i)
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
states = [self._task_pages.get(tid, []) for tid in task_ids]
max_pages = max((len(s) for s in states), default=0)
rows = [s + [-1] * (max_pages - len(s)) for s in states]
return torch.tensor(rows, dtype=torch.long, device=device)
def _touch(self, idx: int) -> None: def _touch(self, idx: int) -> None:
if self._refs[idx] == 0 and idx in self._lru: if self._refs[idx] == 0 and idx in self._lru:
@ -167,23 +216,21 @@ class PagedCache:
] ]
written += chunk written += chunk
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]: def gather(
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages. self, layer_id: int, page_table: Tensor, total_len: int
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len. ) -> Tuple[Tensor, Tensor]:
safe = page_table.clamp(min=0) safe = page_table.clamp(min=0)
k = self.k_cache[layer_id, safe] k = self.k_cache[layer_id, safe]
v = self.v_cache[layer_id, safe] v = self.v_cache[layer_id, safe]
k = k.flatten(1, 2) k = k.flatten(1, 2)
v = v.flatten(1, 2) v = v.flatten(1, 2)
k = k[:, :total_len]
v = v[:, :total_len]
return k, v return k, v
class CacheView: class CacheView:
"""Per-batch view that bundles PagedCache + page_table + total_len. """Bundles PagedCache + page_table + total_len for attention layers."""
Attention layers receive this as ``paged_cache`` and only see
``write()`` / ``gather()``, never raw page tables or length params.
"""
__slots__ = ("_cache", "_page_table", "_total_len") __slots__ = ("_cache", "_page_table", "_total_len")
@ -196,8 +243,4 @@ class CacheView:
self._cache.write(layer_id, self._page_table, start_pos, k, v) self._cache.write(layer_id, self._page_table, start_pos, k, v)
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
k, v = self._cache.gather(layer_id, self._page_table) return self._cache.gather(layer_id, self._page_table, self._total_len)
if self._total_len:
k = k[:, : self._total_len]
v = v[:, : self._total_len]
return k, v

View File

@ -2,7 +2,6 @@ import logging
from typing import List, Optional from typing import List, Optional
import torch import torch
from torch import Tensor
from astrai.inference.cache import PagedCache from astrai.inference.cache import PagedCache
from astrai.inference.sample import sample from astrai.inference.sample import sample
@ -14,6 +13,8 @@ logger = logging.getLogger(__name__)
class Executor: class Executor:
"""Model forward passes for prefill and decode; delegates page ops to PagedCache."""
def __init__( def __init__(
self, self,
model: AutoModel, model: AutoModel,
@ -31,34 +32,13 @@ class Executor:
self.dtype = dtype or next(model.parameters()).dtype self.dtype = dtype or next(model.parameters()).dtype
def allocate_pages_for_activation(self, task: Task) -> bool: def allocate_pages_for_activation(self, task: Task) -> bool:
prompt_len = len(task.prompt_ids) return self.page_cache.task_alloc(task.task_id, task.prompt_ids)
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
return False
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
return True
def free_task_pages(self, task: Task) -> None: def free_task_pages(self, task: Task) -> None:
if task._pages_freed: self.page_cache.task_free(task.task_id)
return
for idx in task.page_table: def get_cached_tokens(self, task: Task) -> int:
self.page_cache.free(idx) return self.page_cache.task_cached(task.task_id)
task.page_table.clear()
task.n_pages = 0
task._pages_freed = True
def execute_prefill( def execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0 self, tasks: List[Task], prompt_len: int, start_pos: int = 0
@ -80,7 +60,8 @@ class Executor:
t.prompt_ids[start_pos:prompt_len], device=self.device t.prompt_ids[start_pos:prompt_len], device=self.device
) )
page_tables = self._make_page_table_tensor(tasks) task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
with torch.inference_mode(): with torch.inference_mode():
self.model( self.model(
@ -92,7 +73,9 @@ class Executor:
start_logical_page = start_pos // self.page_size start_logical_page = start_pos // self.page_size
for t in tasks: for t in tasks:
self._record_page_hashes(t, start_logical_page=start_logical_page) self.page_cache.task_record_hashes(
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
)
def execute_decode(self, tasks: List[Task], start_pos: int) -> None: def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks: if not tasks:
@ -102,7 +85,7 @@ class Executor:
valid: List[Task] = [] valid: List[Task] = []
for t in tasks: for t in tasks:
if self._maybe_alloc_page(t, start_pos): if self.page_cache.task_extend(t.task_id, start_pos):
valid.append(t) valid.append(t)
else: else:
t.status = TaskStatus.ABORTED t.status = TaskStatus.ABORTED
@ -123,7 +106,8 @@ class Executor:
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
page_tables = self._make_page_table_tensor(tasks) task_ids = [t.task_id for t in tasks]
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
total_len = start_pos + 1 total_len = start_pos + 1
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
@ -150,7 +134,7 @@ class Executor:
t.output_ids.append(ntok) t.output_ids.append(ntok)
t.output_tokens += 1 t.output_tokens += 1
pos = t.input_tokens + t.output_tokens pos = t.input_tokens + t.output_tokens
self._maybe_alloc_page(t, pos) self.page_cache.task_extend(t.task_id, pos)
if t.stream_callback: if t.stream_callback:
t.stream_callback(self.tokenizer.decode([ntok])) t.stream_callback(self.tokenizer.decode([ntok]))
@ -158,26 +142,3 @@ class Executor:
if t.is_finished(self.tokenizer.stop_ids): if t.is_finished(self.tokenizer.stop_ids):
if t.stream_callback: if t.stream_callback:
t.stream_callback(STOP) t.stream_callback(STOP)
def _n_pages_for(self, n_tokens: int) -> int:
return (n_tokens + self.page_size - 1) // self.page_size
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
max_pages = max(t.n_pages for t in tasks)
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
return torch.tensor(rows, dtype=torch.long, device=self.device)
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
needed = self._n_pages_for(pos + 1)
while task.n_pages < needed:
p = self.page_cache.alloc()
if p < 0:
return False
task.page_table.append(p)
task.n_pages += 1
return True

View File

@ -14,6 +14,8 @@ logger = logging.getLogger(__name__)
class InferenceScheduler: class InferenceScheduler:
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
def __init__( def __init__(
self, self,
model: AutoModel, model: AutoModel,
@ -112,7 +114,7 @@ class InferenceScheduler:
groups: Dict[Tuple[int, int], List[Task]] = {} groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill: for t in to_prefill:
key = (len(t.prompt_ids), t._prefix_cached_tokens) key = (len(t.prompt_ids), self._executor.get_cached_tokens(t))
groups.setdefault(key, []).append(t) groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items(): for (prompt_len, start_pos), group in groups.items():

View File

@ -13,6 +13,8 @@ STOP = object()
class TaskStatus(Enum): class TaskStatus(Enum):
"""Task lifecycle states."""
PENDING = "pending" PENDING = "pending"
RUNNING = "running" RUNNING = "running"
FINISHED = "finished" FINISHED = "finished"
@ -20,6 +22,8 @@ class TaskStatus(Enum):
class Task: class Task:
"""Single generation request: prompt, sampling params, output state."""
def __init__( def __init__(
self, self,
task_id: str, task_id: str,
@ -41,13 +45,9 @@ class Task:
self.output_ids: List[int] = [] self.output_ids: List[int] = []
self.input_tokens: int = 0 self.input_tokens: int = 0
self.output_tokens: int = 0 self.output_tokens: int = 0
self.page_table: List[int] = []
self.n_pages: int = 0
self._prefix_cached_tokens: int = 0
self.arrival_time = time.time() self.arrival_time = time.time()
self.finish_time: Optional[float] = None self.finish_time: Optional[float] = None
self.stream_callback = stream_callback self.stream_callback = stream_callback
self._pages_freed: bool = False
@property @property
def next_pos(self) -> int: def next_pos(self) -> int:
@ -62,6 +62,8 @@ class Task:
class TaskManager: class TaskManager:
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
def __init__( def __init__(
self, self,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,