diff --git a/astrai/inference/cache.py b/astrai/inference/cache.py index ab8a3dd..968516f 100644 --- a/astrai/inference/cache.py +++ b/astrai/inference/cache.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Tuple +from collections import OrderedDict +from typing import Callable, Dict, List, Optional, Tuple import torch from torch import Tensor @@ -13,8 +14,136 @@ def page_hash(token_ids: List[int], page_idx: int, page_size: int) -> int: return h +class PagePool: + """Bitmask page allocator with ref-counting and LRU eviction.""" + + def __init__(self, n_pages: int, on_evict: Optional[Callable[[int], None]] = None): + self._free_mask = (1 << n_pages) - 1 + self._refs: List[int] = [0] * n_pages + self._lru: OrderedDict[int, None] = OrderedDict() + self._on_evict = on_evict + + def alloc(self) -> int: + if self._free_mask: + lsb = self._free_mask & -self._free_mask + idx = lsb.bit_length() - 1 + self._free_mask ^= lsb + self._refs[idx] = 1 + return idx + if self._lru: + idx, _ = self._lru.popitem(last=False) + if self._on_evict: + self._on_evict(idx) + self._refs[idx] = 1 + self._free_mask &= ~(1 << idx) + return idx + return -1 + + def free(self, idx: int, keep_cached: bool = False) -> None: + self._refs[idx] -= 1 + if self._refs[idx] == 0: + if keep_cached: + self._lru[idx] = None + else: + self._free_mask |= 1 << idx + + def inc_ref(self, idx: int) -> None: + self._refs[idx] += 1 + + def touch(self, idx: int) -> None: + self._lru.move_to_end(idx) + + def remove_from_lru(self, idx: int) -> None: + self._lru.pop(idx, None) + + +class PrefixCache: + """Hash-based prefix matching: maps page hashes to physical page indices.""" + + def __init__(self, page_size: int): + self._page_size = page_size + self._page_to_hash: Dict[int, int] = {} + self._hash_to_page: Dict[int, int] = {} + + def on_evict(self, idx: int) -> None: + h = self._page_to_hash.pop(idx, None) + if h is not None: + self._hash_to_page.pop(h, None) + + def has_page(self, idx: int) -> bool: + return idx in self._page_to_hash + + def lookup(self, token_ids: List[int], pool: PagePool) -> List[int]: + full_pages = len(token_ids) // self._page_size + hits: List[int] = [] + for i in range(full_pages): + h = page_hash(token_ids, i, self._page_size) + p = self._hash_to_page.get(h) + if p is None: + break + pool.touch(p) + hits.append(p) + return hits + + def record( + self, + page_idx: int, + token_ids: List[int], + logical_page_idx: int, + pool: PagePool, + ) -> None: + h = page_hash(token_ids, logical_page_idx, self._page_size) + old_h = self._page_to_hash.pop(page_idx, None) + if old_h is not None: + self._hash_to_page.pop(old_h, None) + self._page_to_hash[page_idx] = h + self._hash_to_page[h] = page_idx + pool.remove_from_lru(page_idx) + + +class TaskTable: + """Maps task_ids to page tables and cached token counts.""" + + def __init__(self, pool: PagePool, page_size: int): + self._pool = pool + self._page_size = page_size + self._pages: Dict[str, List[int]] = {} + self._cached: Dict[str, int] = {} + + def set(self, task_id: str, page_table: List[int], cached: int) -> None: + self._pages[task_id] = page_table + self._cached[task_id] = cached + + def get(self, task_id: str) -> List[int]: + return self._pages.get(task_id, []) + + def get_cached(self, task_id: str) -> int: + return self._cached.get(task_id, 0) + + def pop(self, task_id: str) -> Tuple[List[int], int]: + pages = self._pages.pop(task_id, []) + cached = self._cached.pop(task_id, 0) + return pages, cached + + def extend(self, task_id: str, pos: int) -> bool: + page_table = self._pages[task_id] + needed = (pos + 1 + self._page_size - 1) // self._page_size + while len(page_table) < needed: + p = self._pool.alloc() + if p < 0: + return False + page_table.append(p) + return True + + def table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: + states = [self._pages.get(tid, []) for tid in task_ids] + max_pages = max((len(s) for s in states), default=0) + rows = [s + [-1] * (max_pages - len(s)) for s in states] + return torch.tensor(rows, dtype=torch.long, device=device) + + class PagedCache: - """Paged KV cache: page pool, prefix-cache lookup, LRU eviction, task-page mapping.""" + """Facade: paged KV-cache backed by PagePool, PrefixCache, and TaskTable.""" def __init__( self, @@ -27,8 +156,10 @@ class PagedCache: dtype: torch.dtype, ): self.page_size = page_size - self._free_mask = (1 << n_pages) - 1 - self._refs: List[int] = [0] * n_pages + self._prefix = PrefixCache(page_size) + self._pool = PagePool(n_pages, on_evict=self._prefix.on_evict) + self._table = TaskTable(self._pool, page_size) + self.k_cache = torch.empty( (n_layers, n_pages, page_size, n_kv_heads, head_dim), device=device, @@ -39,160 +170,81 @@ class PagedCache: device=device, dtype=dtype, ) - self._page_to_hash: Dict[int, int] = {} - self._hash_to_page: Dict[int, int] = {} - self._lru: List[int] = [] - self._pin: List[bool] = [False] * n_pages - self._task_pages: Dict[str, List[int]] = {} - self._task_cached: Dict[str, int] = {} - def pages_needed(self, n_tokens: int) -> int: - return (n_tokens + self.page_size - 1) // self.page_size + def alloc_n(self, n: int) -> List[int]: + pages: List[int] = [] + for _ in range(n): + p = self._pool.alloc() + if p < 0: + for page in pages: + self.free(page) + return [] + pages.append(p) + return pages + + def free(self, idx: int) -> None: + cached = self._prefix.has_page(idx) + self._pool.free(idx, keep_cached=cached) + if not cached: + self._prefix.on_evict(idx) def task_alloc(self, task_id: str, prompt_ids: List[int]) -> bool: - hit_pages = self.lookup_prefix(prompt_ids) - cached_tokens = len(hit_pages) * self.page_size - for p in hit_pages: - self.inc_ref(p) + hits = self._prefix.lookup(prompt_ids, self._pool) + cached = len(hits) * self.page_size + for p in hits: + self._pool.inc_ref(p) - remaining = len(prompt_ids) - cached_tokens - n_new = self.pages_needed(remaining) if remaining > 0 else 0 - new_pages = self.alloc_n(n_new) if n_new > 0 else [] + remaining = len(prompt_ids) - cached + n_new = ( + (remaining + self.page_size - 1) // self.page_size if remaining > 0 else 0 + ) + new_pages: List[int] = [] + if n_new > 0: + for _ in range(n_new): + p = self._pool.alloc() + if p < 0: + for hp in hits: + self.free(hp) + for np in new_pages: + self.free(np) + return False + new_pages.append(p) - if remaining > 0 and not new_pages: - for p in hit_pages: - self.free(p) - return False - - page_table = hit_pages + new_pages - self._task_pages[task_id] = page_table - self._task_cached[task_id] = cached_tokens + self._table.set(task_id, hits + new_pages, cached) return True def task_free(self, task_id: str) -> None: - page_table = self._task_pages.pop(task_id, None) - self._task_cached.pop(task_id, None) - if page_table: - for idx in page_table: - self.free(idx) + page_table, _ = self._table.pop(task_id) + for idx in page_table: + self.free(idx) def task_extend(self, task_id: str, pos: int) -> bool: - needed = self.pages_needed(pos + 1) - page_table = self._task_pages[task_id] - while len(page_table) < needed: - p = self.alloc() - if p < 0: - return False - page_table.append(p) - return True + return self._table.extend(task_id, pos) def task_cached(self, task_id: str) -> int: - return self._task_cached.get(task_id, 0) - - def task_page_table(self, task_id: str) -> List[int]: - return self._task_pages.get(task_id, []) - - def task_n_pages(self, task_id: str) -> int: - return len(self._task_pages.get(task_id, [])) + return self._table.get_cached(task_id) def task_record_hashes( self, task_id: str, prompt_ids: List[int], start_logical_page: int = 0 ) -> None: - page_table = self._task_pages[task_id] + page_table = self._table.get(task_id) full_pages = len(prompt_ids) // self.page_size for i in range(start_logical_page, full_pages): - self.record_page(page_table[i], prompt_ids, i) + self._prefix.record(page_table[i], prompt_ids, i, self._pool) def make_table_tensor(self, task_ids: List[str], device: torch.device) -> Tensor: - states = [self._task_pages.get(tid, []) for tid in task_ids] - max_pages = max((len(s) for s in states), default=0) - rows = [s + [-1] * (max_pages - len(s)) for s in states] - return torch.tensor(rows, dtype=torch.long, device=device) - - def _touch(self, idx: int) -> None: - if self._refs[idx] == 0 and idx in self._lru: - self._lru.remove(idx) - self._lru.append(idx) - - def _evict_one(self) -> int: - while self._lru: - idx = self._lru.pop(0) - h = self._page_to_hash.pop(idx, None) - if h is not None: - self._hash_to_page.pop(h, None) - self._pin[idx] = False - self._refs[idx] = 1 - return idx - return -1 - - def record_page( - self, page_idx: int, token_ids: List[int], logical_page_idx: int - ) -> None: - h = page_hash(token_ids, logical_page_idx, self.page_size) - old_h = self._page_to_hash.pop(page_idx, None) - if old_h is not None: - self._hash_to_page.pop(old_h, None) - self._page_to_hash[page_idx] = h - self._hash_to_page[h] = page_idx - self._pin[page_idx] = True - if page_idx in self._lru: - self._lru.remove(page_idx) - - def lookup_prefix(self, token_ids: List[int]) -> List[int]: - full_pages = len(token_ids) // self.page_size - hits: List[int] = [] - for i in range(full_pages): - h = page_hash(token_ids, i, self.page_size) - p = self._hash_to_page.get(h) - if p is None: - break - self._touch(p) - hits.append(p) - return hits - - def inc_ref(self, idx: int) -> None: - self._refs[idx] += 1 - if self._refs[idx] == 1 and idx in self._lru: - self._lru.remove(idx) - - def alloc(self) -> int: - if self._free_mask: - lsb = self._free_mask & -self._free_mask - idx = lsb.bit_length() - 1 - self._free_mask ^= lsb - self._refs[idx] = 1 - if idx in self._lru: - self._lru.remove(idx) - return idx - return self._evict_one() - - def alloc_n(self, n: int) -> List[int]: - pages = [self.alloc() for _ in range(n)] - if any(p < 0 for p in pages): - for p in pages: - if p >= 0: - self.free(p) - return [] - return pages - - def free(self, idx: int) -> None: - self._refs[idx] -= 1 - if self._refs[idx] == 0: - h = self._page_to_hash.get(idx) - if h is not None and self._pin[idx]: - self._lru.append(idx) - else: - self._free_mask |= 1 << idx - h = self._page_to_hash.pop(idx, None) - if h is not None: - self._hash_to_page.pop(h, None) - self._pin[idx] = False + return self._table.table_tensor(task_ids, device) def bind(self, page_table: Tensor, total_len: int = 0) -> "CacheView": return CacheView(self, page_table, total_len) def write( - self, layer_id: int, page_table: Tensor, start_pos: int, k: Tensor, v: Tensor + self, + layer_id: int, + page_table: Tensor, + start_pos: int, + k: Tensor, + v: Tensor, ) -> None: seq_len = k.size(1) if seq_len == 0: @@ -232,8 +284,6 @@ class PagedCache: class CacheView: """Bundles PagedCache + page_table + total_len for attention layers.""" - __slots__ = ("_cache", "_page_table", "_total_len") - def __init__(self, cache: PagedCache, page_table: Tensor, total_len: int = 0): self._cache = cache self._page_table = page_table diff --git a/astrai/inference/executor.py b/astrai/inference/executor.py index 8239771..62657f2 100644 --- a/astrai/inference/executor.py +++ b/astrai/inference/executor.py @@ -13,33 +13,22 @@ logger = logging.getLogger(__name__) class Executor: - """Model forward passes for prefill and decode; delegates page ops to PagedCache.""" + """Model forward passes for prefill and decode phases.""" def __init__( self, model: AutoModel, tokenizer: AutoTokenizer, page_cache: PagedCache, - page_size: int = 64, device: Optional[str] = None, dtype: Optional[torch.dtype] = None, ): self.model = model self.tokenizer = tokenizer self.page_cache = page_cache - self.page_size = page_size self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype - def allocate_pages_for_activation(self, task: Task) -> bool: - return self.page_cache.task_alloc(task.task_id, task.prompt_ids) - - def free_task_pages(self, task: Task) -> None: - self.page_cache.task_free(task.task_id) - - def get_cached_tokens(self, task: Task) -> int: - return self.page_cache.task_cached(task.task_id) - def execute_prefill( self, tasks: List[Task], prompt_len: int, start_pos: int = 0 ) -> None: @@ -71,7 +60,7 @@ class Executor: paged_cache=self.page_cache.bind(page_tables, total_len=prompt_len), ) - start_logical_page = start_pos // self.page_size + start_logical_page = start_pos // self.page_cache.page_size for t in tasks: self.page_cache.task_record_hashes( t.task_id, t.prompt_ids, start_logical_page=start_logical_page diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index 06d1046..87c278e 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -33,19 +33,16 @@ class InferenceScheduler: self.device = device or next(model.parameters()).device self.dtype = dtype or next(model.parameters()).dtype - n_kv_heads = config.n_kv_heads - head_dim = config.dim // config.n_heads - n_layers = config.n_layers n_pages = ( max_batch_size * (self.max_seq_len + page_size) + page_size - 1 ) // page_size - page_cache = PagedCache( - n_layers, + self._page_cache = PagedCache( + config.n_layers, n_pages, page_size, - n_kv_heads, - head_dim, + config.n_kv_heads, + config.dim // config.n_heads, self.device, self.dtype, ) @@ -60,8 +57,7 @@ class InferenceScheduler: self._executor = Executor( model=model, tokenizer=tokenizer, - page_cache=page_cache, - page_size=page_size, + page_cache=self._page_cache, device=self.device, dtype=self.dtype, ) @@ -73,7 +69,7 @@ class InferenceScheduler: def remove_task(self, task_id: str) -> None: for task in self._task_mgr.remove_task(task_id): - self._executor.free_task_pages(task) + self._page_cache.task_free(task.task_id) def get_stats(self) -> Dict[str, Any]: return self._task_mgr.get_stats() @@ -85,7 +81,7 @@ class InferenceScheduler: self._task_mgr.tokenizer.stop_ids ) for task in finished: - self._executor.free_task_pages(task) + self._page_cache.task_free(task.task_id) available = self._task_mgr.max_batch_size - len( self._task_mgr.active_tasks @@ -94,7 +90,7 @@ class InferenceScheduler: candidates = self._task_mgr.pull_candidates(available) failed = [] for task in candidates: - if self._executor.allocate_pages_for_activation(task): + if self._page_cache.task_alloc(task.task_id, task.prompt_ids): self._task_mgr.activate(task) else: failed.append(task) @@ -114,7 +110,10 @@ class InferenceScheduler: groups: Dict[Tuple[int, int], List[Task]] = {} for t in to_prefill: - key = (len(t.prompt_ids), self._executor.get_cached_tokens(t)) + key = ( + len(t.prompt_ids), + self._page_cache.task_cached(t.task_id), + ) groups.setdefault(key, []).append(t) for (prompt_len, start_pos), group in groups.items():