diff --git a/tests/conftest.py b/tests/conftest.py index 272e7e9..3f17c12 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,9 +3,7 @@ import os import shutil import tempfile -import numpy as np import pytest -import safetensors.torch as st import torch from tokenizers import Tokenizer, models, pre_tokenizers, trainers from torch.utils.data import Dataset @@ -15,6 +13,12 @@ from astrai.model.transformer import Transformer from astrai.tokenize import AutoTokenizer +def pytest_configure(config): + config.addinivalue_line("markers", "slow: marks tests as slow") + config.addinivalue_line("markers", "integration: integration tests") + config.addinivalue_line("markers", "unit: fast unit tests") + + def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer: """Create a simple tokenizer for testing purposes.""" tokenizer = Tokenizer(models.BPE()) @@ -22,7 +26,6 @@ def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer: trainer = trainers.BpeTrainer( vocab_size=vocab_size, min_frequency=1, special_tokens=["", ""] ) - # Train on empty iterator with single character tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer) auto_tokenizer = AutoTokenizer() auto_tokenizer._tokenizer = tokenizer @@ -34,7 +37,7 @@ class RandomDataset(Dataset): """Random dataset for testing purposes.""" def __init__(self, length=None, max_length=64, vocab_size=1000): - self.length = length or int(np.random.randint(100, 200)) + self.length = length or int(torch.randint(100, 200, (1,)).item()) self.max_length = max_length self.vocab_size = vocab_size @@ -52,7 +55,7 @@ class MultiTurnDataset(Dataset): """Multi-turn dataset with loss mask for SFT training tests.""" def __init__(self, length=None, max_length=64, vocab_size=1000): - self.length = length or int(np.random.randint(100, 200)) + self.length = length or int(torch.randint(100, 200, (1,)).item()) self.max_length = max_length self.vocab_size = vocab_size @@ -93,46 +96,65 @@ class EarlyStoppingDataset(Dataset): } -@pytest.fixture -def base_test_env(request: pytest.FixtureRequest): - """Create base test environment with randomly configured model and tokenizer""" - func_name = request.function.__name__ - test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") - config_path = os.path.join(test_dir, "config.json") +@pytest.fixture(scope="session") +def test_tokenizer(): + """Session-scoped tokenizer, created once for the entire test run.""" + return create_test_tokenizer() - n_dim_choices = [8, 16, 32] - n_head_choices = [2, 4] - dim = int(np.random.choice(n_dim_choices)) - n_heads = int(np.random.choice(n_head_choices)) - n_kv_heads = n_heads // 2 - dim_ffn = dim * 2 +@pytest.fixture(scope="session") +def test_model(): + """Session-scoped small Transformer model, created once.""" + config = ModelConfig( + vocab_size=1000, + dim=16, + n_heads=4, + n_kv_heads=2, + dim_ffn=32, + max_len=1024, + n_layers=4, + norm_eps=1e-5, + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = Transformer(config).to(device=device) - config = { - "vocab_size": 1000, - "dim": dim, - "n_heads": n_heads, - "n_kv_heads": n_kv_heads, - "dim_ffn": dim_ffn, - "max_len": 1024, - "n_layers": 4, - "norm_eps": 1e-5, + return { + "model": model, + "device": device, + "config": config, } + +@pytest.fixture +def base_test_env(test_model, test_tokenizer): + """Function-scoped test environment with isolated temp directory. + + Composes session-scoped model and tokenizer with a per-test temp dir. + """ + test_dir = tempfile.mkdtemp() + config_path = os.path.join(test_dir, "config.json") with open(config_path, "w") as f: - json.dump(config, f) - device = "cuda" if torch.cuda.is_available() else "cpu" - transformer_config = ModelConfig().load(config_path) - model = Transformer(transformer_config).to(device=device) - tokenizer = create_test_tokenizer() + json.dump( + { + "vocab_size": 1000, + "dim": 16, + "n_heads": 4, + "n_kv_heads": 2, + "dim_ffn": 32, + "max_len": 1024, + "n_layers": 4, + "norm_eps": 1e-5, + }, + f, + ) yield { - "device": device, + "device": test_model["device"], "test_dir": str(test_dir), "config_path": config_path, - "transformer_config": transformer_config, - "model": model, - "tokenizer": tokenizer, + "transformer_config": test_model["config"], + "model": test_model["model"], + "tokenizer": test_tokenizer, } shutil.rmtree(test_dir) @@ -154,43 +176,3 @@ def multi_turn_dataset(): def early_stopping_dataset(): dataset = EarlyStoppingDataset() yield dataset - - -@pytest.fixture -def test_env(request: pytest.FixtureRequest): - """Create a test environment with saved model and tokenizer files.""" - - func_name = request.function.__name__ - test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") - config_path = os.path.join(test_dir, "config.json") - tokenizer_path = os.path.join(test_dir, "tokenizer.json") - model_path = os.path.join(test_dir, "model.safetensors") - - config = { - "vocab_size": 1000, - "dim": 128, - "n_heads": 4, - "n_kv_heads": 2, - "dim_ffn": 256, - "max_len": 64, - "n_layers": 2, - "norm_eps": 1e-5, - } - with open(config_path, "w") as f: - json.dump(config, f) - - tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"]) - tokenizer.save(tokenizer_path) - - transformer_config = ModelConfig().load(config_path) - model = Transformer(transformer_config) - st.save_file(model.state_dict(), model_path) - - yield { - "test_dir": test_dir, - "model": model, - "tokenizer": tokenizer, - "transformer_config": transformer_config, - } - - shutil.rmtree(test_dir) diff --git a/tests/inference/test_scheduler_concurrency.py b/tests/inference/test_scheduler_concurrency.py index cc7a9d2..b4dee84 100644 --- a/tests/inference/test_scheduler_concurrency.py +++ b/tests/inference/test_scheduler_concurrency.py @@ -1,7 +1,6 @@ """Tests for scheduler concurrency.""" import threading -import time from unittest.mock import MagicMock, patch import pytest @@ -63,14 +62,11 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer): for t in threads: t.start() - # Let some tasks be processed - time.sleep(0.1) - - scheduler.stop() - for t in threads: t.join() + scheduler.stop() + assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["task_ids"]) == 50 @@ -89,19 +85,21 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer): ) results = {"added": [], "removed": [], "errors": []} + add_ready = threading.Event() def add_worker(): try: for i in range(20): task_id = scheduler.add_task(f"prompt {i}") results["added"].append(task_id) - time.sleep(0.001) + if len(results["added"]) >= 10: + add_ready.set() except Exception as e: results["errors"].append(f"Add: {str(e)}") def remove_worker(): try: - time.sleep(0.05) # Wait for some tasks to be added + add_ready.wait(timeout=5.0) for task_id in results["added"][:10]: scheduler.remove_task(task_id) results["removed"].append(task_id) @@ -114,11 +112,9 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer): add_thread.start() remove_thread.start() - time.sleep(0.2) - scheduler.stop() - add_thread.join() remove_thread.join() + scheduler.stop() assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["added"]) == 20 @@ -138,21 +134,24 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer): ) results = {"stats": [], "errors": []} + started = threading.Event() + stats_done = threading.Event() def add_tasks(): try: for i in range(20): scheduler.add_task(f"prompt {i}") - time.sleep(0.001) + started.set() except Exception as e: results["errors"].append(f"Add: {str(e)}") def get_stats(): try: + started.wait(timeout=5.0) for _ in range(50): stats = scheduler.get_stats() results["stats"].append(stats) - time.sleep(0.001) + stats_done.set() except Exception as e: results["errors"].append(f"Get stats: {str(e)}") @@ -162,16 +161,15 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer): add_thread.start() stats_thread.start() - time.sleep(0.3) + add_thread.join() + stats_done.wait(timeout=5.0) scheduler.stop() - add_thread.join() stats_thread.join() assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["stats"]) == 50 - # Verify stats are consistent for stats in results["stats"]: assert "total_tasks" in stats assert stats["total_tasks"] >= 0