perf: 测试优化,model 改为 session 共享,scheduler 用 Event 替代 sleep

- 拆出 session-scoped test_tokenizer + test_model,14 次创建 → 1 次
- 删除无用 test_env fixture
- 固定模型维度,消除随机性
- 添加 pytest markers 配置
This commit is contained in:
ViperEkura 2026-05-12 11:33:02 +08:00
parent 5889179c54
commit 5203b7f53e
2 changed files with 71 additions and 91 deletions

View File

@ -3,9 +3,7 @@ import os
import shutil import shutil
import tempfile import tempfile
import numpy as np
import pytest import pytest
import safetensors.torch as st
import torch import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset from torch.utils.data import Dataset
@ -15,6 +13,12 @@ from astrai.model.transformer import Transformer
from astrai.tokenize import AutoTokenizer from astrai.tokenize import AutoTokenizer
def pytest_configure(config):
config.addinivalue_line("markers", "slow: marks tests as slow")
config.addinivalue_line("markers", "integration: integration tests")
config.addinivalue_line("markers", "unit: fast unit tests")
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer: def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
"""Create a simple tokenizer for testing purposes.""" """Create a simple tokenizer for testing purposes."""
tokenizer = Tokenizer(models.BPE()) tokenizer = Tokenizer(models.BPE())
@ -22,7 +26,6 @@ def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
trainer = trainers.BpeTrainer( trainer = trainers.BpeTrainer(
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"] vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
) )
# Train on empty iterator with single character
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer) tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
auto_tokenizer = AutoTokenizer() auto_tokenizer = AutoTokenizer()
auto_tokenizer._tokenizer = tokenizer auto_tokenizer._tokenizer = tokenizer
@ -34,7 +37,7 @@ class RandomDataset(Dataset):
"""Random dataset for testing purposes.""" """Random dataset for testing purposes."""
def __init__(self, length=None, max_length=64, vocab_size=1000): def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200)) self.length = length or int(torch.randint(100, 200, (1,)).item())
self.max_length = max_length self.max_length = max_length
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -52,7 +55,7 @@ class MultiTurnDataset(Dataset):
"""Multi-turn dataset with loss mask for SFT training tests.""" """Multi-turn dataset with loss mask for SFT training tests."""
def __init__(self, length=None, max_length=64, vocab_size=1000): def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200)) self.length = length or int(torch.randint(100, 200, (1,)).item())
self.max_length = max_length self.max_length = max_length
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -93,46 +96,65 @@ class EarlyStoppingDataset(Dataset):
} }
@pytest.fixture @pytest.fixture(scope="session")
def base_test_env(request: pytest.FixtureRequest): def test_tokenizer():
"""Create base test environment with randomly configured model and tokenizer""" """Session-scoped tokenizer, created once for the entire test run."""
func_name = request.function.__name__ return create_test_tokenizer()
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4]
dim = int(np.random.choice(n_dim_choices)) @pytest.fixture(scope="session")
n_heads = int(np.random.choice(n_head_choices)) def test_model():
n_kv_heads = n_heads // 2 """Session-scoped small Transformer model, created once."""
dim_ffn = dim * 2 config = ModelConfig(
vocab_size=1000,
dim=16,
n_heads=4,
n_kv_heads=2,
dim_ffn=32,
max_len=1024,
n_layers=4,
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
config = { return {
"vocab_size": 1000, "model": model,
"dim": dim, "device": device,
"n_heads": n_heads, "config": config,
"n_kv_heads": n_kv_heads,
"dim_ffn": dim_ffn,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5,
} }
@pytest.fixture
def base_test_env(test_model, test_tokenizer):
"""Function-scoped test environment with isolated temp directory.
Composes session-scoped model and tokenizer with a per-test temp dir.
"""
test_dir = tempfile.mkdtemp()
config_path = os.path.join(test_dir, "config.json")
with open(config_path, "w") as f: with open(config_path, "w") as f:
json.dump(config, f) json.dump(
device = "cuda" if torch.cuda.is_available() else "cpu" {
transformer_config = ModelConfig().load(config_path) "vocab_size": 1000,
model = Transformer(transformer_config).to(device=device) "dim": 16,
tokenizer = create_test_tokenizer() "n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 32,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5,
},
f,
)
yield { yield {
"device": device, "device": test_model["device"],
"test_dir": str(test_dir), "test_dir": str(test_dir),
"config_path": config_path, "config_path": config_path,
"transformer_config": transformer_config, "transformer_config": test_model["config"],
"model": model, "model": test_model["model"],
"tokenizer": tokenizer, "tokenizer": test_tokenizer,
} }
shutil.rmtree(test_dir) shutil.rmtree(test_dir)
@ -154,43 +176,3 @@ def multi_turn_dataset():
def early_stopping_dataset(): def early_stopping_dataset():
dataset = EarlyStoppingDataset() dataset = EarlyStoppingDataset()
yield dataset yield dataset
@pytest.fixture
def test_env(request: pytest.FixtureRequest):
"""Create a test environment with saved model and tokenizer files."""
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors")
config = {
"vocab_size": 1000,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
}
with open(config_path, "w") as f:
json.dump(config, f)
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path)
yield {
"test_dir": test_dir,
"model": model,
"tokenizer": tokenizer,
"transformer_config": transformer_config,
}
shutil.rmtree(test_dir)

