242 lines
8.0 KiB
Python
242 lines
8.0 KiB
Python
"""KV cache slot allocation and prefix cache management.
|
|
|
|
Provides:
|
|
- SlotAllocator: Object Pool pattern for O(1) KV cache slot alloc/free via bitmask.
|
|
- PrefixCacheManager: Radix-tree prefix cache with LRU eviction for KV cache reuse.
|
|
"""
|
|
|
|
import time
|
|
from collections import OrderedDict
|
|
from typing import Dict, List, Tuple
|
|
|
|
from torch import Tensor
|
|
|
|
_STOP = object()
|
|
|
|
|
|
class _RadixNode:
|
|
"""Internal node for the radix tree prefix cache.
|
|
|
|
Attributes:
|
|
children: Mapping from token ID to child node.
|
|
slot: KV cache slot index for the prefix ending at this node.
|
|
slot_ver: Version counter of the slot at insertion time.
|
|
ref_count: Number of tasks currently referencing this node.
|
|
last_access: Timestamp of the most recent access (for LRU ordering).
|
|
"""
|
|
|
|
__slots__ = ("children", "slot", "slot_ver", "ref_count", "last_access")
|
|
|
|
def __init__(self):
|
|
self.children: Dict[int, "_RadixNode"] = {}
|
|
self.slot: int = -1
|
|
self.slot_ver: int = 0
|
|
self.ref_count: int = 0
|
|
self.last_access: float = 0.0
|
|
|
|
|
|
class SlotAllocator:
|
|
"""KV cache slot allocator using bitmask for O(1) alloc/free.
|
|
|
|
Implements the Object Pool pattern: pre-allocated KV cache slots
|
|
are managed via a bitmask, providing constant-time allocation and
|
|
deallocation with version counters for staleness detection.
|
|
"""
|
|
|
|
def __init__(self, max_slots: int):
|
|
self._max_slots = max_slots
|
|
self._free_mask = (1 << max_slots) - 1
|
|
self._versions: List[int] = [0] * max_slots
|
|
|
|
def alloc(self) -> int:
|
|
"""Allocates a free slot.
|
|
|
|
Returns:
|
|
Slot index on success, -1 if all slots are occupied.
|
|
"""
|
|
lsb = self._free_mask & -self._free_mask
|
|
if lsb == 0:
|
|
return -1
|
|
idx = lsb.bit_length() - 1
|
|
self._free_mask ^= lsb
|
|
self._versions[idx] += 1
|
|
return idx
|
|
|
|
def free(self, idx: int) -> None:
|
|
"""Releases a slot back to the free pool."""
|
|
self._free_mask |= 1 << idx
|
|
|
|
def occupy(self, idx: int) -> None:
|
|
"""Marks a currently free slot as occupied without bumping its version.
|
|
|
|
Used for direct slot reuse when a prefix-cached slot is still valid.
|
|
"""
|
|
self._free_mask ^= 1 << idx
|
|
|
|
def is_free(self, idx: int) -> bool:
|
|
"""Checks whether a slot is currently free."""
|
|
return (self._free_mask >> idx) & 1 == 1
|
|
|
|
def version(self, idx: int) -> int:
|
|
"""Returns the current version counter for a slot."""
|
|
return self._versions[idx]
|
|
|
|
@property
|
|
def free_count(self) -> int:
|
|
"""Returns the number of currently free slots."""
|
|
return self._free_mask.bit_count()
|
|
|
|
|
|
class PrefixCacheManager:
|
|
"""Radix-tree prefix cache with LRU eviction.
|
|
|
|
Maps token ID sequences to KV cache slots. Intermediate tree nodes
|
|
also store slot information, allowing direct slot reuse when the
|
|
cached slot is free and its version matches (no intervening writes).
|
|
"""
|
|
|
|
def __init__(self, max_capacity: int = 1000):
|
|
"""Initializes the prefix cache.
|
|
|
|
Args:
|
|
max_capacity: Maximum number of nodes in the LRU list.
|
|
"""
|
|
self.root = _RadixNode()
|
|
self.max_capacity = max_capacity
|
|
self._lru: OrderedDict[int, _RadixNode] = OrderedDict()
|
|
|
|
def insert(self, token_ids: Tuple[int, ...], slot: int, slot_ver: int) -> None:
|
|
"""Inserts a token sequence into the prefix cache.
|
|
|
|
Every node along the path records the slot and its version,
|
|
enabling direct slot reuse for partial prefix matches.
|
|
|
|
Args:
|
|
token_ids: The token ID sequence to cache.
|
|
slot: The KV cache slot containing this prefix's computed keys/values.
|
|
slot_ver: The slot version at insertion time, used for staleness detection.
|
|
"""
|
|
node = self.root
|
|
for tid in token_ids:
|
|
nxt = node.children.get(tid)
|
|
if nxt is None:
|
|
nxt = _RadixNode()
|
|
node.children[tid] = nxt
|
|
node = nxt
|
|
node.slot = slot
|
|
node.slot_ver = slot_ver
|
|
node.last_access = time.time()
|
|
self._lru[id(node)] = node
|
|
node.ref_count += 1
|
|
self._evict_if_needed()
|
|
|
|
def find(self, token_ids: List[int]) -> Tuple[int, int, int]:
|
|
"""Finds the longest matching prefix in the cache.
|
|
|
|
Walks the radix tree token by token, recording the deepest match.
|
|
|
|
Args:
|
|
token_ids: The token sequence to match against.
|
|
|
|
Returns:
|
|
Tuple of (prefix_len, slot, slot_ver):
|
|
prefix_len: Number of matching tokens (0 if no match).
|
|
slot: KV cache slot of the matched prefix (-1 if no match).
|
|
slot_ver: Version of that slot when the prefix was inserted.
|
|
"""
|
|
node = self.root
|
|
best_len, best_slot, best_ver = 0, -1, 0
|
|
for i, tid in enumerate(token_ids):
|
|
nxt = node.children.get(tid)
|
|
if nxt is None:
|
|
break
|
|
node = nxt
|
|
best_len, best_slot, best_ver = i + 1, node.slot, node.slot_ver
|
|
node.last_access = time.time()
|
|
self._lru.move_to_end(id(node))
|
|
return best_len, best_slot, best_ver
|
|
|
|
def pin(self, token_ids: Tuple[int, ...]) -> None:
|
|
"""Increments the reference count of a cached prefix.
|
|
|
|
Called when a task reuses a cached prefix to prevent eviction.
|
|
|
|
Args:
|
|
token_ids: The token sequence whose node's ref_count to increment.
|
|
"""
|
|
node = self.root
|
|
for tid in token_ids:
|
|
nxt = node.children.get(tid)
|
|
if nxt is None:
|
|
return
|
|
node = nxt
|
|
node.ref_count += 1
|
|
|
|
def release(self, token_ids: Tuple[int, ...]) -> None:
|
|
"""Decrements the reference count of a cached prefix.
|
|
|
|
The node's slot is preserved even when ref_count reaches zero,
|
|
allowing future tasks to reuse the slot directly if it remains free.
|
|
|
|
Args:
|
|
token_ids: The token sequence whose node's ref_count to decrement.
|
|
"""
|
|
node = self.root
|
|
for tid in token_ids:
|
|
nxt = node.children.get(tid)
|
|
if nxt is None:
|
|
return
|
|
node = nxt
|
|
if node.ref_count > 0:
|
|
node.ref_count -= 1
|
|
|
|
def copy_kv(
|
|
self,
|
|
token_ids: Tuple[int, ...],
|
|
target_slot: int,
|
|
kv_cache: Tuple[Tensor, Tensor],
|
|
n_layers: int,
|
|
) -> None:
|
|
"""Copies cached KV data from the source slot to a target slot.
|
|
|
|
Args:
|
|
token_ids: The prefix token sequence identifying the source cache node.
|
|
target_slot: The destination KV cache slot to copy into.
|
|
kv_cache: Tuple of (k_cache, v_cache) tensors.
|
|
n_layers: Number of transformer layers to copy.
|
|
"""
|
|
node = self.root
|
|
for tid in token_ids:
|
|
nxt = node.children.get(tid)
|
|
if nxt is None:
|
|
return
|
|
node = nxt
|
|
src_slot = node.slot
|
|
if src_slot < 0:
|
|
return
|
|
prefix_len = len(token_ids)
|
|
k_cache, v_cache = kv_cache
|
|
for li in range(n_layers):
|
|
k_cache[target_slot, :prefix_len, li].copy_(
|
|
k_cache[src_slot, :prefix_len, li]
|
|
)
|
|
v_cache[target_slot, :prefix_len, li].copy_(
|
|
v_cache[src_slot, :prefix_len, li]
|
|
)
|
|
|
|
def _evict_if_needed(self) -> None:
|
|
"""Evicts least-recently-used nodes until under capacity.
|
|
|
|
Skips nodes with ref_count > 0 (still in use by active tasks).
|
|
Evicted nodes have their slot and children cleared.
|
|
"""
|
|
while len(self._lru) > self.max_capacity:
|
|
key, node = next(iter(self._lru.items()))
|
|
if node.ref_count > 0:
|
|
self._lru.move_to_end(key)
|
|
continue
|
|
self._lru.pop(key)
|
|
node.slot = -1
|
|
node.slot_ver = 0
|
|
node.children.clear()
|