test: inference 模块补全单元测试,cache/sample/engine/task
- test_cache: page_hash, PagePool, PrefixCache, TaskTable, PagedCache write/gather - test_sample: TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample() - test_engine: _Result 线程安全, generate stream/non-stream batch/single - test_task: Task 生命周期, TaskManager 队列操作 - 4 新文件, +771 行, 116 total tests
This commit is contained in:
parent
0ca6c9e6eb
commit
7d4029c2a4
|
|
@ -0,0 +1,293 @@
|
||||||
|
"""Unit tests for inference cache components."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.inference.cache import (
|
||||||
|
PagedCache,
|
||||||
|
PagePool,
|
||||||
|
PrefixCache,
|
||||||
|
TaskTable,
|
||||||
|
page_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_hash_full_page():
|
||||||
|
token_ids = list(range(256))
|
||||||
|
h = page_hash(token_ids, 0, 64)
|
||||||
|
assert isinstance(h, int)
|
||||||
|
assert h >= 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_hash_different_page_differs():
|
||||||
|
token_ids = list(range(256))
|
||||||
|
assert page_hash(token_ids, 0, 64) != page_hash(token_ids, 1, 64)
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_pool_alloc_free_cycle():
|
||||||
|
pool = PagePool(n_pages=4)
|
||||||
|
a = pool.alloc()
|
||||||
|
b = pool.alloc()
|
||||||
|
assert a != b
|
||||||
|
pool.free(a)
|
||||||
|
pool.free(b)
|
||||||
|
c = pool.alloc()
|
||||||
|
assert c in (a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_pool_alloc_when_full():
|
||||||
|
pool = PagePool(n_pages=2)
|
||||||
|
pool.alloc()
|
||||||
|
pool.alloc()
|
||||||
|
assert pool.alloc() == -1
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_pool_lru_eviction():
|
||||||
|
evicted = []
|
||||||
|
|
||||||
|
def on_evict(idx):
|
||||||
|
evicted.append(idx)
|
||||||
|
|
||||||
|
pool = PagePool(n_pages=2, on_evict=on_evict)
|
||||||
|
p0 = pool.alloc()
|
||||||
|
p1 = pool.alloc()
|
||||||
|
pool.free(p0, keep_cached=True)
|
||||||
|
pool.free(p1, keep_cached=True)
|
||||||
|
pool.alloc()
|
||||||
|
assert len(evicted) == 1
|
||||||
|
assert evicted[0] == p0
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_pool_inc_ref_and_free():
|
||||||
|
pool = PagePool(n_pages=2)
|
||||||
|
p = pool.alloc()
|
||||||
|
pool.inc_ref(p)
|
||||||
|
assert pool._refs[p] == 2
|
||||||
|
pool.free(p)
|
||||||
|
assert pool._refs[p] == 1
|
||||||
|
pool.free(p)
|
||||||
|
assert pool._refs[p] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_pool_touch_moves_to_end():
|
||||||
|
pool = PagePool(n_pages=4)
|
||||||
|
p0 = pool.alloc()
|
||||||
|
p1 = pool.alloc()
|
||||||
|
p2 = pool.alloc()
|
||||||
|
pool.free(p0, keep_cached=True)
|
||||||
|
pool.free(p1, keep_cached=True)
|
||||||
|
pool.free(p2, keep_cached=True)
|
||||||
|
assert next(iter(pool._lru)) == p0
|
||||||
|
pool.touch(p0)
|
||||||
|
assert next(reversed(pool._lru)) == p0
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_pool_remove_from_lru():
|
||||||
|
pool = PagePool(n_pages=4)
|
||||||
|
p0 = pool.alloc()
|
||||||
|
pool.free(p0, keep_cached=True)
|
||||||
|
assert p0 in pool._lru
|
||||||
|
pool.remove_from_lru(p0)
|
||||||
|
assert p0 not in pool._lru
|
||||||
|
|
||||||
|
|
||||||
|
def test_page_pool_keep_cached_realloc():
|
||||||
|
"""Free mask has priority over LRU; cached page returned only when no free pages."""
|
||||||
|
pool = PagePool(n_pages=3)
|
||||||
|
p0 = pool.alloc()
|
||||||
|
p1 = pool.alloc()
|
||||||
|
p2 = pool.alloc()
|
||||||
|
pool.free(p0, keep_cached=True)
|
||||||
|
pool.free(p1, keep_cached=True)
|
||||||
|
pool.free(p2, keep_cached=True)
|
||||||
|
assert pool.alloc() == p0
|
||||||
|
|
||||||
|
|
||||||
|
def _record_then_cache(pool, prefix, page, token_ids, logical_idx):
|
||||||
|
"""Simulate the real lifecycle: record → ref stays >0, then free cached returns to LRU."""
|
||||||
|
prefix.record(page, token_ids, logical_idx, pool)
|
||||||
|
pool.free(page, keep_cached=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_cache_lookup_returns_hits():
|
||||||
|
token_ids = list(range(256))
|
||||||
|
pool = PagePool(n_pages=16)
|
||||||
|
prefix = PrefixCache(page_size=64)
|
||||||
|
pages = [pool.alloc() for _ in range(4)]
|
||||||
|
for i, p in enumerate(pages):
|
||||||
|
_record_then_cache(pool, prefix, p, token_ids, i)
|
||||||
|
hits = prefix.lookup(token_ids, pool)
|
||||||
|
assert hits == pages
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_cache_lookup_stops_at_first_miss():
|
||||||
|
token_ids = list(range(256))
|
||||||
|
pool = PagePool(n_pages=16)
|
||||||
|
prefix = PrefixCache(page_size=64)
|
||||||
|
p0 = pool.alloc()
|
||||||
|
_record_then_cache(pool, prefix, p0, token_ids, 0)
|
||||||
|
p1 = pool.alloc()
|
||||||
|
_record_then_cache(pool, prefix, p1, [99] * 64, 1)
|
||||||
|
hits = prefix.lookup(token_ids, pool)
|
||||||
|
assert len(hits) == 1
|
||||||
|
assert hits[0] == p0
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_cache_ignores_partial_last_page():
|
||||||
|
token_ids = list(range(100))
|
||||||
|
pool = PagePool(n_pages=16)
|
||||||
|
prefix = PrefixCache(page_size=64)
|
||||||
|
p = pool.alloc()
|
||||||
|
_record_then_cache(pool, prefix, p, token_ids, 0)
|
||||||
|
hits = prefix.lookup(token_ids, pool)
|
||||||
|
assert len(hits) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_cache_on_evict_clears_mappings():
|
||||||
|
pool = PagePool(n_pages=4)
|
||||||
|
prefix = PrefixCache(page_size=64)
|
||||||
|
p = pool.alloc()
|
||||||
|
_record_then_cache(pool, prefix, p, list(range(64)), 0)
|
||||||
|
assert prefix.has_page(p)
|
||||||
|
prefix.on_evict(p)
|
||||||
|
assert not prefix.has_page(p)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefix_cache_has_page():
|
||||||
|
pool = PagePool(n_pages=4)
|
||||||
|
prefix = PrefixCache(page_size=64)
|
||||||
|
p = pool.alloc()
|
||||||
|
assert not prefix.has_page(p)
|
||||||
|
_record_then_cache(pool, prefix, p, list(range(64)), 0)
|
||||||
|
assert prefix.has_page(p)
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_table_set_get():
|
||||||
|
pool = PagePool(n_pages=8)
|
||||||
|
table = TaskTable(pool, page_size=64)
|
||||||
|
table.set("task1", [0, 1, 2], 128)
|
||||||
|
assert table.get("task1") == [0, 1, 2]
|
||||||
|
assert table.get_cached("task1") == 128
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_table_get_missing():
|
||||||
|
pool = PagePool(n_pages=8)
|
||||||
|
table = TaskTable(pool, page_size=64)
|
||||||
|
assert table.get("nonexistent") == []
|
||||||
|
assert table.get_cached("nonexistent") == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_table_pop():
|
||||||
|
pool = PagePool(n_pages=8)
|
||||||
|
table = TaskTable(pool, page_size=64)
|
||||||
|
table.set("task1", [0, 1], 64)
|
||||||
|
pages, cached = table.pop("task1")
|
||||||
|
assert pages == [0, 1]
|
||||||
|
assert cached == 64
|
||||||
|
assert table.get("task1") == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_table_extend_allocates_pages():
|
||||||
|
pool = PagePool(n_pages=8)
|
||||||
|
table = TaskTable(pool, page_size=64)
|
||||||
|
table.set("task1", [], 0)
|
||||||
|
ok = table.extend("task1", 200)
|
||||||
|
assert ok
|
||||||
|
assert len(table.get("task1")) == 4
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_table_extend_fails_when_pool_full():
|
||||||
|
pool = PagePool(n_pages=2)
|
||||||
|
table = TaskTable(pool, page_size=64)
|
||||||
|
table.set("task1", [pool.alloc(), pool.alloc()], 0)
|
||||||
|
ok = table.extend("task1", 300)
|
||||||
|
assert not ok
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_table_table_tensor():
|
||||||
|
pool = PagePool(n_pages=16)
|
||||||
|
table = TaskTable(pool, page_size=64)
|
||||||
|
table.set("a", [0, 1], 0)
|
||||||
|
table.set("b", [2, 3, 4], 0)
|
||||||
|
t = table.table_tensor(["a", "b"], torch.device("cpu"))
|
||||||
|
assert t.shape == (2, 3)
|
||||||
|
assert t[0].tolist() == [0, 1, -1]
|
||||||
|
assert t[1].tolist() == [2, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_table_table_tensor_empty_input():
|
||||||
|
pool = PagePool(n_pages=4)
|
||||||
|
table = TaskTable(pool, page_size=64)
|
||||||
|
t = table.table_tensor([], torch.device("cpu"))
|
||||||
|
assert t.numel() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_paged_cache_write_gather_single_page():
|
||||||
|
cache = PagedCache(
|
||||||
|
n_layers=2,
|
||||||
|
n_pages=8,
|
||||||
|
page_size=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
head_dim=8,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
page_table = torch.tensor([[0]], dtype=torch.long)
|
||||||
|
k = torch.randn(1, 2, 2, 8)
|
||||||
|
v = torch.randn(1, 2, 2, 8)
|
||||||
|
|
||||||
|
cache.write(0, page_table, 0, k, v)
|
||||||
|
gk, gv = cache.gather(0, page_table, 2)
|
||||||
|
assert torch.allclose(gk, k)
|
||||||
|
|
||||||
|
|
||||||
|
def test_paged_cache_write_cross_page():
|
||||||
|
cache = PagedCache(
|
||||||
|
n_layers=1,
|
||||||
|
n_pages=8,
|
||||||
|
page_size=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
head_dim=8,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||||
|
k = torch.randn(1, 8, 2, 8)
|
||||||
|
v = torch.randn(1, 8, 2, 8)
|
||||||
|
|
||||||
|
cache.write(0, page_table, 0, k, v)
|
||||||
|
gk, gv = cache.gather(0, page_table, 8)
|
||||||
|
assert torch.allclose(gk, k)
|
||||||
|
|
||||||
|
|
||||||
|
def test_paged_cache_gather_truncates_to_total_len():
|
||||||
|
cache = PagedCache(
|
||||||
|
n_layers=1,
|
||||||
|
n_pages=8,
|
||||||
|
page_size=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
head_dim=8,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
page_table = torch.tensor([[0, 1]], dtype=torch.long)
|
||||||
|
k = torch.randn(1, 6, 2, 8)
|
||||||
|
v = torch.randn(1, 6, 2, 8)
|
||||||
|
cache.write(0, page_table, 0, k, v)
|
||||||
|
|
||||||
|
gk, gv = cache.gather(0, page_table, 5)
|
||||||
|
assert gk.shape == (1, 5, 2, 8)
|
||||||
|
|
||||||
|
|
||||||
|
def test_paged_cache_gather_clamps_negative_padding():
|
||||||
|
cache = PagedCache(
|
||||||
|
n_layers=1,
|
||||||
|
n_pages=8,
|
||||||
|
page_size=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
head_dim=8,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
page_table = torch.tensor([[0, -1]], dtype=torch.long)
|
||||||
|
gk, gv = cache.gather(0, page_table, 4)
|
||||||
|
assert gk.shape == (1, 4, 2, 8)
|
||||||
|
|
@ -0,0 +1,181 @@
|
||||||
|
"""Unit tests for _Result accumulator and InferenceEngine.generate()."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from astrai.inference.engine import _Result
|
||||||
|
from astrai.inference.task import STOP
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_append_single():
|
||||||
|
r = _Result(count=1)
|
||||||
|
r.append("hello", 0)
|
||||||
|
assert r.results[0] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_append_multiple_tasks():
|
||||||
|
r = _Result(count=3)
|
||||||
|
r.append("a", 0)
|
||||||
|
r.append("b", 1)
|
||||||
|
r.append("c", 2)
|
||||||
|
assert r.results[0] == "a"
|
||||||
|
assert r.results[1] == "b"
|
||||||
|
assert r.results[2] == "c"
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_stop_marks_complete():
|
||||||
|
r = _Result(count=2)
|
||||||
|
r.append("text", 0)
|
||||||
|
r.append(STOP, 0)
|
||||||
|
r.append("more", 1)
|
||||||
|
assert r._done[0] is True
|
||||||
|
assert r._done[1] is False
|
||||||
|
assert r._completed == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_stop_does_not_double_count():
|
||||||
|
r = _Result(count=1)
|
||||||
|
r.append(STOP, 0)
|
||||||
|
r.append(STOP, 0)
|
||||||
|
assert r._completed == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_pop_all_returns_and_clears():
|
||||||
|
r = _Result(count=2)
|
||||||
|
r.append("a", 0)
|
||||||
|
r.append("b", 1)
|
||||||
|
out = r.pop_all()
|
||||||
|
assert len(out) == 2
|
||||||
|
assert out[0] == (0, "a")
|
||||||
|
assert out[1] == (1, "b")
|
||||||
|
assert r.pop_all() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_wait_blocks_until_data():
|
||||||
|
r = _Result(count=1)
|
||||||
|
|
||||||
|
def delayed_append():
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
r.append("delayed", 0)
|
||||||
|
|
||||||
|
t = threading.Thread(target=delayed_append)
|
||||||
|
t.start()
|
||||||
|
ok = r.wait(timeout=5.0)
|
||||||
|
t.join()
|
||||||
|
assert ok
|
||||||
|
assert r.results[0] == "delayed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_wait_timeout():
|
||||||
|
r = _Result(count=1)
|
||||||
|
ok = r.wait(timeout=0.01)
|
||||||
|
assert not ok
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_wait_completion_non_streaming():
|
||||||
|
r = _Result(count=2)
|
||||||
|
|
||||||
|
def finish_later():
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
r.append(STOP, 0)
|
||||||
|
time.sleep(0.05)
|
||||||
|
r.append(STOP, 1)
|
||||||
|
|
||||||
|
t = threading.Thread(target=finish_later)
|
||||||
|
t.start()
|
||||||
|
r.wait_completion()
|
||||||
|
t.join()
|
||||||
|
assert r._completed == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_result_get_results():
|
||||||
|
r = _Result(count=2)
|
||||||
|
r.append("hello", 0)
|
||||||
|
r.append("world", 1)
|
||||||
|
results = r.get_results()
|
||||||
|
assert results == ["hello", "world"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_engine_generate_non_streaming_single():
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1, 2, 3]
|
||||||
|
mock_tokenizer.decode.return_value = "response"
|
||||||
|
mock_tokenizer.stop_ids = [0]
|
||||||
|
|
||||||
|
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
||||||
|
instance = MockSched.return_value
|
||||||
|
|
||||||
|
def fake_add(prompt, **kw):
|
||||||
|
cb = kw["stream_callback"]
|
||||||
|
cb("response")
|
||||||
|
cb(STOP)
|
||||||
|
|
||||||
|
instance.add_task.side_effect = fake_add
|
||||||
|
instance.remove_task.return_value = []
|
||||||
|
|
||||||
|
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
|
||||||
|
result = eng.generate("hello")
|
||||||
|
assert result == "response"
|
||||||
|
|
||||||
|
|
||||||
|
def test_engine_generate_streaming_yields_tokens():
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1, 2, 3]
|
||||||
|
mock_tokenizer.decode.return_value = "tok"
|
||||||
|
mock_tokenizer.stop_ids = [0]
|
||||||
|
|
||||||
|
callbacks_saved = []
|
||||||
|
|
||||||
|
def capture_cb(prompt, **kw):
|
||||||
|
callbacks_saved.append(kw.get("stream_callback"))
|
||||||
|
|
||||||
|
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
||||||
|
instance = MockSched.return_value
|
||||||
|
instance.add_task.side_effect = capture_cb
|
||||||
|
instance.remove_task.return_value = []
|
||||||
|
|
||||||
|
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
|
||||||
|
gen = eng.generate("hello", stream=True)
|
||||||
|
|
||||||
|
cb = callbacks_saved[0]
|
||||||
|
cb("t1", 0)
|
||||||
|
cb("t2", 0)
|
||||||
|
cb(STOP, 0)
|
||||||
|
|
||||||
|
tokens = list(gen)
|
||||||
|
assert tokens == ["t1", "t2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_engine_generate_non_streaming_batch():
|
||||||
|
from astrai.inference.engine import InferenceEngine
|
||||||
|
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_tokenizer = MagicMock()
|
||||||
|
mock_tokenizer.encode.return_value = [1, 2, 3]
|
||||||
|
mock_tokenizer.decode.return_value = "r"
|
||||||
|
mock_tokenizer.stop_ids = [0]
|
||||||
|
|
||||||
|
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
||||||
|
instance = MockSched.return_value
|
||||||
|
|
||||||
|
def fake_add(prompt, **kw):
|
||||||
|
cb = kw["stream_callback"]
|
||||||
|
cb("r")
|
||||||
|
cb(STOP)
|
||||||
|
|
||||||
|
instance.add_task.side_effect = fake_add
|
||||||
|
instance.remove_task.return_value = []
|
||||||
|
|
||||||
|
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=2)
|
||||||
|
results = eng.generate(["hello", "world"])
|
||||||
|
assert results == ["r", "r"]
|
||||||
|
|
@ -0,0 +1,127 @@
|
||||||
|
"""Unit tests for inference sampling strategies."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.inference.sample import (
|
||||||
|
SamplingPipeline,
|
||||||
|
TemperatureStrategy,
|
||||||
|
TopKStrategy,
|
||||||
|
TopPStrategy,
|
||||||
|
sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_temperature_scalar():
|
||||||
|
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||||
|
s = TemperatureStrategy(0.5)
|
||||||
|
result = s.apply(logits.clone())
|
||||||
|
assert torch.allclose(result, logits / 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_temperature_skip_when_one():
|
||||||
|
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||||
|
s = TemperatureStrategy(1.0)
|
||||||
|
result = s.apply(logits.clone())
|
||||||
|
assert torch.equal(result, logits)
|
||||||
|
|
||||||
|
|
||||||
|
def test_temperature_per_sample_tensor():
|
||||||
|
logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
|
||||||
|
s = TemperatureStrategy(torch.tensor([0.5, 0.5]))
|
||||||
|
result = s.apply(logits.clone())
|
||||||
|
assert torch.allclose(result, logits / 0.5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_k_keeps_top():
|
||||||
|
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
|
||||||
|
s = TopKStrategy(top_k=2)
|
||||||
|
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||||
|
kept = (result > -1e9).sum().item()
|
||||||
|
assert kept == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_k_skip_when_zero():
|
||||||
|
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||||
|
s = TopKStrategy(top_k=0)
|
||||||
|
result = s.apply(logits.clone())
|
||||||
|
assert torch.equal(result, logits)
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_k_batch_tensor():
|
||||||
|
"""When top_k is a batch tensor, max element governs k for all rows."""
|
||||||
|
logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]])
|
||||||
|
s = TopKStrategy(top_k=torch.tensor([2, 1]))
|
||||||
|
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||||
|
assert (result[0] > -1e9).sum() == 2
|
||||||
|
assert (result[1] > -1e9).sum() == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_p_nucleus_filtering():
|
||||||
|
logits = torch.tensor([[10.0, 1.0, 1.0, 1.0, 1.0]])
|
||||||
|
s = TopPStrategy(top_p=0.5)
|
||||||
|
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||||
|
kept = (result > -1e9).sum().item()
|
||||||
|
assert kept >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_p_skip_when_one():
|
||||||
|
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||||
|
s = TopPStrategy(top_p=1.0)
|
||||||
|
result = s.apply(logits.clone())
|
||||||
|
assert torch.equal(result, logits)
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_p_filter_all_except_max_when_zero():
|
||||||
|
logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]])
|
||||||
|
s = TopPStrategy(top_p=0.0)
|
||||||
|
result = s.apply(logits.clone(), filter_value=-1e9)
|
||||||
|
kept = (result > -1e9).sum().item()
|
||||||
|
assert kept == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_pipeline_composes_strategies():
|
||||||
|
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||||
|
pipeline = SamplingPipeline(
|
||||||
|
[
|
||||||
|
TemperatureStrategy(0.8),
|
||||||
|
TopKStrategy(3),
|
||||||
|
TopPStrategy(0.95),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
result = pipeline.apply(logits.clone(), filter_value=-1e9)
|
||||||
|
kept = (result > -1e9).sum().item()
|
||||||
|
assert 1 <= kept <= 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_sampling_pipeline_sample_returns_valid_token():
|
||||||
|
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||||
|
pipeline = SamplingPipeline(
|
||||||
|
[
|
||||||
|
TemperatureStrategy(0.8),
|
||||||
|
TopKStrategy(3),
|
||||||
|
TopPStrategy(0.95),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tokens = pipeline.sample(logits)
|
||||||
|
assert tokens.shape == (1,)
|
||||||
|
assert 0 <= tokens[0] < logits.size(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_sample_shortcut():
|
||||||
|
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]])
|
||||||
|
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
|
||||||
|
assert tokens.shape == (1,)
|
||||||
|
assert 0 <= tokens[0] < logits.size(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_sample_batch():
|
||||||
|
logits = torch.tensor(
|
||||||
|
[
|
||||||
|
[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||||
|
[5.0, 4.0, 3.0, 2.0, 1.0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95)
|
||||||
|
assert tokens.shape == (2,)
|
||||||
|
for t in tokens:
|
||||||
|
assert 0 <= t < logits.size(-1)
|
||||||
|
|
@ -0,0 +1,170 @@
|
||||||
|
"""Unit tests for Task and TaskManager."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
from astrai.inference.task import STOP, Task, TaskManager, TaskStatus
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_tokenizer():
|
||||||
|
t = MagicMock()
|
||||||
|
t.encode.return_value = [1, 2, 3, 4, 5]
|
||||||
|
t.stop_ids = [0]
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_default_status_is_pending():
|
||||||
|
task = Task("id1", [1, 2, 3])
|
||||||
|
assert task.status == TaskStatus.PENDING
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_next_pos():
|
||||||
|
task = Task("id1", [1, 2, 3])
|
||||||
|
task.input_tokens = 5
|
||||||
|
assert task.next_pos == 5
|
||||||
|
task.output_ids.append(4)
|
||||||
|
assert task.next_pos == 6
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_is_finished_max_tokens():
|
||||||
|
task = Task("id1", [1, 2, 3], max_tokens=2)
|
||||||
|
task.output_tokens = 2
|
||||||
|
assert task.is_finished([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_is_finished_stop_id():
|
||||||
|
task = Task("id1", [1, 2, 3])
|
||||||
|
task.output_ids = [5, 0]
|
||||||
|
assert task.is_finished([0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_is_finished_not_yet():
|
||||||
|
task = Task("id1", [1, 2, 3], max_tokens=10)
|
||||||
|
task.output_ids = [1, 2]
|
||||||
|
assert not task.is_finished([0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_add_task():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tid = tm.add_task("hello")
|
||||||
|
assert tid.startswith("task_")
|
||||||
|
assert tm._total_tasks == 1
|
||||||
|
assert len(tm.waiting_queue) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_add_task_too_long_immediate_stop():
|
||||||
|
t = _make_mock_tokenizer()
|
||||||
|
t.encode.return_value = list(range(9000))
|
||||||
|
cb_calls = []
|
||||||
|
|
||||||
|
tm = TaskManager(tokenizer=t, max_seq_len=16)
|
||||||
|
tm.add_task("long", stream_callback=lambda tok: cb_calls.append(tok))
|
||||||
|
assert cb_calls[0] is STOP
|
||||||
|
assert len(tm.waiting_queue) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_remove_task():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tid = tm.add_task("test")
|
||||||
|
tm.remove_task(tid)
|
||||||
|
assert len(tm.waiting_queue) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_remove_active_task():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tid = tm.add_task("test")
|
||||||
|
tasks = tm.pull_candidates(1)
|
||||||
|
tm.activate(tasks[0])
|
||||||
|
assert len(tm.active_tasks) == 1
|
||||||
|
removed = tm.remove_task(tid)
|
||||||
|
assert len(removed) == 1
|
||||||
|
assert len(tm.active_tasks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_pull_candidates_fifo():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tm.add_task("a")
|
||||||
|
tm.add_task("b")
|
||||||
|
tm.add_task("c")
|
||||||
|
pulled = tm.pull_candidates(2)
|
||||||
|
assert len(pulled) == 2
|
||||||
|
assert pulled[0].prompt_ids == [1, 2, 3, 4, 5]
|
||||||
|
assert len(tm.waiting_queue) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_activate():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tm.add_task("test")
|
||||||
|
task = tm.pull_candidates(1)[0]
|
||||||
|
tm.activate(task)
|
||||||
|
assert task.status == TaskStatus.RUNNING
|
||||||
|
assert task in tm.active_tasks
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_return_to_waiting():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tm.add_task("a")
|
||||||
|
tm.add_task("b")
|
||||||
|
t1 = tm.pull_candidates(1)[0]
|
||||||
|
tm.return_to_waiting([t1])
|
||||||
|
assert len(tm.waiting_queue) == 2
|
||||||
|
assert tm.waiting_queue[0] == t1
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_remove_finished_aborted():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tm.add_task("test")
|
||||||
|
task = tm.pull_candidates(1)[0]
|
||||||
|
tm.activate(task)
|
||||||
|
task.status = TaskStatus.ABORTED
|
||||||
|
finished = tm.remove_finished_tasks([0])
|
||||||
|
assert len(finished) == 1
|
||||||
|
assert len(tm.active_tasks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_remove_finished_stop_id():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tm.add_task("test")
|
||||||
|
task = tm.pull_candidates(1)[0]
|
||||||
|
tm.activate(task)
|
||||||
|
task.output_ids = [0]
|
||||||
|
task.output_tokens = 1
|
||||||
|
finished = tm.remove_finished_tasks([0])
|
||||||
|
assert len(finished) == 1
|
||||||
|
assert task.status == TaskStatus.FINISHED
|
||||||
|
assert len(tm.active_tasks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_has_work():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
assert not tm.has_work()
|
||||||
|
tm.add_task("test")
|
||||||
|
assert tm.has_work()
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_wake():
|
||||||
|
import threading
|
||||||
|
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
called = threading.Event()
|
||||||
|
|
||||||
|
def waiter():
|
||||||
|
tm.wait_for_tasks(timeout=5.0)
|
||||||
|
called.set()
|
||||||
|
|
||||||
|
t = threading.Thread(target=waiter)
|
||||||
|
t.start()
|
||||||
|
import time
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
tm.wake()
|
||||||
|
t.join(timeout=2.0)
|
||||||
|
assert called.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_manager_get_stats():
|
||||||
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
||||||
|
tm.add_task("test")
|
||||||
|
stats = tm.get_stats()
|
||||||
|
assert stats["total_tasks"] == 1
|
||||||
|
assert stats["waiting_queue"] == 1
|
||||||
|
assert stats["active_tasks"] == 0
|
||||||
Loading…
Reference in New Issue