182 lines
4.6 KiB
Python
182 lines
4.6 KiB
Python
"""Unit tests for _Result accumulator and InferenceEngine.generate()."""
|
|
|
|
import threading
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from astrai.inference.engine import _Result
|
|
from astrai.inference.task import STOP
|
|
|
|
|
|
def test_result_append_single():
|
|
r = _Result(count=1)
|
|
r.append("hello", 0)
|
|
assert r.results[0] == "hello"
|
|
|
|
|
|
def test_result_append_multiple_tasks():
|
|
r = _Result(count=3)
|
|
r.append("a", 0)
|
|
r.append("b", 1)
|
|
r.append("c", 2)
|
|
assert r.results[0] == "a"
|
|
assert r.results[1] == "b"
|
|
assert r.results[2] == "c"
|
|
|
|
|
|
def test_result_stop_marks_complete():
|
|
r = _Result(count=2)
|
|
r.append("text", 0)
|
|
r.append(STOP, 0)
|
|
r.append("more", 1)
|
|
assert r._done[0] is True
|
|
assert r._done[1] is False
|
|
assert r._completed == 1
|
|
|
|
|
|
def test_result_stop_does_not_double_count():
|
|
r = _Result(count=1)
|
|
r.append(STOP, 0)
|
|
r.append(STOP, 0)
|
|
assert r._completed == 1
|
|
|
|
|
|
def test_result_pop_all_returns_and_clears():
|
|
r = _Result(count=2)
|
|
r.append("a", 0)
|
|
r.append("b", 1)
|
|
out = r.pop_all()
|
|
assert len(out) == 2
|
|
assert out[0] == (0, "a")
|
|
assert out[1] == (1, "b")
|
|
assert r.pop_all() == []
|
|
|
|
|
|
def test_result_wait_blocks_until_data():
|
|
r = _Result(count=1)
|
|
|
|
def delayed_append():
|
|
import time
|
|
|
|
time.sleep(0.05)
|
|
r.append("delayed", 0)
|
|
|
|
t = threading.Thread(target=delayed_append)
|
|
t.start()
|
|
ok = r.wait(timeout=5.0)
|
|
t.join()
|
|
assert ok
|
|
assert r.results[0] == "delayed"
|
|
|
|
|
|
def test_result_wait_timeout():
|
|
r = _Result(count=1)
|
|
ok = r.wait(timeout=0.01)
|
|
assert not ok
|
|
|
|
|
|
def test_result_wait_completion_non_streaming():
|
|
r = _Result(count=2)
|
|
|
|
def finish_later():
|
|
import time
|
|
|
|
time.sleep(0.05)
|
|
r.append(STOP, 0)
|
|
time.sleep(0.05)
|
|
r.append(STOP, 1)
|
|
|
|
t = threading.Thread(target=finish_later)
|
|
t.start()
|
|
r.wait_completion()
|
|
t.join()
|
|
assert r._completed == 2
|
|
|
|
|
|
def test_result_get_results():
|
|
r = _Result(count=2)
|
|
r.append("hello", 0)
|
|
r.append("world", 1)
|
|
results = r.get_results()
|
|
assert results == ["hello", "world"]
|
|
|
|
|
|
def test_engine_generate_non_streaming_single():
|
|
from astrai.inference.engine import InferenceEngine
|
|
|
|
mock_model = MagicMock()
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1, 2, 3]
|
|
mock_tokenizer.decode.return_value = "response"
|
|
mock_tokenizer.stop_ids = [0]
|
|
|
|
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
|
instance = MockSched.return_value
|
|
|
|
def fake_add(prompt, **kw):
|
|
cb = kw["stream_callback"]
|
|
cb("response")
|
|
cb(STOP)
|
|
|
|
instance.add_task.side_effect = fake_add
|
|
instance.remove_task.return_value = []
|
|
|
|
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
|
|
result = eng.generate("hello")
|
|
assert result == "response"
|
|
|
|
|
|
def test_engine_generate_streaming_yields_tokens():
|
|
from astrai.inference.engine import InferenceEngine
|
|
|
|
mock_model = MagicMock()
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1, 2, 3]
|
|
mock_tokenizer.decode.return_value = "tok"
|
|
mock_tokenizer.stop_ids = [0]
|
|
|
|
callbacks_saved = []
|
|
|
|
def capture_cb(prompt, **kw):
|
|
callbacks_saved.append(kw.get("stream_callback"))
|
|
|
|
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
|
instance = MockSched.return_value
|
|
instance.add_task.side_effect = capture_cb
|
|
instance.remove_task.return_value = []
|
|
|
|
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=1)
|
|
gen = eng.generate("hello", stream=True)
|
|
|
|
cb = callbacks_saved[0]
|
|
cb("t1", 0)
|
|
cb("t2", 0)
|
|
cb(STOP, 0)
|
|
|
|
tokens = list(gen)
|
|
assert tokens == ["t1", "t2"]
|
|
|
|
|
|
def test_engine_generate_non_streaming_batch():
|
|
from astrai.inference.engine import InferenceEngine
|
|
|
|
mock_model = MagicMock()
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1, 2, 3]
|
|
mock_tokenizer.decode.return_value = "r"
|
|
mock_tokenizer.stop_ids = [0]
|
|
|
|
with patch("astrai.inference.engine.InferenceScheduler") as MockSched:
|
|
instance = MockSched.return_value
|
|
|
|
def fake_add(prompt, **kw):
|
|
cb = kw["stream_callback"]
|
|
cb("r")
|
|
cb(STOP)
|
|
|
|
instance.add_task.side_effect = fake_add
|
|
instance.remove_task.return_value = []
|
|
|
|
eng = InferenceEngine(mock_model, mock_tokenizer, max_batch_size=2)
|
|
results = eng.generate(["hello", "world"])
|
|
assert results == ["r", "r"]
|