From 7d4029c2a4b85483ec1d0af1e6f190b77ca98c9d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Tue, 12 May 2026 12:17:57 +0800 Subject: [PATCH] =?UTF-8?q?test:=20inference=20=E6=A8=A1=E5=9D=97=E8=A1=A5?= =?UTF-8?q?=E5=85=A8=E5=8D=95=E5=85=83=E6=B5=8B=E8=AF=95=EF=BC=8Ccache/sam?= =?UTF-8?q?ple/engine/task?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - test_cache: page_hash, PagePool, PrefixCache, TaskTable, PagedCache write/gather - test_sample: TemperatureStrategy, TopKStrategy, TopPStrategy, SamplingPipeline, sample() - test_engine: _Result 线程安全, generate stream/non-stream batch/single - test_task: Task 生命周期, TaskManager 队列操作 - 4 新文件, +771 行, 116 total tests --- tests/inference/test_cache.py | 293 +++++++++++++++++++++++++++++++++ tests/inference/test_engine.py | 181 ++++++++++++++++++++ tests/inference/test_sample.py | 127 ++++++++++++++ tests/inference/test_task.py | 170 +++++++++++++++++++ 4 files changed, 771 insertions(+) create mode 100644 tests/inference/test_cache.py create mode 100644 tests/inference/test_engine.py create mode 100644 tests/inference/test_sample.py create mode 100644 tests/inference/test_task.py diff --git a/tests/inference/test_cache.py b/tests/inference/test_cache.py new file mode 100644 index 0000000..f9c80d3 --- /dev/null +++ b/tests/inference/test_cache.py @@ -0,0 +1,293 @@ +"""Unit tests for inference cache components.""" + +import torch + +from astrai.inference.cache import ( + PagedCache, + PagePool, + PrefixCache, + TaskTable, + page_hash, +) + + +def test_page_hash_full_page(): + token_ids = list(range(256)) + h = page_hash(token_ids, 0, 64) + assert isinstance(h, int) + assert h >= 0 + + +def test_page_hash_different_page_differs(): + token_ids = list(range(256)) + assert page_hash(token_ids, 0, 64) != page_hash(token_ids, 1, 64) + + +def test_page_pool_alloc_free_cycle(): + pool = PagePool(n_pages=4) + a = pool.alloc() + b = pool.alloc() + assert a != b + pool.free(a) + pool.free(b) + c = pool.alloc() + assert c in (a, b) + + +def test_page_pool_alloc_when_full(): + pool = PagePool(n_pages=2) + pool.alloc() + pool.alloc() + assert pool.alloc() == -1 + + +def test_page_pool_lru_eviction(): + evicted = [] + + def on_evict(idx): + evicted.append(idx) + + pool = PagePool(n_pages=2, on_evict=on_evict) + p0 = pool.alloc() + p1 = pool.alloc() + pool.free(p0, keep_cached=True) + pool.free(p1, keep_cached=True) + pool.alloc() + assert len(evicted) == 1 + assert evicted[0] == p0 + + +def test_page_pool_inc_ref_and_free(): + pool = PagePool(n_pages=2) + p = pool.alloc() + pool.inc_ref(p) + assert pool._refs[p] == 2 + pool.free(p) + assert pool._refs[p] == 1 + pool.free(p) + assert pool._refs[p] == 0 + + +def test_page_pool_touch_moves_to_end(): + pool = PagePool(n_pages=4) + p0 = pool.alloc() + p1 = pool.alloc() + p2 = pool.alloc() + pool.free(p0, keep_cached=True) + pool.free(p1, keep_cached=True) + pool.free(p2, keep_cached=True) + assert next(iter(pool._lru)) == p0 + pool.touch(p0) + assert next(reversed(pool._lru)) == p0 + + +def test_page_pool_remove_from_lru(): + pool = PagePool(n_pages=4) + p0 = pool.alloc() + pool.free(p0, keep_cached=True) + assert p0 in pool._lru + pool.remove_from_lru(p0) + assert p0 not in pool._lru + + +def test_page_pool_keep_cached_realloc(): + """Free mask has priority over LRU; cached page returned only when no free pages.""" + pool = PagePool(n_pages=3) + p0 = pool.alloc() + p1 = pool.alloc() + p2 = pool.alloc() + pool.free(p0, keep_cached=True) + pool.free(p1, keep_cached=True) + pool.free(p2, keep_cached=True) + assert pool.alloc() == p0 + + +def _record_then_cache(pool, prefix, page, token_ids, logical_idx): + """Simulate the real lifecycle: record → ref stays >0, then free cached returns to LRU.""" + prefix.record(page, token_ids, logical_idx, pool) + pool.free(page, keep_cached=True) + + +def test_prefix_cache_lookup_returns_hits(): + token_ids = list(range(256)) + pool = PagePool(n_pages=16) + prefix = PrefixCache(page_size=64) + pages = [pool.alloc() for _ in range(4)] + for i, p in enumerate(pages): + _record_then_cache(pool, prefix, p, token_ids, i) + hits = prefix.lookup(token_ids, pool) + assert hits == pages + + +def test_prefix_cache_lookup_stops_at_first_miss(): + token_ids = list(range(256)) + pool = PagePool(n_pages=16) + prefix = PrefixCache(page_size=64) + p0 = pool.alloc() + _record_then_cache(pool, prefix, p0, token_ids, 0) + p1 = pool.alloc() + _record_then_cache(pool, prefix, p1, [99] * 64, 1) + hits = prefix.lookup(token_ids, pool) + assert len(hits) == 1 + assert hits[0] == p0 + + +def test_prefix_cache_ignores_partial_last_page(): + token_ids = list(range(100)) + pool = PagePool(n_pages=16) + prefix = PrefixCache(page_size=64) + p = pool.alloc() + _record_then_cache(pool, prefix, p, token_ids, 0) + hits = prefix.lookup(token_ids, pool) + assert len(hits) == 1 + + +def test_prefix_cache_on_evict_clears_mappings(): + pool = PagePool(n_pages=4) + prefix = PrefixCache(page_size=64) + p = pool.alloc() + _record_then_cache(pool, prefix, p, list(range(64)), 0) + assert prefix.has_page(p) + prefix.on_evict(p) + assert not prefix.has_page(p) + + +def test_prefix_cache_has_page(): + pool = PagePool(n_pages=4) + prefix = PrefixCache(page_size=64) + p = pool.alloc() + assert not prefix.has_page(p) + _record_then_cache(pool, prefix, p, list(range(64)), 0) + assert prefix.has_page(p) + + +def test_task_table_set_get(): + pool = PagePool(n_pages=8) + table = TaskTable(pool, page_size=64) + table.set("task1", [0, 1, 2], 128) + assert table.get("task1") == [0, 1, 2] + assert table.get_cached("task1") == 128 + + +def test_task_table_get_missing(): + pool = PagePool(n_pages=8) + table = TaskTable(pool, page_size=64) + assert table.get("nonexistent") == [] + assert table.get_cached("nonexistent") == 0 + + +def test_task_table_pop(): + pool = PagePool(n_pages=8) + table = TaskTable(pool, page_size=64) + table.set("task1", [0, 1], 64) + pages, cached = table.pop("task1") + assert pages == [0, 1] + assert cached == 64 + assert table.get("task1") == [] + + +def test_task_table_extend_allocates_pages(): + pool = PagePool(n_pages=8) + table = TaskTable(pool, page_size=64) + table.set("task1", [], 0) + ok = table.extend("task1", 200) + assert ok + assert len(table.get("task1")) == 4 + + +def test_task_table_extend_fails_when_pool_full(): + pool = PagePool(n_pages=2) + table = TaskTable(pool, page_size=64) + table.set("task1", [pool.alloc(), pool.alloc()], 0) + ok = table.extend("task1", 300) + assert not ok + + +def test_task_table_table_tensor(): + pool = PagePool(n_pages=16) + table = TaskTable(pool, page_size=64) + table.set("a", [0, 1], 0) + table.set("b", [2, 3, 4], 0) + t = table.table_tensor(["a", "b"], torch.device("cpu")) + assert t.shape == (2, 3) + assert t[0].tolist() == [0, 1, -1] + assert t[1].tolist() == [2, 3, 4] + + +def test_task_table_table_tensor_empty_input(): + pool = PagePool(n_pages=4) + table = TaskTable(pool, page_size=64) + t = table.table_tensor([], torch.device("cpu")) + assert t.numel() == 0 + + +def test_paged_cache_write_gather_single_page(): + cache = PagedCache( + n_layers=2, + n_pages=8, + page_size=4, + n_kv_heads=2, + head_dim=8, + device=torch.device("cpu"), + dtype=torch.float32, + ) + page_table = torch.tensor([[0]], dtype=torch.long) + k = torch.randn(1, 2, 2, 8) + v = torch.randn(1, 2, 2, 8) + + cache.write(0, page_table, 0, k, v) + gk, gv = cache.gather(0, page_table, 2) + assert torch.allclose(gk, k) + + +def test_paged_cache_write_cross_page(): + cache = PagedCache( + n_layers=1, + n_pages=8, + page_size=4, + n_kv_heads=2, + head_dim=8, + device=torch.device("cpu"), + dtype=torch.float32, + ) + page_table = torch.tensor([[0, 1]], dtype=torch.long) + k = torch.randn(1, 8, 2, 8) + v = torch.randn(1, 8, 2, 8) + + cache.write(0, page_table, 0, k, v) + gk, gv = cache.gather(0, page_table, 8) + assert torch.allclose(gk, k) + + +def test_paged_cache_gather_truncates_to_total_len(): + cache = PagedCache( + n_layers=1, + n_pages=8, + page_size=4, + n_kv_heads=2, + head_dim=8, + device=torch.device("cpu"), + dtype=torch.float32, + ) + page_table = torch.tensor([[0, 1]], dtype=torch.long) + k = torch.randn(1, 6, 2, 8) + v = torch.randn(1, 6, 2, 8) + cache.write(0, page_table, 0, k, v) + + gk, gv = cache.gather(0, page_table, 5) + assert gk.shape == (1, 5, 2, 8) + + +def test_paged_cache_gather_clamps_negative_padding(): + cache = PagedCache( + n_layers=1, + n_pages=8, + page_size=4, + n_kv_heads=2, + head_dim=8, + device=torch.device("cpu"), + dtype=torch.float32, + ) + page_table = torch.tensor([[0, -1]], dtype=torch.long) + gk, gv = cache.gather(0, page_table, 4) + assert gk.shape == (1, 4, 2, 8) diff --git a/tests/inference/test_engine.py b/tests/inference/test_engine.py new file mode 100644 index 0000000..11180ba --- /dev/null +++ b/tests/inference/test_engine.py @@ -0,0 +1,181 @@ +"""Unit tests for _Result accumulator and InferenceEngine.generate().""" + +import threading +from unittest.mock import MagicMock, patch + +from astrai.inference.engine import _Result +from astrai.inference.task import STOP + + +def test_result_append_single(): + r = _Result(count=1) + r.append("hello", 0) + assert r.results[0] == "hello" + + +def test_result_append_multiple_tasks(): + r = _Result(count=3) + r.append("a", 0) + r.append("b", 1) + r.append("c", 2) + assert r.results[0] == "a" + assert r.results[1] == "b" + assert r.results[2] == "c" + + +def test_result_stop_marks_complete(): + r = _Result(count=2) + r.append("text", 0) + r.append(STOP, 0) + r.append("more", 1) + assert r._done[0] is True + assert r._done[1] is False + assert r._completed == 1 + + +def test_result_stop_does_not_double_count(): + r = _Result(count=1) + r.append(STOP, 0) + r.append(STOP, 0) + assert r._completed == 1 + + +def test_result_pop_all_returns_and_clears(): + r = _Result(count=2) + r.append("a", 0) + r.append("b", 1) + out = r.pop_all() + assert len(out) == 2 + assert out[0] == (0, "a") + assert out[1] == (1, "b") + assert r.pop_all() == [] + + +def test_result_wait_blocks_until_data(): + r = _Result(count=1) + + def delayed_append(): + import time + + time.sleep(0.05) + r.append("delayed", 0) + + t = threading.Thread(target=delayed_append) + t.start() + ok = r.wait(timeout=5.0) + t.join() + assert ok + assert r.results[0] == "delayed" + + +def test_result_wait_timeout(): + r = _Result(count=1) + ok = r.wait(timeout=0.01) + assert not ok + + +def test_result_wait_completion_non_streaming(): + r = _Result(count=2) + + def finish_later(): + import time + + time.sleep(0.05) + r.append(STOP, 0) + time.sleep(0.05) + r.append(STOP, 1) + + t = threading.Thread(target=finish_later) + t.start() + r.wait_completion() + t.join() + assert r._completed == 2 + + +def test_result_get_results(): + r = _Result(count=2) + r.append("hello", 0) + r.append("world", 1) + results = r.get_results() + assert results == ["hello", "world"] + + +def test_engine_generate_non_streaming_single(): + from astrai.inference.engine import InferenceEngine + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1, 2, 3] + mock_tokenizer.decode.return_value = "response" + mock_tokenizer.stop_ids = [0] + + with patch("astrai.inference.engine.InferenceScheduler") as MockSched: + instance = MockSched.return_value + + def fake_add(prompt, **kw): + cb = kw["stream_callback"] + cb("response") + cb(STOP) + + instance.add_task.side_effect = fake_add + instance.remove_task.return_value = [] + + eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1) + result = eng.generate("hello") + assert result == "response" + + +def test_engine_generate_streaming_yields_tokens(): + from astrai.inference.engine import InferenceEngine + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1, 2, 3] + mock_tokenizer.decode.return_value = "tok" + mock_tokenizer.stop_ids = [0] + + callbacks_saved = [] + + def capture_cb(prompt, **kw): + callbacks_saved.append(kw.get("stream_callback")) + + with patch("astrai.inference.engine.InferenceScheduler") as MockSched: + instance = MockSched.return_value + instance.add_task.side_effect = capture_cb + instance.remove_task.return_value = [] + + eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1) + gen = eng.generate("hello", stream=True) + + cb = callbacks_saved[0] + cb("t1", 0) + cb("t2", 0) + cb(STOP, 0) + + tokens = list(gen) + assert tokens == ["t1", "t2"] + + +def test_engine_generate_non_streaming_batch(): + from astrai.inference.engine import InferenceEngine + + mock_model = MagicMock() + mock_tokenizer = MagicMock() + mock_tokenizer.encode.return_value = [1, 2, 3] + mock_tokenizer.decode.return_value = "r" + mock_tokenizer.stop_ids = [0] + + with patch("astrai.inference.engine.InferenceScheduler") as MockSched: + instance = MockSched.return_value + + def fake_add(prompt, **kw): + cb = kw["stream_callback"] + cb("r") + cb(STOP) + + instance.add_task.side_effect = fake_add + instance.remove_task.return_value = [] + + eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=2) + results = eng.generate(["hello", "world"]) + assert results == ["r", "r"] diff --git a/tests/inference/test_sample.py b/tests/inference/test_sample.py new file mode 100644 index 0000000..b5b9022 --- /dev/null +++ b/tests/inference/test_sample.py @@ -0,0 +1,127 @@ +"""Unit tests for inference sampling strategies.""" + +import torch + +from astrai.inference.sample import ( + SamplingPipeline, + TemperatureStrategy, + TopKStrategy, + TopPStrategy, + sample, +) + + +def test_temperature_scalar(): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + s = TemperatureStrategy(0.5) + result = s.apply(logits.clone()) + assert torch.allclose(result, logits / 0.5) + + +def test_temperature_skip_when_one(): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + s = TemperatureStrategy(1.0) + result = s.apply(logits.clone()) + assert torch.equal(result, logits) + + +def test_temperature_per_sample_tensor(): + logits = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + s = TemperatureStrategy(torch.tensor([0.5, 0.5])) + result = s.apply(logits.clone()) + assert torch.allclose(result, logits / 0.5) + + +def test_top_k_keeps_top(): + logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]]) + s = TopKStrategy(top_k=2) + result = s.apply(logits.clone(), filter_value=-1e9) + kept = (result > -1e9).sum().item() + assert kept == 2 + + +def test_top_k_skip_when_zero(): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + s = TopKStrategy(top_k=0) + result = s.apply(logits.clone()) + assert torch.equal(result, logits) + + +def test_top_k_batch_tensor(): + """When top_k is a batch tensor, max element governs k for all rows.""" + logits = torch.tensor([[0.1, 0.5, 0.3], [0.9, 0.2, 0.1]]) + s = TopKStrategy(top_k=torch.tensor([2, 1])) + result = s.apply(logits.clone(), filter_value=-1e9) + assert (result[0] > -1e9).sum() == 2 + assert (result[1] > -1e9).sum() == 2 + + +def test_top_p_nucleus_filtering(): + logits = torch.tensor([[10.0, 1.0, 1.0, 1.0, 1.0]]) + s = TopPStrategy(top_p=0.5) + result = s.apply(logits.clone(), filter_value=-1e9) + kept = (result > -1e9).sum().item() + assert kept >= 1 + + +def test_top_p_skip_when_one(): + logits = torch.tensor([[1.0, 2.0, 3.0]]) + s = TopPStrategy(top_p=1.0) + result = s.apply(logits.clone()) + assert torch.equal(result, logits) + + +def test_top_p_filter_all_except_max_when_zero(): + logits = torch.tensor([[0.1, 0.5, 0.3, 0.9, 0.2]]) + s = TopPStrategy(top_p=0.0) + result = s.apply(logits.clone(), filter_value=-1e9) + kept = (result > -1e9).sum().item() + assert kept == 1 + + +def test_sampling_pipeline_composes_strategies(): + logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + pipeline = SamplingPipeline( + [ + TemperatureStrategy(0.8), + TopKStrategy(3), + TopPStrategy(0.95), + ] + ) + result = pipeline.apply(logits.clone(), filter_value=-1e9) + kept = (result > -1e9).sum().item() + assert 1 <= kept <= 3 + + +def test_sampling_pipeline_sample_returns_valid_token(): + logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + pipeline = SamplingPipeline( + [ + TemperatureStrategy(0.8), + TopKStrategy(3), + TopPStrategy(0.95), + ] + ) + tokens = pipeline.sample(logits) + assert tokens.shape == (1,) + assert 0 <= tokens[0] < logits.size(-1) + + +def test_module_sample_shortcut(): + logits = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95) + assert tokens.shape == (1,) + assert 0 <= tokens[0] < logits.size(-1) + + +def test_module_sample_batch(): + logits = torch.tensor( + [ + [1.0, 2.0, 3.0, 4.0, 5.0], + [5.0, 4.0, 3.0, 2.0, 1.0], + ] + ) + tokens = sample(logits, temperature=0.8, top_k=3, top_p=0.95) + assert tokens.shape == (2,) + for t in tokens: + assert 0 <= t < logits.size(-1) diff --git a/tests/inference/test_task.py b/tests/inference/test_task.py new file mode 100644 index 0000000..916400f --- /dev/null +++ b/tests/inference/test_task.py @@ -0,0 +1,170 @@ +"""Unit tests for Task and TaskManager.""" + +from unittest.mock import MagicMock + +from astrai.inference.task import STOP, Task, TaskManager, TaskStatus + + +def _make_mock_tokenizer(): + t = MagicMock() + t.encode.return_value = [1, 2, 3, 4, 5] + t.stop_ids = [0] + return t + + +def test_task_default_status_is_pending(): + task = Task("id1", [1, 2, 3]) + assert task.status == TaskStatus.PENDING + + +def test_task_next_pos(): + task = Task("id1", [1, 2, 3]) + task.input_tokens = 5 + assert task.next_pos == 5 + task.output_ids.append(4) + assert task.next_pos == 6 + + +def test_task_is_finished_max_tokens(): + task = Task("id1", [1, 2, 3], max_tokens=2) + task.output_tokens = 2 + assert task.is_finished([]) + + +def test_task_is_finished_stop_id(): + task = Task("id1", [1, 2, 3]) + task.output_ids = [5, 0] + assert task.is_finished([0]) + + +def test_task_is_finished_not_yet(): + task = Task("id1", [1, 2, 3], max_tokens=10) + task.output_ids = [1, 2] + assert not task.is_finished([0]) + + +def test_task_manager_add_task(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tid = tm.add_task("hello") + assert tid.startswith("task_") + assert tm._total_tasks == 1 + assert len(tm.waiting_queue) == 1 + + +def test_task_manager_add_task_too_long_immediate_stop(): + t = _make_mock_tokenizer() + t.encode.return_value = list(range(9000)) + cb_calls = [] + + tm = TaskManager(tokenizer=t, max_seq_len=16) + tm.add_task("long", stream_callback=lambda tok: cb_calls.append(tok)) + assert cb_calls[0] is STOP + assert len(tm.waiting_queue) == 0 + + +def test_task_manager_remove_task(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tid = tm.add_task("test") + tm.remove_task(tid) + assert len(tm.waiting_queue) == 0 + + +def test_task_manager_remove_active_task(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tid = tm.add_task("test") + tasks = tm.pull_candidates(1) + tm.activate(tasks[0]) + assert len(tm.active_tasks) == 1 + removed = tm.remove_task(tid) + assert len(removed) == 1 + assert len(tm.active_tasks) == 0 + + +def test_task_manager_pull_candidates_fifo(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tm.add_task("a") + tm.add_task("b") + tm.add_task("c") + pulled = tm.pull_candidates(2) + assert len(pulled) == 2 + assert pulled[0].prompt_ids == [1, 2, 3, 4, 5] + assert len(tm.waiting_queue) == 1 + + +def test_task_manager_activate(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tm.add_task("test") + task = tm.pull_candidates(1)[0] + tm.activate(task) + assert task.status == TaskStatus.RUNNING + assert task in tm.active_tasks + + +def test_task_manager_return_to_waiting(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tm.add_task("a") + tm.add_task("b") + t1 = tm.pull_candidates(1)[0] + tm.return_to_waiting([t1]) + assert len(tm.waiting_queue) == 2 + assert tm.waiting_queue[0] == t1 + + +def test_task_manager_remove_finished_aborted(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tm.add_task("test") + task = tm.pull_candidates(1)[0] + tm.activate(task) + task.status = TaskStatus.ABORTED + finished = tm.remove_finished_tasks([0]) + assert len(finished) == 1 + assert len(tm.active_tasks) == 0 + + +def test_task_manager_remove_finished_stop_id(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tm.add_task("test") + task = tm.pull_candidates(1)[0] + tm.activate(task) + task.output_ids = [0] + task.output_tokens = 1 + finished = tm.remove_finished_tasks([0]) + assert len(finished) == 1 + assert task.status == TaskStatus.FINISHED + assert len(tm.active_tasks) == 0 + + +def test_task_manager_has_work(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + assert not tm.has_work() + tm.add_task("test") + assert tm.has_work() + + +def test_task_manager_wake(): + import threading + + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + called = threading.Event() + + def waiter(): + tm.wait_for_tasks(timeout=5.0) + called.set() + + t = threading.Thread(target=waiter) + t.start() + import time + + time.sleep(0.05) + tm.wake() + t.join(timeout=2.0) + assert called.is_set() + + +def test_task_manager_get_stats(): + tm = TaskManager(tokenizer=_make_mock_tokenizer()) + tm.add_task("test") + stats = tm.get_stats() + assert stats["total_tasks"] == 1 + assert stats["waiting_queue"] == 1 + assert stats["active_tasks"] == 0