AstrAI/tests/inference/test_scheduler.py

194 lines
5.9 KiB
Python

"""Tests for scheduler concurrency."""
import threading
from unittest.mock import MagicMock, patch
import pytest
import torch
from astrai.inference import InferenceScheduler
@pytest.fixture
def mock_model_and_tokenizer():
"""Create mock model and tokenizer."""
mock_model = MagicMock()
mock_model.config = MagicMock()
mock_model.config.n_kv_heads = 8
mock_model.config.n_heads = 8
mock_model.config.dim = 128
mock_model.config.n_layers = 2
mock_model.config.max_len = 100
mock_model.parameters.return_value = iter(
[MagicMock(dtype=torch.float32, device=torch.device("cpu"))]
)
mock_tokenizer = MagicMock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
mock_tokenizer.decode.return_value = "token"
mock_tokenizer.stop_ids = [0]
mock_tokenizer.pad_id = None
return mock_model, mock_tokenizer
def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
"""Test concurrent add_task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"task_ids": [], "errors": []}
lock = threading.Lock()
def add_task_worker(worker_id):
try:
for i in range(10):
task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}")
with lock:
results["task_ids"].append(task_id)
except Exception as e:
results["errors"].append(str(e))
threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["task_ids"]) == 50
def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
"""Test concurrent add and remove task operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"added": [], "removed": [], "errors": []}
add_ready = threading.Event()
def add_worker():
try:
for i in range(20):
task_id = scheduler.add_task(f"prompt {i}")
results["added"].append(task_id)
if len(results["added"]) >= 10:
add_ready.set()
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def remove_worker():
try:
add_ready.wait(timeout=5.0)
for task_id in results["added"][:10]:
scheduler.remove_task(task_id)
results["removed"].append(task_id)
except Exception as e:
results["errors"].append(f"Remove: {str(e)}")
add_thread = threading.Thread(target=add_worker)
remove_thread = threading.Thread(target=remove_worker)
add_thread.start()
remove_thread.start()
add_thread.join()
remove_thread.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["added"]) == 20
def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
"""Test concurrent get_stats operations."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
results = {"stats": [], "errors": []}
started = threading.Event()
stats_done = threading.Event()
def add_tasks():
try:
for i in range(20):
scheduler.add_task(f"prompt {i}")
started.set()
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def get_stats():
try:
started.wait(timeout=5.0)
for _ in range(50):
stats = scheduler.get_stats()
results["stats"].append(stats)
stats_done.set()
except Exception as e:
results["errors"].append(f"Get stats: {str(e)}")
add_thread = threading.Thread(target=add_tasks)
stats_thread = threading.Thread(target=get_stats)
add_thread.start()
stats_thread.start()
add_thread.join()
stats_done.wait(timeout=5.0)
scheduler.stop()
stats_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["stats"]) == 50
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0
def test_prefill_skips_fully_cached_tasks(mock_model_and_tokenizer):
"""Tasks whose entire prompt is cached skip the prefill phase."""
mock_model, mock_tokenizer = mock_model_and_tokenizer
with patch("astrai.inference.core.scheduler.AutoModel"):
with patch("astrai.inference.core.scheduler.AutoTokenizer"):
scheduler = InferenceScheduler(
model=mock_model,
tokenizer=mock_tokenizer,
max_batch_size=4,
device="cpu",
)
task_id = scheduler.add_task("short prompt", stream_callback=lambda t: None)
scheduler.stop()
assert task_id.startswith("task_")