diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index b0376bc..ab8a3dd 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -1,9 +1,3 @@ -"""Page-based KV cache with page-table-indirected read/write. - -Provides: - - PagedCache: paged KV cache combining page pool and tensor storage. -""" - from typing import Dict, List, Tuple import torch @@ -20,17 +14,7 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int: class PagedCache: - """Paged KV cache with page-table-indirected read/write and persistent prefix caching. - - Combines: - - Page pool (ref-counted alloc/free via bitmask) - - KV tensor storage (k_cache, v_cache) - - Prefix-cache hash lookup (page_content_hash -> physical_page_idx) - - LRU eviction for persistent cross-batch prefix caching - - Pages with recorded hashes persist after refcount reaches 0 (pinned). - They are evicted via LRU only when alloc() finds no free pages. - """ + """Paged KV cache: page pool, prefix-cache lookup, LRU eviction, task-page mapping.""" def __init__( self, @@ -59,6 +43,71 @@ class PagedCache: self._hash_to_page: Dict[int, int] = {} self._lru: List[int] = [] self._pin: List[bool] = [False] * n_pages + self._task_pages: Dict[str, List[int]] = {} + self._task_cached: Dict[str, int] = {} + + def pages_needed(self, n_tokens: int) -> int: + return (n_tokens + self.page_size - 1) // self.page_size + + def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool: + hit_pages = self.lookup_prefix(prompt_ids) + cached_tokens = len(hit_pages) * self.page_size + for p in hit_pages: + self.inc_ref(p) + + remaining = len(prompt_ids) - cached_tokens + n_new = self.pages_needed(remaining) if remaining > 0 else 0 + new_pages = self.alloc_n(n_new) if n_new > 0 else [] + + if remaining > 0 and not new_pages: + for p in hit_pages: + self.free(p) + return False + + page_table = hit_pages + new_pages + self._task_pages[task_id] = page_table + self._task_cached[task_id] = cached_tokens + return True + + def task_free(self, task_id: str) -> None: + page_table = self._task_pages.pop(task_id, None) + self._task_cached.pop(task_id, None) + if page_table: + for idx in page_table: + self.free(idx) + + def task_extend(self, task_id: str, pos: int) -> bool: + needed = self.pages_needed(pos + 1) + page_table = self._task_pages[task_id] + while len(page_table) < needed: + p = self.alloc() + if p < 0: + return False + page_table.append(p) + return True + + def task_cached(self, task_id: str) -> int: + return self._task_cached.get(task_id, 0) + + def task_page_table(self, task_id: str) -> List[int]: + return self._task_pages.get(task_id, []) + + def task_n_pages(self, task_id: str) -> int: + return len(self._task_pages.get(task_id, [])) + + def task_record_hashes( + self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0 + ) -> None: + page_table = self._task_pages[task_id] + full_pages = len(prompt_ids) // self.page_size + for i in range(start_logical_page, full_pages): + self.record_page(page_table[i], prompt_ids, i) + + def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: + states = [self._task_pages.get(tid, []) for tid in task_ids] + max_pages = max((len(s) for s in states), default=0) + rows = [s + [-1] * (max_pages - len(s)) for s in states] + return torch.tensor(rows, dtype=torch.long, device=device) def _touch(self, idx: int) -> None: if self._refs[idx] == 0 and idx in self._lru: @@ -167,23 +216,21 @@ class PagedCache: ] written += chunk - def gather(self, layer_id: int, page_table: Tensor) -> Tuple[Tensor, Tensor]: - # page_table: [batch, max_pages] with -1 padding for tasks with fewer pages. - # clamp(min=0) maps -1 to page 0 (irrelevant data) — truncated by CacheView total_len. + def gather( + self, layer_id: int, page_table: Tensor, total_len: int + ) -> Tuple[Tensor, Tensor]: safe = page_table.clamp(min=0) k = self.k_cache[layer_id, safe] v = self.v_cache[layer_id, safe] k = k.flatten(1, 2) v = v.flatten(1, 2) + k = k[:, :total_len] + v = v[:, :total_len] return k, v class CacheView: - """Per-batch view that bundles PagedCache + page_table + total_len. - - Attention layers receive this as ``paged_cache`` and only see - ``write()`` / ``gather()``, never raw page tables or length params. - """ + """Bundles PagedCache + page_table + total_len for attention layers.""" __slots__ = ("_cache", "_page_table", "_total_len") @@ -196,8 +243,4 @@ class CacheView: self._cache.write(layer_id, self._page_table, start_pos, k, v) def gather(self, layer_id: int) -> Tuple[Tensor, Tensor]: - k, v = self._cache.gather(layer_id, self._page_table) - if self._total_len: - k = k[:, : self._total_len] - v = v[:, : self._total_len] - return k, v + return self._cache.gather(layer_id, self._page_table, self._total_len) diff --git a/astrai/inference/executor.py b/astrai/inference/executor.py index 3997dd6..8239771 100644 --- a/astrai/inference/executor.py +++ b/astrai/inference/executor.py @@ -2,7 +2,6 @@ import logging from typing import List, Optional import torch -from torch import Tensor from astrai.inference.cache import PagedCache from astrai.inference.sample import sample @@ -14,6 +13,8 @@ logger = logging.getLogger(__name__) class Executor: + """Model forward passes for prefill and decode; delegates page ops to PagedCache.""" + def __init__( self, model: AutoModel, @@ -31,34 +32,13 @@ class Executor: self.dtype = dtype or next(model.parameters()).dtype def allocate_pages_for_activation(self, task: Task) -> bool: - prompt_len = len(task.prompt_ids) - hit_pages = self.page_cache.lookup_prefix(task.prompt_ids) - cached_tokens = len(hit_pages) * self.page_size - for p in hit_pages: - self.page_cache.inc_ref(p) - - remaining = prompt_len - cached_tokens - n_new = self._n_pages_for(remaining) if remaining > 0 else 0 - new_pages = self.page_cache.alloc_n(n_new) if n_new > 0 else [] - - if remaining > 0 and not new_pages: - for p in hit_pages: - self.page_cache.free(p) - return False - - task.page_table = hit_pages + new_pages - task.n_pages = len(task.page_table) - task._prefix_cached_tokens = cached_tokens - return True + return self.page_cache.task_alloc(task.task_id, task.prompt_ids) def free_task_pages(self, task: Task) -> None: - if task._pages_freed: - return - for idx in task.page_table: - self.page_cache.free(idx) - task.page_table.clear() - task.n_pages = 0 - task._pages_freed = True + self.page_cache.task_free(task.task_id) + + def get_cached_tokens(self, task: Task) -> int: + return self.page_cache.task_cached(task.task_id) def execute_prefill( self, tasks: List[Task], prompt_len: int, start_pos: int = 0 @@ -80,7 +60,8 @@ class Executor: t.prompt_ids[start_pos:prompt_len], device=self.device ) - page_tables = self._make_page_table_tensor(tasks) + task_ids = [t.task_id for t in tasks] + page_tables = self.page_cache.make_table_tensor(task_ids, self.device) with torch.inference_mode(): self.model( @@ -92,7 +73,9 @@ class Executor: start_logical_page = start_pos // self.page_size for t in tasks: - self._record_page_hashes(t, start_logical_page=start_logical_page) + self.page_cache.task_record_hashes( + t.task_id, t.prompt_ids, start_logical_page=start_logical_page + ) def execute_decode(self, tasks: List[Task], start_pos: int) -> None: if not tasks: @@ -102,7 +85,7 @@ class Executor: valid: List[Task] = [] for t in tasks: - if self._maybe_alloc_page(t, start_pos): + if self.page_cache.task_extend(t.task_id, start_pos): valid.append(t) else: t.status = TaskStatus.ABORTED @@ -123,7 +106,8 @@ class Executor: active_mask = torch.ones((batch_sz, 1), dtype=torch.bool, device=self.device) - page_tables = self._make_page_table_tensor(tasks) + task_ids = [t.task_id for t in tasks] + page_tables = self.page_cache.make_table_tensor(task_ids, self.device) total_len = start_pos + 1 temperatures = torch.tensor([t.temperature for t in tasks], device=self.device) @@ -150,7 +134,7 @@ class Executor: t.output_ids.append(ntok) t.output_tokens += 1 pos = t.input_tokens + t.output_tokens - self._maybe_alloc_page(t, pos) + self.page_cache.task_extend(t.task_id, pos) if t.stream_callback: t.stream_callback(self.tokenizer.decode([ntok])) @@ -158,26 +142,3 @@ class Executor: if t.is_finished(self.tokenizer.stop_ids): if t.stream_callback: t.stream_callback(STOP) - - def _n_pages_for(self, n_tokens: int) -> int: - return (n_tokens + self.page_size - 1) // self.page_size - - def _make_page_table_tensor(self, tasks: List[Task]) -> Tensor: - max_pages = max(t.n_pages for t in tasks) - rows = [t.page_table + [-1] * (max_pages - t.n_pages) for t in tasks] - return torch.tensor(rows, dtype=torch.long, device=self.device) - - def _record_page_hashes(self, task: Task, start_logical_page: int = 0) -> None: - full_pages = len(task.prompt_ids) // self.page_size - for i in range(start_logical_page, full_pages): - self.page_cache.record_page(task.page_table[i], task.prompt_ids, i) - - def _maybe_alloc_page(self, task: Task, pos: int) -> bool: - needed = self._n_pages_for(pos + 1) - while task.n_pages < needed: - p = self.page_cache.alloc() - if p < 0: - return False - task.page_table.append(p) - task.n_pages += 1 - return True diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 32db10f..06d1046 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -14,6 +14,8 @@ logger = logging.getLogger(__name__) class InferenceScheduler: + """Four-phase continuous batching loop: cleanup -> refill -> prefill -> decode.""" + def __init__( self, model: AutoModel, @@ -112,7 +114,7 @@ class InferenceScheduler: groups: Dict[Tuple[int, int], List[Task]] = {} for t in to_prefill: - key = (len(t.prompt_ids), t._prefix_cached_tokens) + key = (len(t.prompt_ids), self._executor.get_cached_tokens(t)) groups.setdefault(key, []).append(t) for (prompt_len, start_pos), group in groups.items(): diff --git a/astrai/inference/task.py b/astrai/inference/task.py index 8091692..76a571e 100644 --- a/astrai/inference/task.py +++ b/astrai/inference/task.py @@ -13,6 +13,8 @@ STOP = object() class TaskStatus(Enum): + """Task lifecycle states.""" + PENDING = "pending" RUNNING = "running" FINISHED = "finished" @@ -20,6 +22,8 @@ class TaskStatus(Enum): class Task: + """Single generation request: prompt, sampling params, output state.""" + def __init__( self, task_id: str, @@ -41,13 +45,9 @@ class Task: self.output_ids: List[int] = [] self.input_tokens: int = 0 self.output_tokens: int = 0 - self.page_table: List[int] = [] - self.n_pages: int = 0 - self._prefix_cached_tokens: int = 0 self.arrival_time = time.time() self.finish_time: Optional[float] = None self.stream_callback = stream_callback - self._pages_freed: bool = False @property def next_pos(self) -> int: @@ -62,6 +62,8 @@ class Task: class TaskManager: + """Thread-safe task queues and lifecycle transitions (no page ops).""" + def __init__( self, tokenizer: AutoTokenizer,