feat: 推理引擎前缀缓存(KV cache 复用)
- cache.py: 新增模块级 page_hash() 多项式滚动哈希函数;PagedCache 新增 record_page/lookup_prefix/inc_ref,free() 自动清理哈希映射 - scheduler.py: Task 新增 _prefix_cached_tokens;_refill_active_batch 先查 缓存命中页(inc_ref)再分配剩余页;合并 _execute_prefill 为单一方法, 按 (prompt_len, start_pos) 分组批量执行全量/部分 prefill; _record_page_hashes 注册完整页哈希;修复 device/dtype 默认值从硬编码 改为 None(自动检测模型设备) - test: mock model 补充 dtype/device 适配自动检测
This commit is contained in:
parent
ca4e6b907c
commit
3583c46b66
|
|
@ -4,7 +4,7 @@ Provides:
|
||||||
- PagedCache: paged KV cache combining page pool and tensor storage.
|
- PagedCache: paged KV cache combining page pool and tensor storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
@ -12,12 +12,22 @@ from torch import Tensor
|
||||||
STOP = object()
|
STOP = object()
|
||||||
|
|
||||||
|
|
||||||
|
def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int:
|
||||||
|
start = page_idx * page_size
|
||||||
|
end = min(start + page_size, len(token_ids))
|
||||||
|
h = 0
|
||||||
|
for i in range(start, end):
|
||||||
|
h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
class PagedCache:
|
class PagedCache:
|
||||||
"""Paged KV cache with page-table-indirected read/write.
|
"""Paged KV cache with page-table-indirected read/write.
|
||||||
|
|
||||||
Combines:
|
Combines:
|
||||||
- Page pool (ref-counted alloc/free via bitmask)
|
- Page pool (ref-counted alloc/free via bitmask)
|
||||||
- KV tensor storage (k_cache, v_cache)
|
- KV tensor storage (k_cache, v_cache)
|
||||||
|
- Prefix-cache hash lookup (page_content_hash -> physical_page_idx)
|
||||||
|
|
||||||
Call :meth:`bind` to obtain a batch view for the attention layers.
|
Call :meth:`bind` to obtain a batch view for the attention layers.
|
||||||
"""
|
"""
|
||||||
|
|
@ -45,6 +55,32 @@ class PagedCache:
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
self._page_to_hash: Dict[int, int] = {}
|
||||||
|
self._hash_to_page: Dict[int, int] = {}
|
||||||
|
|
||||||
|
def record_page(
|
||||||
|
self, page_idx: int, token_ids: List[int], logical_page_idx: int
|
||||||
|
) -> None:
|
||||||
|
h = page_hash(token_ids, logical_page_idx, self.page_size)
|
||||||
|
old_h = self._page_to_hash.pop(page_idx, None)
|
||||||
|
if old_h is not None:
|
||||||
|
self._hash_to_page.pop(old_h, None)
|
||||||
|
self._page_to_hash[page_idx] = h
|
||||||
|
self._hash_to_page[h] = page_idx
|
||||||
|
|
||||||
|
def lookup_prefix(self, token_ids: List[int]) -> List[int]:
|
||||||
|
full_pages = len(token_ids) // self.page_size
|
||||||
|
hits: List[int] = []
|
||||||
|
for i in range(full_pages):
|
||||||
|
h = page_hash(token_ids, i, self.page_size)
|
||||||
|
p = self._hash_to_page.get(h)
|
||||||
|
if p is None:
|
||||||
|
break
|
||||||
|
hits.append(p)
|
||||||
|
return hits
|
||||||
|
|
||||||
|
def inc_ref(self, idx: int) -> None:
|
||||||
|
self._refs[idx] += 1
|
||||||
|
|
||||||
def alloc(self) -> int:
|
def alloc(self) -> int:
|
||||||
lsb = self._free_mask & -self._free_mask
|
lsb = self._free_mask & -self._free_mask
|
||||||
|
|
@ -68,6 +104,9 @@ class PagedCache:
|
||||||
self._refs[idx] -= 1
|
self._refs[idx] -= 1
|
||||||
if self._refs[idx] == 0:
|
if self._refs[idx] == 0:
|
||||||
self._free_mask |= 1 << idx
|
self._free_mask |= 1 << idx
|
||||||
|
h = self._page_to_hash.pop(idx, None)
|
||||||
|
if h is not None:
|
||||||
|
self._hash_to_page.pop(h, None)
|
||||||
|
|
||||||
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView":
|
||||||
return CacheView(self, page_table, total_len)
|
return CacheView(self, page_table, total_len)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import threading
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
@ -53,6 +53,7 @@ class Task:
|
||||||
self.output_tokens: int = 0
|
self.output_tokens: int = 0
|
||||||
self.page_table: List[int] = []
|
self.page_table: List[int] = []
|
||||||
self.n_pages: int = 0
|
self.n_pages: int = 0
|
||||||
|
self._prefix_cached_tokens: int = 0
|
||||||
self.arrival_time = time.time()
|
self.arrival_time = time.time()
|
||||||
self.finish_time: Optional[float] = None
|
self.finish_time: Optional[float] = None
|
||||||
self.stream_callback = stream_callback
|
self.stream_callback = stream_callback
|
||||||
|
|
@ -88,8 +89,8 @@ class InferenceScheduler:
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
max_prompt_len: int = 512,
|
max_prompt_len: int = 512,
|
||||||
page_size: int = 64,
|
page_size: int = 64,
|
||||||
device: str = "cuda",
|
device: Optional[str] = None,
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: Optional[torch.dtype] = None,
|
||||||
):
|
):
|
||||||
config = model.config
|
config = model.config
|
||||||
|
|
||||||
|
|
@ -180,6 +181,11 @@ class InferenceScheduler:
|
||||||
for idx in indices:
|
for idx in indices:
|
||||||
self.page_cache.free(idx)
|
self.page_cache.free(idx)
|
||||||
|
|
||||||
|
def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None:
|
||||||
|
full_pages = len(task.prompt_ids) // self.page_size
|
||||||
|
for i in range(start_logical_page, full_pages):
|
||||||
|
self.page_cache.record_page(task.page_table[i], task.prompt_ids, i)
|
||||||
|
|
||||||
def _remove_finished_tasks(self) -> None:
|
def _remove_finished_tasks(self) -> None:
|
||||||
finished = []
|
finished = []
|
||||||
for task in self.active_tasks:
|
for task in self.active_tasks:
|
||||||
|
|
@ -214,12 +220,25 @@ class InferenceScheduler:
|
||||||
failed: List[Task] = []
|
failed: List[Task] = []
|
||||||
for task in to_add:
|
for task in to_add:
|
||||||
prompt_len = len(task.prompt_ids)
|
prompt_len = len(task.prompt_ids)
|
||||||
n_pages = self._n_pages_for(prompt_len)
|
|
||||||
task.page_table = self.page_cache.alloc_n(n_pages)
|
hit_pages = self.page_cache.lookup_prefix(task.prompt_ids)
|
||||||
if not task.page_table:
|
cached_tokens = len(hit_pages) * self.page_size
|
||||||
|
for p in hit_pages:
|
||||||
|
self.page_cache.inc_ref(p)
|
||||||
|
|
||||||
|
remaining = prompt_len - cached_tokens
|
||||||
|
n_new = self._n_pages_for(remaining) if remaining > 0 else 0
|
||||||
|
new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else []
|
||||||
|
|
||||||
|
if remaining > 0 and not new_pages:
|
||||||
|
for p in hit_pages:
|
||||||
|
self.page_cache.free(p)
|
||||||
failed.append(task)
|
failed.append(task)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
task.page_table = hit_pages + new_pages
|
||||||
task.n_pages = len(task.page_table)
|
task.n_pages = len(task.page_table)
|
||||||
|
task._prefix_cached_tokens = cached_tokens
|
||||||
task.status = TaskStatus.RUNNING
|
task.status = TaskStatus.RUNNING
|
||||||
self.active_tasks.append(task)
|
self.active_tasks.append(task)
|
||||||
|
|
||||||
|
|
@ -227,42 +246,20 @@ class InferenceScheduler:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self.waiting_queue[:0] = failed
|
self.waiting_queue[:0] = failed
|
||||||
|
|
||||||
def _execute_prefill(self) -> None:
|
def _execute_prefill(
|
||||||
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
|
self, tasks: List[Task], prompt_len: int, start_pos: int = 0
|
||||||
if not to_prefill:
|
) -> None:
|
||||||
return
|
|
||||||
|
|
||||||
for t in to_prefill:
|
|
||||||
prompt_len = len(t.prompt_ids)
|
|
||||||
t.input_tokens = prompt_len
|
|
||||||
t.output_tokens = 0
|
|
||||||
|
|
||||||
groups: Dict[int, List[Task]] = {}
|
|
||||||
for t in to_prefill:
|
|
||||||
groups.setdefault(len(t.prompt_ids), []).append(t)
|
|
||||||
|
|
||||||
for prompt_len, group in groups.items():
|
|
||||||
self._execute_prefill_batch(group, prompt_len)
|
|
||||||
|
|
||||||
def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None:
|
|
||||||
tasks = sorted(tasks, key=lambda t: t.task_id)
|
tasks = sorted(tasks, key=lambda t: t.task_id)
|
||||||
batch_sz = len(tasks)
|
batch_sz = len(tasks)
|
||||||
|
|
||||||
input_ids = torch.zeros(
|
seq_len = prompt_len - start_pos
|
||||||
batch_sz,
|
input_ids = torch.zeros(batch_sz, seq_len, dtype=torch.long, device=self.device)
|
||||||
prompt_len,
|
input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device)
|
||||||
dtype=torch.long,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
input_mask = torch.ones(
|
|
||||||
batch_sz,
|
|
||||||
prompt_len,
|
|
||||||
dtype=torch.bool,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, t in enumerate(tasks):
|
for i, t in enumerate(tasks):
|
||||||
input_ids[i] = torch.tensor(t.prompt_ids, device=self.device)
|
input_ids[i] = torch.tensor(
|
||||||
|
t.prompt_ids[start_pos:prompt_len], device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
page_tables = self._make_page_table_tensor(tasks)
|
page_tables = self._make_page_table_tensor(tasks)
|
||||||
|
|
||||||
|
|
@ -270,10 +267,14 @@ class InferenceScheduler:
|
||||||
self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
start_pos=0,
|
start_pos=start_pos,
|
||||||
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
start_logical_page = start_pos // self.page_size
|
||||||
|
for t in tasks:
|
||||||
|
self._record_page_hashes(t, start_logical_page=start_logical_page)
|
||||||
|
|
||||||
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
def _execute_decode(self, tasks: List[Task], start_pos: int) -> None:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
return
|
return
|
||||||
|
|
@ -349,7 +350,19 @@ class InferenceScheduler:
|
||||||
self._task_event.wait(timeout=1.0)
|
self._task_event.wait(timeout=1.0)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self._execute_prefill()
|
to_prefill = [t for t in self.active_tasks if t.output_tokens == 0]
|
||||||
|
if to_prefill:
|
||||||
|
for t in to_prefill:
|
||||||
|
t.input_tokens = len(t.prompt_ids)
|
||||||
|
|
||||||
|
groups: Dict[Tuple[int, int], List[Task]] = {}
|
||||||
|
for t in to_prefill:
|
||||||
|
key = (len(t.prompt_ids), t._prefix_cached_tokens)
|
||||||
|
groups.setdefault(key, []).append(t)
|
||||||
|
|
||||||
|
for (prompt_len, start_pos), group in groups.items():
|
||||||
|
if start_pos < prompt_len:
|
||||||
|
self._execute_prefill(group, prompt_len, start_pos)
|
||||||
|
|
||||||
pos_groups: Dict[int, List[Task]] = {}
|
pos_groups: Dict[int, List[Task]] = {}
|
||||||
for t in self.active_tasks:
|
for t in self.active_tasks:
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import time
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
|
|
||||||
|
|
@ -19,6 +20,9 @@ def mock_model_and_tokenizer():
|
||||||
mock_model.config.dim = 128
|
mock_model.config.dim = 128
|
||||||
mock_model.config.n_layers = 2
|
mock_model.config.n_layers = 2
|
||||||
mock_model.config.max_len = 100
|
mock_model.config.max_len = 100
|
||||||
|
mock_model.parameters.return_value = iter(
|
||||||
|
[MagicMock(dtype=torch.float32, device=torch.device("cpu"))]
|
||||||
|
)
|
||||||
|
|
||||||
mock_tokenizer = MagicMock()
|
mock_tokenizer = MagicMock()
|
||||||
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue