feat: 推理引擎前缀缓存(KV cache 复用)

- cache.py: 新增模块级 page_hash() 多项式滚动哈希函数;PagedCache 新增
  record_page/lookup_prefix/inc_ref,free() 自动清理哈希映射
- scheduler.py: Task 新增 _prefix_cached_tokens;_refill_active_batch 先查
  缓存命中页(inc_ref)再分配剩余页;合并 _execute_prefill 为单一方法,
  按 (prompt_len, start_pos) 分组批量执行全量/部分 prefill;
  _record_page_hashes 注册完整页哈希;修复 device/dtype 默认值从硬编码
  改为 None(自动检测模型设备)
- test: mock model 补充 dtype/device 适配自动检测
This commit is contained in:
ViperEkura 2026-05-09 23:53:23 +08:00
parent ca4e6b907c
commit 3583c46b66
3 changed files with 96 additions and 40 deletions

View File

@ -4,7 +4,7 @@ Provides:
- PagedCache: paged KV cache combining page pool and tensor storage. - PagedCache: paged KV cache combining page pool and tensor storage.
""" """
from typing import List, Tuple from typing import Dict, List, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
@ -12,12 +12,22 @@ from torch import Tensor
STOP = object() STOP = object()
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
start = page_idx * page_size
end = min(start + page_size, len(token_ids))
h = 0
for i in range(start, end):
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
return h
class PagedCache: class PagedCache:
"""Paged KV cache with page-table-indirected read/write. """Paged KV cache with page-table-indirected read/write.
Combines: Combines:
- Page pool (ref-counted alloc/free via bitmask) - Page pool (ref-counted alloc/free via bitmask)
- KV tensor storage (k_cache, v_cache) - KV tensor storage (k_cache, v_cache)
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
Call :meth:`bind` to obtain a batch view for the attention layers. Call :meth:`bind` to obtain a batch view for the attention layers.
""" """
@ -45,6 +55,32 @@ class PagedCache:
device=device, device=device,
dtype=dtype, dtype=dtype,
) )
self._page_to_hash: Dict[int, int] = {}
self._hash_to_page: Dict[int, int] = {}
def record_page(
self, page_idx: int, token_ids: List[int], logical_page_idx: int
) -> None:
h = page_hash(token_ids, logical_page_idx, self.page_size)
old_h = self._page_to_hash.pop(page_idx, None)
if old_h is not None:
self._hash_to_page.pop(old_h, None)
self._page_to_hash[page_idx] = h
self._hash_to_page[h] = page_idx
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
full_pages = len(token_ids) // self.page_size
hits: List[int] = []
for i in range(full_pages):
h = page_hash(token_ids, i, self.page_size)
p = self._hash_to_page.get(h)
if p is None:
break
hits.append(p)
return hits
def inc_ref(self, idx: int) -> None:
self._refs[idx] += 1
def alloc(self) -> int: def alloc(self) -> int:
lsb = self._free_mask & -self._free_mask lsb = self._free_mask & -self._free_mask
@ -68,6 +104,9 @@ class PagedCache:
self._refs[idx] -= 1 self._refs[idx] -= 1
if self._refs[idx] == 0: if self._refs[idx] == 0:
self._free_mask |= 1 << idx self._free_mask |= 1 << idx
h = self._page_to_hash.pop(idx, None)
if h is not None:
self._hash_to_page.pop(h, None)
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
return CacheView(self, page_table, total_len) return CacheView(self, page_table, total_len)

View File

