feat: 推理引擎前缀缓存(KV cache 复用)

- cache.py: 新增模块级 page_hash() 多项式滚动哈希函数;PagedCache 新增
  record_page/lookup_prefix/inc_ref,free() 自动清理哈希映射
- scheduler.py: Task 新增 _prefix_cached_tokens;_refill_active_batch 先查
  缓存命中页(inc_ref)再分配剩余页;合并 _execute_prefill 为单一方法,
  按 (prompt_len, start_pos) 分组批量执行全量/部分 prefill;
  _record_page_hashes 注册完整页哈希;修复 device/dtype 默认值从硬编码
  改为 None(自动检测模型设备)
- test: mock model 补充 dtype/device 适配自动检测
This commit is contained in:
ViperEkura 2026-05-09 23:53:23 +08:00
parent ca4e6b907c
commit 3583c46b66
3 changed files with 96 additions and 40 deletions

View File

@ -4,7 +4,7 @@ Provides:
- PagedCache: paged KV cache combining page pool and tensor storage.
"""
from typing import List, Tuple
from typing import Dict, List, Tuple
import torch
from torch import Tensor
@ -12,12 +12,22 @@ from torch import Tensor
STOP = object()
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size
end = min(start + page_size, len(token_ids))
h = 0
for i in range(start, end):
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
return h
class PagedCache:
"""Paged KV cache with page-table-indirected read/write.
Combines:
- Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache)
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
Call :meth:`bind` to obtain a batch view for the attention layers.
"""
@ -45,6 +55,32 @@ class PagedCache:
device=device,
dtype=dtype,
)
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
def record_page(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
h = page_hash(token_ids, logical_page_idx, self.page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
full_pages = len(token_ids) // self.page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self.page_size)
p = self._hash_to_page.get(h)
if p is None:
break
hits.append(p)
return hits
def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1
def alloc(self) -> int:
lsb = self._free_mask & -self._free_mask
@ -68,6 +104,9 @@ class PagedCache:
self._refs[idx] -= 1
if self._refs[idx] == 0:
self._free_mask |= 1 << idx
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len)

View File

@ -5,7 +5,7 @@ import threading
import time
import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
import torch
from torch import Tensor
@ -53,6 +53,7 @@ class Task:
self.output_tokens: int = 0
self.page_table: List[int] = []
self.n_pages: int = 0
self._prefix_cached_tokens: int = 0
self.arrival_time = time.time()
self.finish_time: Optional[float] = None
self.stream_callback = stream_callback
@ -88,8 +89,8 @@ class InferenceScheduler:
max_seq_len: Optional[int] = None,
max_prompt_len: int = 512,
page_size: int = 64,
device: str = "cuda",
dtype: torch.dtype = torch.bfloat16,
device: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
):
config = model.config
@ -180,6 +181,11 @@ class InferenceScheduler:
for idx in indices:
self.page_cache.free(idx)
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
def _remove_finished_tasks(self) -> None:
finished = []
for task in self.active_tasks:
@ -214,12 +220,25 @@ class InferenceScheduler:
failed: List[Task] = []
for task in to_add:
prompt_len = len(task.prompt_ids)
n_pages = self._n_pages_for(prompt_len)
task.page_table = self.page_cache.alloc_n(n_pages)
if not task.page_table:
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
failed.append(task)
continue
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
task.status = TaskStatus.RUNNING
self.active_tasks.append(task)
@ -227,42 +246,20 @@ class InferenceScheduler:
with self._lock:
self.waiting_queue[:0] = failed
def _execute_prefill(self) -> None:
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if not to_prefill:
return
for t in to_prefill:
prompt_len = len(t.prompt_ids)
t.input_tokens = prompt_len
t.output_tokens = 0
groups: Dict[int, List[Task]] = {}
for t in to_prefill:
groups.setdefault(len(t.prompt_ids), []).append(t)
for prompt_len, group in groups.items():
self._execute_prefill_batch(group, prompt_len)
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
def _execute_prefill(
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
) -> None:
tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks)
input_ids = torch.zeros(
batch_sz,
prompt_len,
dtype=torch.long,
device=self.device,
)
input_mask = torch.ones(
batch_sz,
prompt_len,
dtype=torch.bool,
device=self.device,
)
seq_len = prompt_len - start_pos
input_ids = torch.zeros(batch_sz, seq_len, dtype=torch.long, device=self.device)
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
page_tables = self._make_page_table_tensor(tasks)
@ -270,10 +267,14 @@ class InferenceScheduler:
self.model(
input_ids,
input_mask=input_mask,
start_pos=0,
start_pos=start_pos,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
)
start_logical_page = start_pos // self.page_size
for t in tasks:
self._record_page_hashes(t, start_logical_page=start_logical_page)
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks:
return
@ -349,7 +350,19 @@ class InferenceScheduler:
self._task_event.wait(timeout=1.0)
continue
self._execute_prefill()
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if to_prefill:
for t in to_prefill:
t.input_tokens = len(t.prompt_ids)
groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill:
key = (len(t.prompt_ids), t._prefix_cached_tokens)
groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items():
if start_pos < prompt_len:
self._execute_prefill(group, prompt_len, start_pos)
pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks:

View File

@ -5,6 +5,7 @@ import time
from unittest.mock import MagicMock, patch
import pytest
import torch
from astrai.inference.scheduler import InferenceScheduler
@ -19,6 +20,9 @@ def mock_model_and_tokenizer():
mock_model.config.dim = 128
mock_model.config.n_layers = 2
mock_model.config.max_len = 100
mock_model.parameters.return_value = iter(
[MagicMock(dtype=torch.float32, device=torch.device("cpu"))]
)
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]