176 lines
5.2 KiB
Python
176 lines
5.2 KiB
Python
"""Tests for scheduler concurrency."""
|
|
|
|
import threading
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from astrai.inference.scheduler import InferenceScheduler
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_model_and_tokenizer():
|
|
"""Create mock model and tokenizer."""
|
|
mock_model = MagicMock()
|
|
mock_model.config = MagicMock()
|
|
mock_model.config.n_kv_heads = 8
|
|
mock_model.config.n_heads = 8
|
|
mock_model.config.dim = 128
|
|
mock_model.config.n_layers = 2
|
|
mock_model.config.max_len = 100
|
|
mock_model.parameters.return_value = iter(
|
|
[MagicMock(dtype=torch.float32, device=torch.device("cpu"))]
|
|
)
|
|
|
|
mock_tokenizer = MagicMock()
|
|
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5]
|
|
mock_tokenizer.decode.return_value = "token"
|
|
mock_tokenizer.stop_ids = [0]
|
|
mock_tokenizer.pad_id = None
|
|
|
|
return mock_model, mock_tokenizer
|
|
|
|
|
|
def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
|
|
"""Test concurrent add_task operations."""
|
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
|
|
|
with patch("astrai.inference.scheduler.AutoModel"):
|
|
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
|
scheduler = InferenceScheduler(
|
|
model=mock_model,
|
|
tokenizer=mock_tokenizer,
|
|
max_batch_size=4,
|
|
device="cpu",
|
|
)
|
|
|
|
results = {"task_ids": [], "errors": []}
|
|
lock = threading.Lock()
|
|
|
|
def add_task_worker(worker_id):
|
|
try:
|
|
for i in range(10):
|
|
task_id = scheduler.add_task(f"prompt from worker {worker_id}-{i}")
|
|
with lock:
|
|
results["task_ids"].append(task_id)
|
|
except Exception as e:
|
|
results["errors"].append(str(e))
|
|
|
|
threads = [threading.Thread(target=add_task_worker, args=(i,)) for i in range(5)]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
scheduler.stop()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
assert len(results["task_ids"]) == 50
|
|
|
|
|
|
def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
|
"""Test concurrent add and remove task operations."""
|
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
|
|
|
with patch("astrai.inference.scheduler.AutoModel"):
|
|
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
|
scheduler = InferenceScheduler(
|
|
model=mock_model,
|
|
tokenizer=mock_tokenizer,
|
|
max_batch_size=4,
|
|
device="cpu",
|
|
)
|
|
|
|
results = {"added": [], "removed": [], "errors": []}
|
|
add_ready = threading.Event()
|
|
|
|
def add_worker():
|
|
try:
|
|
for i in range(20):
|
|
task_id = scheduler.add_task(f"prompt {i}")
|
|
results["added"].append(task_id)
|
|
if len(results["added"]) >= 10:
|
|
add_ready.set()
|
|
except Exception as e:
|
|
results["errors"].append(f"Add: {str(e)}")
|
|
|
|
def remove_worker():
|
|
try:
|
|
add_ready.wait(timeout=5.0)
|
|
for task_id in results["added"][:10]:
|
|
scheduler.remove_task(task_id)
|
|
results["removed"].append(task_id)
|
|
except Exception as e:
|
|
results["errors"].append(f"Remove: {str(e)}")
|
|
|
|
add_thread = threading.Thread(target=add_worker)
|
|
remove_thread = threading.Thread(target=remove_worker)
|
|
|
|
add_thread.start()
|
|
remove_thread.start()
|
|
|
|
add_thread.join()
|
|
remove_thread.join()
|
|
scheduler.stop()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
assert len(results["added"]) == 20
|
|
|
|
|
|
def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
|
"""Test concurrent get_stats operations."""
|
|
mock_model, mock_tokenizer = mock_model_and_tokenizer
|
|
|
|
with patch("astrai.inference.scheduler.AutoModel"):
|
|
with patch("astrai.inference.scheduler.AutoTokenizer"):
|
|
scheduler = InferenceScheduler(
|
|
model=mock_model,
|
|
tokenizer=mock_tokenizer,
|
|
max_batch_size=4,
|
|
device="cpu",
|
|
)
|
|
|
|
results = {"stats": [], "errors": []}
|
|
started = threading.Event()
|
|
stats_done = threading.Event()
|
|
|
|
def add_tasks():
|
|
try:
|
|
for i in range(20):
|
|
scheduler.add_task(f"prompt {i}")
|
|
started.set()
|
|
except Exception as e:
|
|
results["errors"].append(f"Add: {str(e)}")
|
|
|
|
def get_stats():
|
|
try:
|
|
started.wait(timeout=5.0)
|
|
for _ in range(50):
|
|
stats = scheduler.get_stats()
|
|
results["stats"].append(stats)
|
|
stats_done.set()
|
|
except Exception as e:
|
|
results["errors"].append(f"Get stats: {str(e)}")
|
|
|
|
add_thread = threading.Thread(target=add_tasks)
|
|
stats_thread = threading.Thread(target=get_stats)
|
|
|
|
add_thread.start()
|
|
stats_thread.start()
|
|
|
|
add_thread.join()
|
|
stats_done.wait(timeout=5.0)
|
|
scheduler.stop()
|
|
|
|
stats_thread.join()
|
|
|
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
|
assert len(results["stats"]) == 50
|
|
|
|
for stats in results["stats"]:
|
|
assert "total_tasks" in stats
|
|
assert stats["total_tasks"] >= 0
|