@ -5,7 +5,7 @@ import threading
import time import time
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
@ -53,6 +53,7 @@ class Task:
self.output_tokens: int = 0 self.output_tokens: int = 0
self.page_table: List[int] = [] self.page_table: List[int] = []
self.n_pages: int = 0 self.n_pages: int = 0
self._prefix_cached_tokens: int = 0
self.arrival_time = time.time() self.arrival_time = time.time()
self.finish_time: Optional[float] = None self.finish_time: Optional[float] = None
self.stream_callback = stream_callback self.stream_callback = stream_callback
@ -88,8 +89,8 @@ class InferenceScheduler:
max_seq_len: Optional[int] = None, max_seq_len: Optional[int] = None,
max_prompt_len: int = 512, max_prompt_len: int = 512,
page_size: int = 64, page_size: int = 64,
device: str = "cuda", device: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16, dtype: Optional[torch.dtype] = None,
): ):
config = model.config config = model.config
@ -180,6 +181,11 @@ class InferenceScheduler:
for idx in indices: for idx in indices:
self.page_cache.free(idx) self.page_cache.free(idx)
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
full_pages = len(task.prompt_ids) // self.page_size
for i in range(start_logical_page, full_pages):
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
def _remove_finished_tasks(self) -> None: def _remove_finished_tasks(self) -> None:
finished = [] finished = []
for task in self.active_tasks: for task in self.active_tasks:
@ -214,12 +220,25 @@ class InferenceScheduler:
failed: List[Task] = [] failed: List[Task] = []
for task in to_add: for task in to_add:
prompt_len = len(task.prompt_ids) prompt_len = len(task.prompt_ids)
n_pages = self._n_pages_for(prompt_len)
task.page_table = self.page_cache.alloc_n(n_pages) hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
if not task.page_table: cached_tokens = len(hit_pages) * self.page_size
for p in hit_pages:
self.page_cache.inc_ref(p)
remaining = prompt_len - cached_tokens
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
if remaining > 0 and not new_pages:
for p in hit_pages:
self.page_cache.free(p)
failed.append(task) failed.append(task)
continue continue
task.page_table = hit_pages + new_pages
task.n_pages = len(task.page_table) task.n_pages = len(task.page_table)
task._prefix_cached_tokens = cached_tokens
task.status = TaskStatus.RUNNING task.status = TaskStatus.RUNNING
self.active_tasks.append(task) self.active_tasks.append(task)
@ -227,42 +246,20 @@ class InferenceScheduler:
with self._lock: with self._lock:
self.waiting_queue[:0] = failed self.waiting_queue[:0] = failed
def _execute_prefill(self) -> None: def _execute_prefill(
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] self, tasks: List[Task], prompt_len: int, start_pos: int = 0
if not to_prefill: ) -> None:
return
for t in to_prefill:
prompt_len = len(t.prompt_ids)
t.input_tokens = prompt_len
t.output_tokens = 0
groups: Dict[int, List[Task]] = {}
for t in to_prefill:
groups.setdefault(len(t.prompt_ids), []).append(t)
for prompt_len, group in groups.items():
self._execute_prefill_batch(group, prompt_len)
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
tasks = sorted(tasks, key=lambda t: t.task_id) tasks = sorted(tasks, key=lambda t: t.task_id)
batch_sz = len(tasks) batch_sz = len(tasks)
input_ids = torch.zeros( seq_len = prompt_len - start_pos
batch_sz, input_ids = torch.zeros(batch_sz, seq_len, dtype=torch.long, device=self.device)
prompt_len, input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
dtype=torch.long,
device=self.device,
)
input_mask = torch.ones(
batch_sz,
prompt_len,
dtype=torch.bool,
device=self.device,
)
for i, t in enumerate(tasks): for i, t in enumerate(tasks):
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device) input_ids[i] = torch.tensor(
t.prompt_ids[start_pos:prompt_len], device=self.device
)
page_tables = self._make_page_table_tensor(tasks) page_tables = self._make_page_table_tensor(tasks)
@ -270,10 +267,14 @@ class InferenceScheduler:
self.model( self.model(
input_ids, input_ids,
input_mask=input_mask, input_mask=input_mask,
start_pos=0, start_pos=start_pos,
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
) )
start_logical_page = start_pos // self.page_size
for t in tasks:
self._record_page_hashes(t, start_logical_page=start_logical_page)
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
if not tasks: if not tasks:
return return
@ -349,7 +350,19 @@ class InferenceScheduler:
self._task_event.wait(timeout=1.0) self._task_event.wait(timeout=1.0)
continue continue
self._execute_prefill() to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
if to_prefill:
for t in to_prefill:
t.input_tokens = len(t.prompt_ids)
groups: Dict[Tuple[int, int], List[Task]] = {}
for t in to_prefill:
key = (len(t.prompt_ids), t._prefix_cached_tokens)
groups.setdefault(key, []).append(t)
for (prompt_len, start_pos), group in groups.items():
if start_pos < prompt_len:
self._execute_prefill(group, prompt_len, start_pos)
pos_groups: Dict[int, List[Task]] = {} pos_groups: Dict[int, List[Task]] = {}
for t in self.active_tasks: for t in self.active_tasks:

View File

@ -5,6 +5,7 @@ import time
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
import torch
from astrai.inference.scheduler import InferenceScheduler from astrai.inference.scheduler import InferenceScheduler
@ -19,6 +20,9 @@ def mock_model_and_tokenizer():
mock_model.config.dim = 128 mock_model.config.dim = 128
mock_model.config.n_layers = 2 mock_model.config.n_layers = 2
mock_model.config.max_len = 100 mock_model.config.max_len = 100
mock_model.parameters.return_value = iter(
[MagicMock(dtype=torch.float32, device=torch.device("cpu"))]
)
mock_tokenizer = MagicMock() mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5] mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]