303 lines
7.7 KiB
Python
303 lines
7.7 KiB
Python
"""Unit tests for inference cache components."""
|
|
|
|
import torch
|
|
|
|
from astrai.inference import (
|
|
PagedCache,
|
|
PagePool,
|
|
PrefixCache,
|
|
TaskTable,
|
|
page_hash,
|
|
)
|
|
|
|
|
|
def test_page_hash_full_page():
|
|
token_ids = list(range(256))
|
|
h = page_hash(token_ids, 0, 64)
|
|
assert isinstance(h, int)
|
|
assert h >= 0
|
|
|
|
|
|
def test_page_hash_different_page_differs():
|
|
token_ids = list(range(256))
|
|
assert page_hash(token_ids, 0, 64) != page_hash(token_ids, 1, 64)
|
|
|
|
|
|
def test_page_pool_alloc_free_cycle():
|
|
pool = PagePool(n_pages=4)
|
|
a = pool.alloc()
|
|
b = pool.alloc()
|
|
assert a != b
|
|
pool.free(a)
|
|
pool.free(b)
|
|
c = pool.alloc()
|
|
assert c in (a, b)
|
|
|
|
|
|
def test_page_pool_alloc_when_full():
|
|
pool = PagePool(n_pages=2)
|
|
pool.alloc()
|
|
pool.alloc()
|
|
assert pool.alloc() == -1
|
|
|
|
|
|
def test_page_pool_lru_eviction():
|
|
evicted = []
|
|
|
|
def on_evict(idx):
|
|
evicted.append(idx)
|
|
|
|
pool = PagePool(n_pages=2, on_evict=on_evict)
|
|
p0 = pool.alloc()
|
|
p1 = pool.alloc()
|
|
pool.free(p0, keep_cached=True)
|
|
pool.free(p1, keep_cached=True)
|
|
pool.alloc()
|
|
assert len(evicted) == 1
|
|
assert evicted[0] == p0
|
|
|
|
|
|
def test_page_pool_inc_ref_and_free():
|
|
pool = PagePool(n_pages=2)
|
|
p = pool.alloc()
|
|
pool.inc_ref(p)
|
|
assert pool._refs[p] == 2
|
|
pool.free(p)
|
|
assert pool._refs[p] == 1
|
|
pool.free(p)
|
|
assert pool._refs[p] == 0
|
|
|
|
|
|
def test_page_pool_touch_moves_to_end():
|
|
pool = PagePool(n_pages=4)
|
|
p0 = pool.alloc()
|
|
p1 = pool.alloc()
|
|
p2 = pool.alloc()
|
|
pool.free(p0, keep_cached=True)
|
|
pool.free(p1, keep_cached=True)
|
|
pool.free(p2, keep_cached=True)
|
|
assert next(iter(pool._lru)) == p0
|
|
pool.touch(p0)
|
|
assert next(reversed(pool._lru)) == p0
|
|
|
|
|
|
def test_page_pool_remove_from_lru():
|
|
pool = PagePool(n_pages=4)
|
|
p0 = pool.alloc()
|
|
pool.free(p0, keep_cached=True)
|
|
assert p0 in pool._lru
|
|
pool.remove_from_lru(p0)
|
|
assert p0 not in pool._lru
|
|
|
|
|
|
def test_page_pool_keep_cached_realloc():
|
|
"""Free mask has priority over LRU; cached page returned only when no free pages."""
|
|
pool = PagePool(n_pages=3)
|
|
p0 = pool.alloc()
|
|
p1 = pool.alloc()
|
|
p2 = pool.alloc()
|
|
pool.free(p0, keep_cached=True)
|
|
pool.free(p1, keep_cached=True)
|
|
pool.free(p2, keep_cached=True)
|
|
assert pool.alloc() == p0
|
|
|
|
|
|
def _record_then_cache(pool, prefix, page, token_ids, logical_idx):
|
|
"""Simulate the real lifecycle: record → ref stays >0, then free cached returns to LRU."""
|
|
prefix.record(page, token_ids, logical_idx, pool)
|
|
pool.free(page, keep_cached=True)
|
|
|
|
|
|
def test_prefix_cache_lookup_returns_hits():
|
|
token_ids = list(range(256))
|
|
pool = PagePool(n_pages=16)
|
|
prefix = PrefixCache(page_size=64)
|
|
pages = [pool.alloc() for _ in range(4)]
|
|
for i, p in enumerate(pages):
|
|
_record_then_cache(pool, prefix, p, token_ids, i)
|
|
hits = prefix.lookup(token_ids, pool)
|
|
assert hits == pages
|
|
|
|
|
|
def test_prefix_cache_lookup_stops_at_first_miss():
|
|
token_ids = list(range(256))
|
|
pool = PagePool(n_pages=16)
|
|
prefix = PrefixCache(page_size=64)
|
|
p0 = pool.alloc()
|
|
_record_then_cache(pool, prefix, p0, token_ids, 0)
|
|
p1 = pool.alloc()
|
|
_record_then_cache(pool, prefix, p1, [99] * 64, 1)
|
|
hits = prefix.lookup(token_ids, pool)
|
|
assert len(hits) == 1
|
|
assert hits[0] == p0
|
|
|
|
|
|
def test_prefix_cache_ignores_partial_last_page():
|
|
token_ids = list(range(100))
|
|
pool = PagePool(n_pages=16)
|
|
prefix = PrefixCache(page_size=64)
|
|
p = pool.alloc()
|
|
_record_then_cache(pool, prefix, p, token_ids, 0)
|
|
hits = prefix.lookup(token_ids, pool)
|
|
assert len(hits) == 1
|
|
|
|
|
|
def test_prefix_cache_on_evict_clears_mappings():
|
|
pool = PagePool(n_pages=4)
|
|
prefix = PrefixCache(page_size=64)
|
|
p = pool.alloc()
|
|
_record_then_cache(pool, prefix, p, list(range(64)), 0)
|
|
assert prefix.has_page(p)
|
|
prefix.on_evict(p)
|
|
assert not prefix.has_page(p)
|
|
|
|
|
|
def test_prefix_cache_has_page():
|
|
pool = PagePool(n_pages=4)
|
|
prefix = PrefixCache(page_size=64)
|
|
p = pool.alloc()
|
|
assert not prefix.has_page(p)
|
|
_record_then_cache(pool, prefix, p, list(range(64)), 0)
|
|
assert prefix.has_page(p)
|
|
|
|
|
|
def test_task_table_set_get():
|
|
table = TaskTable(page_size=64)
|
|
table.set("task1", [0, 1, 2], 128)
|
|
assert table.get("task1") == [0, 1, 2]
|
|
assert table.get_cached("task1") == 128
|
|
|
|
|
|
def test_task_table_get_missing():
|
|
table = TaskTable(page_size=64)
|
|
assert table.get("nonexistent") == []
|
|
assert table.get_cached("nonexistent") == 0
|
|
|
|
|
|
def test_task_table_pop():
|
|
table = TaskTable(page_size=64)
|
|
table.set("task1", [0, 1], 64)
|
|
pages, cached = table.pop("task1")
|
|
assert pages == [0, 1]
|
|
assert cached == 64
|
|
assert table.get("task1") == []
|
|
|
|
|
|
def test_paged_cache_task_extend_allocates():
|
|
cache = PagedCache(
|
|
n_layers=1,
|
|
n_pages=8,
|
|
page_size=64,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
cache._table.set("task1", [], 0)
|
|
ok = cache.task_extend("task1", 200)
|
|
assert ok
|
|
assert len(cache._table.get("task1")) == 4
|
|
|
|
|
|
def test_paged_cache_task_extend_fails_when_pool_full():
|
|
cache = PagedCache(
|
|
n_layers=1,
|
|
n_pages=2,
|
|
page_size=64,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
cache._table.set("task1", [0, 1], 0)
|
|
ok = cache.task_extend("task1", 300)
|
|
assert not ok
|
|
|
|
|
|
def test_task_table_table_tensor():
|
|
table = TaskTable(page_size=64)
|
|
table.set("a", [0, 1], 0)
|
|
table.set("b", [2, 3, 4], 0)
|
|
t = table.table_tensor(["a", "b"], torch.device("cpu"))
|
|
assert t.shape == (2, 3)
|
|
assert t[0].tolist() == [0, 1, -1]
|
|
assert t[1].tolist() == [2, 3, 4]
|
|
|
|
|
|
def test_task_table_table_tensor_empty_input():
|
|
table = TaskTable(page_size=64)
|
|
t = table.table_tensor([], torch.device("cpu"))
|
|
assert t.numel() == 0
|
|
|
|
|
|
def test_paged_cache_write_gather_single_page():
|
|
cache = PagedCache(
|
|
n_layers=2,
|
|
n_pages=8,
|
|
page_size=4,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
page_table = torch.tensor([[0]], dtype=torch.long)
|
|
k = torch.randn(1, 2, 2, 8)
|
|
v = torch.randn(1, 2, 2, 8)
|
|
|
|
cache.write(0, page_table, 0, k, v)
|
|
gk, gv = cache.gather(0, page_table, 2)
|
|
assert torch.allclose(gk, k)
|
|
|
|
|
|
def test_paged_cache_write_cross_page():
|
|
cache = PagedCache(
|
|
n_layers=1,
|
|
n_pages=8,
|
|
page_size=4,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
|
k = torch.randn(1, 8, 2, 8)
|
|
v = torch.randn(1, 8, 2, 8)
|
|
|
|
cache.write(0, page_table, 0, k, v)
|
|
gk, gv = cache.gather(0, page_table, 8)
|
|
assert torch.allclose(gk, k)
|
|
|
|
|
|
def test_paged_cache_gather_truncates_to_total_len():
|
|
cache = PagedCache(
|
|
n_layers=1,
|
|
n_pages=8,
|
|
page_size=4,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
|
k = torch.randn(1, 6, 2, 8)
|
|
v = torch.randn(1, 6, 2, 8)
|
|
cache.write(0, page_table, 0, k, v)
|
|
|
|
gk, gv = cache.gather(0, page_table, 5)
|
|
assert gk.shape == (1, 5, 2, 8)
|
|
|
|
|
|
def test_paged_cache_gather_clamps_negative_padding():
|
|
cache = PagedCache(
|
|
n_layers=1,
|
|
n_pages=8,
|
|
page_size=4,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
page_table = torch.tensor([[0, -1]], dtype=torch.long)
|
|
gk, gv = cache.gather(0, page_table, 4)
|
|
assert gk.shape == (1, 4, 2, 8)
|