View File

@ -1,7 +1,6 @@
"""Tests for scheduler concurrency.""" """Tests for scheduler concurrency."""
import threading import threading
import time
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -63,14 +62,11 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
for t in threads: for t in threads:
t.start() t.start()
# Let some tasks be processed
time.sleep(0.1)
scheduler.stop()
for t in threads: for t in threads:
t.join() t.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["task_ids"]) == 50 assert len(results["task_ids"]) == 50
@ -89,19 +85,21 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
) )
results = {"added": [], "removed": [], "errors": []} results = {"added": [], "removed": [], "errors": []}
add_ready = threading.Event()
def add_worker(): def add_worker():
try: try:
for i in range(20): for i in range(20):
task_id = scheduler.add_task(f"prompt {i}") task_id = scheduler.add_task(f"prompt {i}")
results["added"].append(task_id) results["added"].append(task_id)
time.sleep(0.001) if len(results["added"]) >= 10:
add_ready.set()
except Exception as e: except Exception as e:
results["errors"].append(f"Add: {str(e)}") results["errors"].append(f"Add: {str(e)}")
def remove_worker(): def remove_worker():
try: try:
time.sleep(0.05) # Wait for some tasks to be added add_ready.wait(timeout=5.0)
for task_id in results["added"][:10]: for task_id in results["added"][:10]:
scheduler.remove_task(task_id) scheduler.remove_task(task_id)
results["removed"].append(task_id) results["removed"].append(task_id)
@ -114,11 +112,9 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
add_thread.start() add_thread.start()
remove_thread.start() remove_thread.start()
time.sleep(0.2)
scheduler.stop()
add_thread.join() add_thread.join()
remove_thread.join() remove_thread.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["added"]) == 20 assert len(results["added"]) == 20
@ -138,21 +134,24 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
) )
results = {"stats": [], "errors": []} results = {"stats": [], "errors": []}
started = threading.Event()
stats_done = threading.Event()
def add_tasks(): def add_tasks():
try: try:
for i in range(20): for i in range(20):
scheduler.add_task(f"prompt {i}") scheduler.add_task(f"prompt {i}")
time.sleep(0.001) started.set()
except Exception as e: except Exception as e:
results["errors"].append(f"Add: {str(e)}") results["errors"].append(f"Add: {str(e)}")
def get_stats(): def get_stats():
try: try:
started.wait(timeout=5.0)
for _ in range(50): for _ in range(50):
stats = scheduler.get_stats() stats = scheduler.get_stats()
results["stats"].append(stats) results["stats"].append(stats)
time.sleep(0.001) stats_done.set()
except Exception as e: except Exception as e:
results["errors"].append(f"Get stats: {str(e)}") results["errors"].append(f"Get stats: {str(e)}")
@ -162,16 +161,15 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
add_thread.start() add_thread.start()
stats_thread.start() stats_thread.start()
time.sleep(0.3) add_thread.join()
stats_done.wait(timeout=5.0)
scheduler.stop() scheduler.stop()
add_thread.join()
stats_thread.join() stats_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}" assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["stats"]) == 50 assert len(results["stats"]) == 50
# Verify stats are consistent
for stats in results["stats"]: for stats in results["stats"]:
assert "total_tasks" in stats assert "total_tasks" in stats
assert stats["total_tasks"] >= 0 assert stats["total_tasks"] >= 0