From 3583c46b6633a870336fd4b28e65ec55c1edf6cf Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 9 May 2026 23:53:23 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=8E=A8=E7=90=86=E5=BC=95=E6=93=8E?= =?UTF-8?q?=E5=89=8D=E7=BC=80=E7=BC=93=E5=AD=98=EF=BC=88KV=20cache=20?= =?UTF-8?q?=E5=A4=8D=E7=94=A8=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - cache.py: 新增模块级 page_hash() 多项式滚动哈希函数;PagedCache 新增 record_page/lookup_prefix/inc_ref,free() 自动清理哈希映射 - scheduler.py: Task 新增 _prefix_cached_tokens;_refill_active_batch 先查 缓存命中页(inc_ref)再分配剩余页;合并 _execute_prefill 为单一方法, 按 (prompt_len, start_pos) 分组批量执行全量/部分 prefill; _record_page_hashes 注册完整页哈希;修复 device/dtype 默认值从硬编码 改为 None(自动检测模型设备) - test: mock model 补充 dtype/device 适配自动检测 --- astrai/inference/cache.py | 41 ++++++++- astrai/inference/scheduler.py | 91 +++++++++++-------- tests/inference/test_scheduler_concurrency.py | 4 + 3 files changed, 96 insertions(+), 40 deletions(-) diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index 08812c4..1b17d14 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -4,7 +4,7 @@ Provides: - PagedCache: paged KV cache combining page pool and tensor storage. """ -from typing import List, Tuple +from typing import Dict, List, Tuple import torch from torch import Tensor @@ -12,12 +12,22 @@ from torch import Tensor STOP = object() +def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int: + start = page_idx * page_size + end = min(start + page_size, len(token_ids)) + h = 0 + for i in range(start, end): + h = (h * 31 + token_ids[i]) & 0xFFFFFFFFFFFFFFFF + return h + + class PagedCache: """Paged KV cache with page-table-indirected read/write. Combines: - Page pool (ref-counted alloc/free via bitmask) - KV tensor storage (k_cache, v_cache) + - Prefix-cache hash lookup (page_content_hash -> physical_page_idx) Call :meth:`bind` to obtain a batch view for the attention layers. """ @@ -45,6 +55,32 @@ class PagedCache: device=device, dtype=dtype, ) + self._page_to_hash: Dict[int, int] = {} + self._hash_to_page: Dict[int, int] = {} + + def record_page( + self, page_idx: int, token_ids: List[int], logical_page_idx: int + ) -> None: + h = page_hash(token_ids, logical_page_idx, self.page_size) + old_h = self._page_to_hash.pop(page_idx, None) + if old_h is not None: + self._hash_to_page.pop(old_h, None) + self._page_to_hash[page_idx] = h + self._hash_to_page[h] = page_idx + + def lookup_prefix(self, token_ids: List[int]) -> List[int]: + full_pages = len(token_ids) // self.page_size + hits: List[int] = [] + for i in range(full_pages): + h = page_hash(token_ids, i, self.page_size) + p = self._hash_to_page.get(h) + if p is None: + break + hits.append(p) + return hits + + def inc_ref(self, idx: int) -> None: + self._refs[idx] += 1 def alloc(self) -> int: lsb = self._free_mask & -self._free_mask @@ -68,6 +104,9 @@ class PagedCache: self._refs[idx] -= 1 if self._refs[idx] == 0: self._free_mask |= 1 << idx + h = self._page_to_hash.pop(idx, None) + if h is not None: + self._hash_to_page.pop(h, None) def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": return CacheView(self, page_table, total_len) diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index c4979d0..c501cc5 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -5,7 +5,7 @@ import threading import time import uuid from enum import Enum -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import torch from torch import Tensor @@ -53,6 +53,7 @@ class Task: self.output_tokens: int = 0 self.page_table: List[int] = [] self.n_pages: int = 0 + self._prefix_cached_tokens: int = 0 self.arrival_time = time.time() self.finish_time: Optional[float] = None self.stream_callback = stream_callback @@ -88,8 +89,8 @@ class InferenceScheduler: max_seq_len: Optional[int] = None, max_prompt_len: int = 512, page_size: int = 64, - device: str = "cuda", - dtype: torch.dtype = torch.bfloat16, + device: Optional[str] = None, + dtype: Optional[torch.dtype] = None, ): config = model.config @@ -180,6 +181,11 @@ class InferenceScheduler: for idx in indices: self.page_cache.free(idx) + def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None: + full_pages = len(task.prompt_ids) // self.page_size + for i in range(start_logical_page, full_pages): + self.page_cache.record_page(task.page_table[i], task.prompt_ids, i) + def _remove_finished_tasks(self) -> None: finished = [] for task in self.active_tasks: @@ -214,12 +220,25 @@ class InferenceScheduler: failed: List[Task] = [] for task in to_add: prompt_len = len(task.prompt_ids) - n_pages = self._n_pages_for(prompt_len) - task.page_table = self.page_cache.alloc_n(n_pages) - if not task.page_table: + + hit_pages = self.page_cache.lookup_prefix(task.prompt_ids) + cached_tokens = len(hit_pages) * self.page_size + for p in hit_pages: + self.page_cache.inc_ref(p) + + remaining = prompt_len - cached_tokens + n_new = self._n_pages_for(remaining) if remaining > 0 else 0 + new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else [] + + if remaining > 0 and not new_pages: + for p in hit_pages: + self.page_cache.free(p) failed.append(task) continue + + task.page_table = hit_pages + new_pages task.n_pages = len(task.page_table) + task._prefix_cached_tokens = cached_tokens task.status = TaskStatus.RUNNING self.active_tasks.append(task) @@ -227,42 +246,20 @@ class InferenceScheduler: with self._lock: self.waiting_queue[:0] = failed - def _execute_prefill(self) -> None: - to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] - if not to_prefill: - return - - for t in to_prefill: - prompt_len = len(t.prompt_ids) - t.input_tokens = prompt_len - t.output_tokens = 0 - - groups: Dict[int, List[Task]] = {} - for t in to_prefill: - groups.setdefault(len(t.prompt_ids), []).append(t) - - for prompt_len, group in groups.items(): - self._execute_prefill_batch(group, prompt_len) - - def _execute_prefill_batch(self, tasks: List[Task], prompt_len: int) -> None: + def _execute_prefill( + self, tasks: List[Task], prompt_len: int, start_pos: int = 0 + ) -> None: tasks = sorted(tasks, key=lambda t: t.task_id) batch_sz = len(tasks) - input_ids = torch.zeros( - batch_sz, - prompt_len, - dtype=torch.long, - device=self.device, - ) - input_mask = torch.ones( - batch_sz, - prompt_len, - dtype=torch.bool, - device=self.device, - ) + seq_len = prompt_len - start_pos + input_ids = torch.zeros(batch_sz, seq_len, dtype=torch.long, device=self.device) + input_mask = torch.ones(batch_sz, seq_len, dtype=torch.bool, device=self.device) for i, t in enumerate(tasks): - input_ids[i] = torch.tensor(t.prompt_ids, device=self.device) + input_ids[i] = torch.tensor( + t.prompt_ids[start_pos:prompt_len], device=self.device + ) page_tables = self._make_page_table_tensor(tasks) @@ -270,10 +267,14 @@ class InferenceScheduler: self.model( input_ids, input_mask=input_mask, - start_pos=0, + start_pos=start_pos, paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) + start_logical_page = start_pos // self.page_size + for t in tasks: + self._record_page_hashes(t, start_logical_page=start_logical_page) + def _execute_decode(self, tasks: List[Task], start_pos: int) -> None: if not tasks: return @@ -349,7 +350,19 @@ class InferenceScheduler: self._task_event.wait(timeout=1.0) continue - self._execute_prefill() + to_prefill = [t for t in self.active_tasks if t.output_tokens == 0] + if to_prefill: + for t in to_prefill: + t.input_tokens = len(t.prompt_ids) + + groups: Dict[Tuple[int, int], List[Task]] = {} + for t in to_prefill: + key = (len(t.prompt_ids), t._prefix_cached_tokens) + groups.setdefault(key, []).append(t) + + for (prompt_len, start_pos), group in groups.items(): + if start_pos < prompt_len: + self._execute_prefill(group, prompt_len, start_pos) pos_groups: Dict[int, List[Task]] = {} for t in self.active_tasks: diff --git a/tests/inference/test_scheduler_concurrency.py b/tests/inference/test_scheduler_concurrency.py index 3d771c3..cc7a9d2 100644 --- a/tests/inference/test_scheduler_concurrency.py +++ b/tests/inference/test_scheduler_concurrency.py @@ -5,6 +5,7 @@ import time from unittest.mock import MagicMock, patch import pytest +import torch from astrai.inference.scheduler import InferenceScheduler @@ -19,6 +20,9 @@ def mock_model_and_tokenizer(): mock_model.config.dim = 128 mock_model.config.n_layers = 2 mock_model.config.max_len = 100 + mock_model.parameters.return_value = iter( + [MagicMock(dtype=torch.float32, device=torch.device("cpu"))] + ) mock_tokenizer = MagicMock() mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]