280 lines
6.8 KiB
Python
280 lines
6.8 KiB
Python
"""Unit tests for inference cache components."""
|
|
|
|
import torch
|
|
|
|
from astrai.inference import (
|
|
Allocator,
|
|
KVCache,
|
|
PagePool,
|
|
PrefixCache,
|
|
Storage,
|
|
TaskTable,
|
|
page_hash,
|
|
)
|
|
|
|
|
|
def make_pool(n_pages: int, page_size: int) -> PagePool:
|
|
return PagePool(Allocator(n_pages), PrefixCache(page_size))
|
|
|
|
|
|
def test_page_hash_full_page():
|
|
token_ids = list(range(256))
|
|
h = page_hash(token_ids, 0, 64)
|
|
assert isinstance(h, int)
|
|
assert h >= 0
|
|
|
|
|
|
def test_page_hash_different_page_differs():
|
|
token_ids = list(range(256))
|
|
assert page_hash(token_ids, 0, 64) != page_hash(token_ids, 1, 64)
|
|
|
|
|
|
def test_page_pool_alloc_free_cycle():
|
|
pool = make_pool(4, 64)
|
|
a = pool.alloc()
|
|
b = pool.alloc()
|
|
assert a != b
|
|
pool.free(a)
|
|
pool.free(b)
|
|
c = pool.alloc()
|
|
assert c in (a, b)
|
|
|
|
|
|
def test_page_pool_alloc_when_full():
|
|
pool = make_pool(2, 64)
|
|
pool.alloc()
|
|
pool.alloc()
|
|
assert pool.alloc() == -1
|
|
|
|
|
|
def test_page_pool_lru_eviction():
|
|
pool = make_pool(2, 64)
|
|
p0 = pool.alloc()
|
|
p1 = pool.alloc()
|
|
pool.record(p0, list(range(64)), 0)
|
|
pool.record(p1, list(range(64, 128)), 0)
|
|
pool.free(p0)
|
|
pool.free(p1)
|
|
pool.alloc()
|
|
assert p0 in pool._alloc._lru or p1 in pool._alloc._lru
|
|
|
|
|
|
def test_page_pool_inc_ref_and_free():
|
|
pool = make_pool(2, 64)
|
|
p = pool.alloc()
|
|
pool.inc_ref(p)
|
|
assert pool._alloc._refs[p] == 2
|
|
pool.free(p)
|
|
assert pool._alloc._refs[p] == 1
|
|
pool.free(p)
|
|
assert pool._alloc._refs[p] == 0
|
|
|
|
|
|
def test_page_pool_keep_cached_realloc():
|
|
"""Free mask has priority over LRU; cached page returned only when no free pages."""
|
|
pool = make_pool(3, 64)
|
|
p0 = pool.alloc()
|
|
p1 = pool.alloc()
|
|
p2 = pool.alloc()
|
|
for p in (p0, p1, p2):
|
|
pool.record(p, [p] * 64, 0)
|
|
pool.free(p0)
|
|
pool.free(p1)
|
|
pool.free(p2)
|
|
assert pool.alloc() == p0
|
|
|
|
|
|
def test_prefix_cache_lookup_returns_hits():
|
|
token_ids = list(range(256))
|
|
pool = make_pool(16, 64)
|
|
pages = [pool.alloc() for _ in range(4)]
|
|
for i, p in enumerate(pages):
|
|
pool.record(p, token_ids, i)
|
|
pool.free(p)
|
|
hits = pool.lookup(token_ids)
|
|
assert hits == pages
|
|
|
|
|
|
def test_prefix_cache_lookup_stops_at_first_miss():
|
|
token_ids = list(range(256))
|
|
pool = make_pool(16, 64)
|
|
p0 = pool.alloc()
|
|
pool.record(p0, token_ids, 0)
|
|
pool.free(p0)
|
|
p1 = pool.alloc()
|
|
pool.record(p1, [99] * 64, 1)
|
|
pool.free(p1)
|
|
hits = pool.lookup(token_ids)
|
|
assert len(hits) == 1
|
|
assert hits[0] == p0
|
|
|
|
|
|
def test_prefix_cache_ignores_partial_last_page():
|
|
token_ids = list(range(100))
|
|
pool = make_pool(16, 64)
|
|
p = pool.alloc()
|
|
pool.record(p, token_ids, 0)
|
|
pool.free(p)
|
|
hits = pool.lookup(token_ids)
|
|
assert len(hits) == 1
|
|
|
|
|
|
def test_prefix_cache_on_evict_clears_mappings():
|
|
pool = make_pool(4, 64)
|
|
p = pool.alloc()
|
|
pool.record(p, list(range(64)), 0)
|
|
pool.free(p)
|
|
assert p in pool._prefix._page_to_hash
|
|
pool._prefix.evict(p)
|
|
assert p not in pool._prefix._page_to_hash
|
|
|
|
|
|
def test_prefix_cache_has_page():
|
|
pool = make_pool(4, 64)
|
|
p = pool.alloc()
|
|
assert p not in pool._prefix._page_to_hash
|
|
pool.record(p, list(range(64)), 0)
|
|
pool.free(p)
|
|
assert p in pool._prefix._page_to_hash
|
|
|
|
|
|
def test_task_table_set_get():
|
|
table = TaskTable(page_size=64)
|
|
table.set("task1", [0, 1, 2], 128)
|
|
assert table.get("task1") == [0, 1, 2]
|
|
assert table.get_cached("task1") == 128
|
|
|
|
|
|
def test_task_table_get_missing():
|
|
table = TaskTable(page_size=64)
|
|
assert table.get("nonexistent") == []
|
|
assert table.get_cached("nonexistent") == 0
|
|
|
|
|
|
def test_task_table_pop():
|
|
table = TaskTable(page_size=64)
|
|
table.set("task1", [0, 1], 64)
|
|
pages, cached = table.pop("task1")
|
|
assert pages == [0, 1]
|
|
assert cached == 64
|
|
assert table.get("task1") == []
|
|
|
|
|
|
def test_kv_cache_task_extend_allocates():
|
|
cache = KVCache(
|
|
n_layers=1,
|
|
n_pages=8,
|
|
page_size=64,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
cache._table.set("task1", [], 0)
|
|
ok = cache.task_extend("task1", 200)
|
|
assert ok
|
|
assert len(cache._table.get("task1")) == 4
|
|
|
|
|
|
def test_kv_cache_task_extend_fails_when_pool_full():
|
|
cache = KVCache(
|
|
n_layers=1,
|
|
n_pages=2,
|
|
page_size=64,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
cache._table.set("task1", [0, 1], 0)
|
|
ok = cache.task_extend("task1", 300)
|
|
assert not ok
|
|
|
|
|
|
def test_task_table_table_tensor():
|
|
table = TaskTable(page_size=64)
|
|
table.set("a", [0, 1], 0)
|
|
table.set("b", [2, 3, 4], 0)
|
|
t = table.table_tensor(["a", "b"], torch.device("cpu"))
|
|
assert t.shape == (2, 3)
|
|
assert t[0].tolist() == [0, 1, -1]
|
|
assert t[1].tolist() == [2, 3, 4]
|
|
|
|
|
|
def test_task_table_table_tensor_empty_input():
|
|
table = TaskTable(page_size=64)
|
|
t = table.table_tensor([], torch.device("cpu"))
|
|
assert t.numel() == 0
|
|
|
|
|
|
def test_storage_write_gather_single_page():
|
|
storage = Storage(
|
|
n_layers=2,
|
|
n_pages=8,
|
|
page_size=4,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
page_table = torch.tensor([[0]], dtype=torch.long)
|
|
k = torch.randn(1, 2, 2, 8)
|
|
v = torch.randn(1, 2, 2, 8)
|
|
|
|
storage.write(0, page_table, 0, k, v)
|
|
gk, gv = storage.gather(0, page_table, 2)
|
|
assert torch.allclose(gk, k)
|
|
|
|
|
|
def test_storage_write_cross_page():
|
|
storage = Storage(
|
|
n_layers=1,
|
|
n_pages=8,
|
|
page_size=4,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
|
k = torch.randn(1, 8, 2, 8)
|
|
v = torch.randn(1, 8, 2, 8)
|
|
|
|
storage.write(0, page_table, 0, k, v)
|
|
gk, gv = storage.gather(0, page_table, 8)
|
|
assert torch.allclose(gk, k)
|
|
|
|
|
|
def test_storage_gather_truncates_to_total_len():
|
|
storage = Storage(
|
|
n_layers=1,
|
|
n_pages=8,
|
|
page_size=4,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
|
k = torch.randn(1, 6, 2, 8)
|
|
v = torch.randn(1, 6, 2, 8)
|
|
storage.write(0, page_table, 0, k, v)
|
|
|
|
gk, gv = storage.gather(0, page_table, 5)
|
|
assert gk.shape == (1, 5, 2, 8)
|
|
|
|
|
|
def test_storage_gather_clamps_negative_padding():
|
|
storage = Storage(
|
|
n_layers=1,
|
|
n_pages=8,
|
|
page_size=4,
|
|
n_kv_heads=2,
|
|
head_dim=8,
|
|
device=torch.device("cpu"),
|
|
dtype=torch.float32,
|
|
)
|
|
page_table = torch.tensor([[0, -1]], dtype=torch.long)
|
|
gk, gv = storage.gather(0, page_table, 4)
|
|
assert gk.shape == (1, 4, 2, 8)
|