171 lines
4.4 KiB
Python
171 lines
4.4 KiB
Python
"""Unit tests for Task and TaskManager."""
|
|
|
|
from unittest.mock import MagicMock
|
|
|
|
from astrai.inference import STOP, Task, TaskManager, TaskStatus
|
|
|
|
|
|
def _make_mock_tokenizer():
|
|
t = MagicMock()
|
|
t.encode.return_value = [1, 2, 3, 4, 5]
|
|
t.stop_ids = [0]
|
|
return t
|
|
|
|
|
|
def test_task_default_status_is_pending():
|
|
task = Task("id1", [1, 2, 3])
|
|
assert task.status == TaskStatus.PENDING
|
|
|
|
|
|
def test_task_next_pos():
|
|
task = Task("id1", [1, 2, 3])
|
|
task.input_tokens = 5
|
|
assert task.next_pos == 5
|
|
task.output_ids.append(4)
|
|
assert task.next_pos == 6
|
|
|
|
|
|
def test_task_is_finished_max_tokens():
|
|
task = Task("id1", [1, 2, 3], max_tokens=2)
|
|
task.output_tokens = 2
|
|
assert task.is_finished([])
|
|
|
|
|
|
def test_task_is_finished_stop_id():
|
|
task = Task("id1", [1, 2, 3])
|
|
task.output_ids = [5, 0]
|
|
assert task.is_finished([0])
|
|
|
|
|
|
def test_task_is_finished_not_yet():
|
|
task = Task("id1", [1, 2, 3], max_tokens=10)
|
|
task.output_ids = [1, 2]
|
|
assert not task.is_finished([0])
|
|
|
|
|
|
def test_task_manager_add_task():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tid = tm.add_task("hello")
|
|
assert tid.startswith("task_")
|
|
assert tm._total_tasks == 1
|
|
assert len(tm.waiting_queue) == 1
|
|
|
|
|
|
def test_task_manager_add_task_too_long_immediate_stop():
|
|
t = _make_mock_tokenizer()
|
|
t.encode.return_value = list(range(9000))
|
|
cb_calls = []
|
|
|
|
tm = TaskManager(tokenizer=t, max_seq_len=16)
|
|
tm.add_task("long", stream_callback=lambda tok: cb_calls.append(tok))
|
|
assert cb_calls[0] is STOP
|
|
assert len(tm.waiting_queue) == 0
|
|
|
|
|
|
def test_task_manager_remove_task():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tid = tm.add_task("test")
|
|
tm.remove_task(tid)
|
|
assert len(tm.waiting_queue) == 0
|
|
|
|
|
|
def test_task_manager_remove_active_task():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tid = tm.add_task("test")
|
|
tasks = tm.pull_candidates(1)
|
|
tm.activate(tasks[0])
|
|
assert len(tm.active_tasks) == 1
|
|
removed = tm.remove_task(tid)
|
|
assert len(removed) == 1
|
|
assert len(tm.active_tasks) == 0
|
|
|
|
|
|
def test_task_manager_pull_candidates_fifo():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tm.add_task("a")
|
|
tm.add_task("b")
|
|
tm.add_task("c")
|
|
pulled = tm.pull_candidates(2)
|
|
assert len(pulled) == 2
|
|
assert pulled[0].prompt_ids == [1, 2, 3, 4, 5]
|
|
assert len(tm.waiting_queue) == 1
|
|
|
|
|
|
def test_task_manager_activate():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tm.add_task("test")
|
|
task = tm.pull_candidates(1)[0]
|
|
tm.activate(task)
|
|
assert task.status == TaskStatus.RUNNING
|
|
assert task in tm.active_tasks
|
|
|
|
|
|
def test_task_manager_return_to_waiting():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tm.add_task("a")
|
|
tm.add_task("b")
|
|
t1 = tm.pull_candidates(1)[0]
|
|
tm.return_to_waiting([t1])
|
|
assert len(tm.waiting_queue) == 2
|
|
assert tm.waiting_queue[0] == t1
|
|
|
|
|
|
def test_task_manager_remove_finished_aborted():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tm.add_task("test")
|
|
task = tm.pull_candidates(1)[0]
|
|
tm.activate(task)
|
|
task.status = TaskStatus.ABORTED
|
|
finished = tm.remove_finished_tasks([0])
|
|
assert len(finished) == 1
|
|
assert len(tm.active_tasks) == 0
|
|
|
|
|
|
def test_task_manager_remove_finished_stop_id():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tm.add_task("test")
|
|
task = tm.pull_candidates(1)[0]
|
|
tm.activate(task)
|
|
task.output_ids = [0]
|
|
task.output_tokens = 1
|
|
finished = tm.remove_finished_tasks([0])
|
|
assert len(finished) == 1
|
|
assert task.status == TaskStatus.FINISHED
|
|
assert len(tm.active_tasks) == 0
|
|
|
|
|
|
def test_task_manager_has_work():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
assert not tm.has_work()
|
|
tm.add_task("test")
|
|
assert tm.has_work()
|
|
|
|
|
|
def test_task_manager_wake():
|
|
import threading
|
|
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
called = threading.Event()
|
|
|
|
def waiter():
|
|
tm.wait_for_tasks(timeout=5.0)
|
|
called.set()
|
|
|
|
t = threading.Thread(target=waiter)
|
|
t.start()
|
|
import time
|
|
|
|
time.sleep(0.05)
|
|
tm.wake()
|
|
t.join(timeout=2.0)
|
|
assert called.is_set()
|
|
|
|
|
|
def test_task_manager_get_stats():
|
|
tm = TaskManager(tokenizer=_make_mock_tokenizer())
|
|
tm.add_task("test")
|
|
stats = tm.get_stats()
|
|
assert stats["total_tasks"] == 1
|
|
assert stats["waiting_queue"] == 1
|
|
assert stats["active_tasks"] == 0
|