refactor: 页状态移入 PagedCache,Task 纯化为域对象
- PagedCache 增 task_alloc/task_free/task_extend/task_cached/task_record_hashes/make_table_tensor - Task 移除 page_table/n_pages/_prefix_cached_tokens/_pages_freed - Executor 移除 _PageState,页操作全部委托 PagedCache - CacheView.gather 截断逻辑下沉到 PagedCache.gather - 各类补充单行职责 docstring
This commit is contained in:
parent
73d6cc0f26
commit
4753958f92
|
|
@ -1,9 +1,3 @@
|
||||||
"""Page-based KV cache with page-table-indirected read/write.
|
|
||||||
|
|
||||||
Provides:
|
|
||||||
- PagedCache: paged KV cache combining page pool and tensor storage.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -20,17 +14,7 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||||
|
|
||||||
|
|
||||||
class PagedCache:
|
class PagedCache:
|
||||||
"""Paged KV cache with page-table-indirected read/write and persistent prefix caching.
|
"""Paged KV cache: page pool, prefix-cache lookup, LRU eviction, task-page mapping."""
|
||||||
|
|
||||||
Combines:
|
|
||||||
- Page pool (ref-counted alloc/free via bitmask)
|
|
||||||
- KV tensor storage (k_cache, v_cache)
|
|
||||||
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
|
|
||||||
- LRU eviction for persistent cross-batch prefix caching
|
|
||||||
|
|
||||||
Pages with recorded hashes persist after refcount reaches 0 (pinned).
|
|
||||||
They are evicted via LRU only when alloc() finds no free pages.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -59,6 +43,71 @@ class PagedCache:
|
||||||
self._hash_to_page: Dict[int, int] = {}
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
self._lru: List[int] = []
|
self._lru: List[int] = []
|
||||||
self._pin: List[bool] = [False] * n_pages
|
self._pin: List[bool] = [False] * n_pages
|
||||||
|
self._task_pages: Dict[str, List[int]] = {}
|
||||||
|
self._task_cached: Dict[str, int] = {}
|
||||||
|
|
||||||
|
def pages_needed(self, n_tokens: int) -> int:
|
||||||
|
return (n_tokens + self.page_size - 1) // self.page_size
|
||||||
|
|
||||||
|
def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool:
|
||||||
|
hit_pages = self.lookup_prefix(prompt_ids)
|
||||||
|
cached_tokens = len(hit_pages) * self.page_size
|
||||||
|
for p in hit_pages:
|
||||||
|
self.inc_ref(p)
|
||||||
|
|
||||||
|
remaining = len(prompt_ids) - cached_tokens
|
||||||
|
n_new = self.pages_needed(remaining) if remaining > 0 else 0
|
||||||
|
new_pages = self.alloc_n(n_new) if n_new > 0 else []
|
||||||
|
|
||||||
|
if remaining > 0 and not new_pages:
|
||||||
|
for p in hit_pages:
|
||||||
|
self.free(p)
|
||||||
|
return False
|
||||||
|
|
||||||
|
page_table = hit_pages + new_pages
|
||||||
|
self._task_pages[task_id] = page_table
|
||||||
|
self._task_cached[task_id] = cached_tokens
|
||||||
|
return True
|
||||||
|
|
||||||
|
def task_free(self, task_id: str) -> None:
|
||||||
|
page_table = self._task_pages.pop(task_id, None)
|
||||||
|
self._task_cached.pop(task_id, None)
|
||||||
|
if page_table:
|
||||||
|
for idx in page_table:
|
||||||
|
self.free(idx)
|
||||||
|
|
||||||
|
def task_extend(self, task_id: str, pos: int) -> bool:
|
||||||
|
needed = self.pages_needed(pos + 1)
|
||||||
|
page_table = self._task_pages[task_id]
|
||||||
|
while len(page_table) < needed:
|
||||||
|
p = self.alloc()
|
||||||
|
if p < 0:
|
||||||
|
return False
|
||||||
|
page_table.append(p)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def task_cached(self, task_id: str) -> int:
|
||||||
|
return self._task_cached.get(task_id, 0)
|
||||||
|
|
||||||
|
def task_page_table(self, task_id: str) -> List[int]:
|
||||||
|
return self._task_pages.get(task_id, [])
|
||||||
|
|
||||||
|
def task_n_pages(self, task_id: str) -> int:
|
||||||
|
return len(self._task_pages.get(task_id, []))
|
||||||
|
|
||||||
|
def task_record_hashes(
|
||||||
|
self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0
|
||||||
|
) -> None:
|
||||||
|
page_table = self._task_pages[task_id]
|
||||||
|
full_pages = len(prompt_ids) // self.page_size
|
||||||
|
for i in range(start_logical_page, full_pages):
|
||||||
|
self.record_page(page_table[i], prompt_ids, i)
|
||||||
|
|
||||||
|
def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor:
|
||||||
|
states = [self._task_pages.get(tid, []) for tid in task_ids]
|
||||||
|
max_pages = max((len(s) for s in states), default=0)
|
||||||
|
rows = [s + [-1] * (max_pages - len(s)) for s in states]
|
||||||
|
return torch.tensor(rows, dtype=torch.long, device=device)
|
||||||
|
|
||||||
def _touch(self, idx: int) -> None:
|
def _touch(self, idx: int) -> None:
|
||||||
if self._refs[idx] == 0 and idx in self._lru:
|
if self._refs[idx] == 0 and idx in self._lru:
|
||||||
|
|
@ -167,23 +216,21 @@ class PagedCache:
|
||||||
]
|
]
|
||||||
written += chunk
|
written += chunk
|
||||||
|
|
||||||
def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]:
|
def gather(
|
||||||
# page_table: [batch, max_pages] with -1 padding for tasks with fewer pages.
|
self, layer_id: int, page_table: Tensor, total_len: int
|
||||||
# clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len.
|
) -> Tuple[Tensor, Tensor]:
|
||||||
safe = page_table.clamp(min=0)
|
safe = page_table.clamp(min=0)
|
||||||
k = self.k_cache[layer_id, safe]
|
k = self.k_cache[layer_id, safe]
|
||||||
v = self.v_cache[layer_id, safe]
|
v = self.v_cache[layer_id, safe]
|
||||||
k = k.flatten(1, 2)
|
k = k.flatten(1, 2)
|
||||||
v = v.flatten(1, 2)
|
v = v.flatten(1, 2)
|
||||||
|
k = k[:, :total_len]
|
||||||
|
v = v[:, :total_len]
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
|
|
||||||
class CacheView:
|
class CacheView:
|
||||||
"""Per-batch view that bundles PagedCache + page_table + total_len.
|
"""Bundles PagedCache + page_table + total_len for attention layers."""
|
||||||
|
|
||||||
Attention layers receive this as ``paged_cache`` and only see
|
|
||||||
``write()`` / ``gather()``, never raw page tables or length params.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("_cache", "_page_table", "_total_len")
|
__slots__ = ("_cache", "_page_table", "_total_len")
|
||||||
|
|
||||||
|
|
@ -196,8 +243,4 @@ class CacheView:
|
||||||
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
self._cache.write(layer_id, self._page_table, start_pos, k, v)
|
||||||
|
|
||||||
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
k, v = self._cache.gather(layer_id, self._page_table)
|
return self._cache.gather(layer_id, self._page_table, self._total_len)
|
||||||
if self._total_len:
|
|
||||||
k = k[:, : self._total_len]
|
|
||||||
v = v[:, : self._total_len]
|
|
||||||
return k, v
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,6 @@ import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from astrai.inference.cache import PagedCache
|
from astrai.inference.cache import PagedCache
|
||||||
from astrai.inference.sample import sample
|
from astrai.inference.sample import sample
|
||||||
|
|
@ -14,6 +13,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Executor:
|
class Executor:
|
||||||
|
"""Model forward passes for prefill and decode; delegates page ops to PagedCache."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: AutoModel,
|
model: AutoModel,
|
||||||
|
|
@ -31,34 +32,13 @@ class Executor:
|
||||||
self.dtype = dtype or next(model.parameters()).dtype
|
self.dtype = dtype or next(model.parameters()).dtype
|
||||||
|
|
||||||
def allocate_pages_for_activation(self, task: Task) -> bool:
|
def allocate_pages_for_activation(self, task: Task) -> bool:
|
||||||
prompt_len = len(task.prompt_ids)
|
return self.page_cache.task_alloc(task.task_id, task.prompt_ids)
|
||||||
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
|
|
||||||
cached_tokens = len(hit_pages) * self.page_size
|
|
||||||
for p in hit_pages:
|
|
||||||
self.page_cache.inc_ref(p)
|
|
||||||
|
|
||||||
remaining = prompt_len - cached_tokens
|
|
||||||
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
|
|
||||||
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
|
|
||||||
|
|
||||||
if remaining > 0 and not new_pages:
|
|
||||||
for p in hit_pages:
|
|
||||||
self.page_cache.free(p)
|
|
||||||
return False
|
|
||||||
|
|
||||||
task.page_table = hit_pages + new_pages
|
|
||||||
task.n_pages = len(task.page_table)
|
|
||||||
task._prefix_cached_tokens = cached_tokens
|
|
||||||
return True
|
|
||||||
|
|
||||||
def free_task_pages(self, task: Task) -> None:
|
def free_task_pages(self, task: Task) -> None:
|
||||||
if task._pages_freed:
|
self.page_cache.task_free(task.task_id)
|
||||||
return
|
|
||||||
for idx in task.page_table:
|
def get_cached_tokens(self, task: Task) -> int:
|
||||||
self.page_cache.free(idx)
|
return self.page_cache.task_cached(task.task_id)
|
||||||
task.page_table.clear()
|
|
||||||
task.n_pages = 0
|
|
||||||
task._pages_freed = True
|
|
||||||
|
|
||||||
def execute_prefill(
|
def execute_prefill(
|
||||||
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||||
|
|
@ -80,7 +60,8 @@ class Executor:
|
||||||
t.prompt_ids[start_pos:prompt_len], device=self.device
|
t.prompt_ids[start_pos:prompt_len], device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
page_tables = self._make_page_table_tensor(tasks)
|
task_ids = [t.task_id for t in tasks]
|
||||||
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
self.model(
|
self.model(
|
||||||
|
|
@ -92,7 +73,9 @@ class Executor:
|
||||||
|
|
||||||
start_logical_page = start_pos // self.page_size
|
start_logical_page = start_pos // self.page_size
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
self._record_page_hashes(t, start_logical_page=start_logical_page)
|
self.page_cache.task_record_hashes(
|
||||||
|
t.task_id, t.prompt_ids, start_logical_page=start_logical_page
|
||||||
|
)
|
||||||
|
|
||||||
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
def execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
|
|
@ -102,7 +85,7 @@ class Executor:
|
||||||
|
|
||||||
valid: List[Task] = []
|
valid: List[Task] = []
|
||||||
for t in tasks:
|
for t in tasks:
|
||||||
if self._maybe_alloc_page(t, start_pos):
|
if self.page_cache.task_extend(t.task_id, start_pos):
|
||||||
valid.append(t)
|
valid.append(t)
|
||||||
else:
|
else:
|
||||||
t.status = TaskStatus.ABORTED
|
t.status = TaskStatus.ABORTED
|
||||||
|
|
@ -123,7 +106,8 @@ class Executor:
|
||||||
|
|
||||||
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
page_tables = self._make_page_table_tensor(tasks)
|
task_ids = [t.task_id for t in tasks]
|
||||||
|
page_tables = self.page_cache.make_table_tensor(task_ids, self.device)
|
||||||
total_len = start_pos + 1
|
total_len = start_pos + 1
|
||||||
|
|
||||||
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
temperatures = torch.tensor([t.temperature for t in tasks], device=self.device)
|
||||||
|
|
@ -150,7 +134,7 @@ class Executor:
|
||||||
t.output_ids.append(ntok)
|
t.output_ids.append(ntok)
|
||||||
t.output_tokens += 1
|
t.output_tokens += 1
|
||||||
pos = t.input_tokens + t.output_tokens
|
pos = t.input_tokens + t.output_tokens
|
||||||
self._maybe_alloc_page(t, pos)
|
self.page_cache.task_extend(t.task_id, pos)
|
||||||
if t.stream_callback:
|
if t.stream_callback:
|
||||||
t.stream_callback(self.tokenizer.decode([ntok]))
|
t.stream_callback(self.tokenizer.decode([ntok]))
|
||||||
|
|
||||||
|
|
@ -158,26 +142,3 @@ class Executor:
|
||||||
if t.is_finished(self.tokenizer.stop_ids):
|
if t.is_finished(self.tokenizer.stop_ids):
|
||||||
if t.stream_callback:
|
if t.stream_callback:
|
||||||
t.stream_callback(STOP)
|
t.stream_callback(STOP)
|
||||||
|
|
||||||
def _n_pages_for(self, n_tokens: int) -> int:
|
|
||||||
return (n_tokens + self.page_size - 1) // self.page_size
|
|
||||||
|
|
||||||
def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor:
|
|
||||||
max_pages = max(t.n_pages for t in tasks)
|
|
||||||
rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks]
|
|
||||||
return torch.tensor(rows, dtype=torch.long, device=self.device)
|
|
||||||
|
|
||||||
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
|
|
||||||
full_pages = len(task.prompt_ids) // self.page_size
|
|
||||||
for i in range(start_logical_page, full_pages):
|
|
||||||
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
|
|
||||||
|
|
||||||
def _maybe_alloc_page(self, task: Task, pos: int) -> bool:
|
|
||||||
needed = self._n_pages_for(pos + 1)
|
|
||||||
while task.n_pages < needed:
|
|
||||||
p = self.page_cache.alloc()
|
|
||||||
if p < 0:
|
|
||||||
return False
|
|
||||||
task.page_table.append(p)
|
|
||||||
task.n_pages += 1
|
|
||||||
return True
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class InferenceScheduler:
|
class InferenceScheduler:
|
||||||
|
"""Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: AutoModel,
|
model: AutoModel,
|
||||||
|
|
@ -112,7 +114,7 @@ class InferenceScheduler:
|
||||||
|
|
||||||
groups: Dict[Tuple[int, int], List[Task]] = {}
|
groups: Dict[Tuple[int, int], List[Task]] = {}
|
||||||
for t in to_prefill:
|
for t in to_prefill:
|
||||||
key = (len(t.prompt_ids), t._prefix_cached_tokens)
|
key = (len(t.prompt_ids), self._executor.get_cached_tokens(t))
|
||||||
groups.setdefault(key, []).append(t)
|
groups.setdefault(key, []).append(t)
|
||||||
|
|
||||||
for (prompt_len, start_pos), group in groups.items():
|
for (prompt_len, start_pos), group in groups.items():
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,8 @@ STOP = object()
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(Enum):
|
class TaskStatus(Enum):
|
||||||
|
"""Task lifecycle states."""
|
||||||
|
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
FINISHED = "finished"
|
FINISHED = "finished"
|
||||||
|
|
@ -20,6 +22,8 @@ class TaskStatus(Enum):
|
||||||
|
|
||||||
|
|
||||||
class Task:
|
class Task:
|
||||||
|
"""Single generation request: prompt, sampling params, output state."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
|
|
@ -41,13 +45,9 @@ class Task:
|
||||||
self.output_ids: List[int] = []
|
self.output_ids: List[int] = []
|
||||||
self.input_tokens: int = 0
|
self.input_tokens: int = 0
|
||||||
self.output_tokens: int = 0
|
self.output_tokens: int = 0
|
||||||
self.page_table: List[int] = []
|
|
||||||
self.n_pages: int = 0
|
|
||||||
self._prefix_cached_tokens: int = 0
|
|
||||||
self.arrival_time = time.time()
|
self.arrival_time = time.time()
|
||||||
self.finish_time: Optional[float] = None
|
self.finish_time: Optional[float] = None
|
||||||
self.stream_callback = stream_callback
|
self.stream_callback = stream_callback
|
||||||
self._pages_freed: bool = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def next_pos(self) -> int:
|
def next_pos(self) -> int:
|
||||||
|
|
@ -62,6 +62,8 @@ class Task:
|
||||||
|
|
||||||
|
|
||||||
class TaskManager:
|
class TaskManager:
|
||||||
|
"""Thread-safe task queues and lifecycle transitions (no page ops)."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tokenizer: AutoTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue