perf: 测试优化,model 改为 session 共享,scheduler 用 Event 替代 sleep

- 拆出 session-scoped test_tokenizer + test_model,14 次创建 → 1 次
- 删除无用 test_env fixture
- 固定模型维度,消除随机性
- 添加 pytest markers 配置
This commit is contained in:
ViperEkura 2026-05-12 11:33:02 +08:00
parent 5889179c54
commit 5203b7f53e
2 changed files with 71 additions and 91 deletions

View File

@ -3,9 +3,7 @@ import os
import shutil
import tempfile
import numpy as np
import pytest
import safetensors.torch as st
import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset
@ -15,6 +13,12 @@ from astrai.model.transformer import Transformer
from astrai.tokenize import AutoTokenizer
def pytest_configure(config):
config.addinivalue_line("markers", "slow: marks tests as slow")
config.addinivalue_line("markers", "integration: integration tests")
config.addinivalue_line("markers", "unit: fast unit tests")
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
"""Create a simple tokenizer for testing purposes."""
tokenizer = Tokenizer(models.BPE())
@ -22,7 +26,6 @@ def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
trainer = trainers.BpeTrainer(
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
)
# Train on empty iterator with single character
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
auto_tokenizer = AutoTokenizer()
auto_tokenizer._tokenizer = tokenizer
@ -34,7 +37,7 @@ class RandomDataset(Dataset):
"""Random dataset for testing purposes."""
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))
self.length = length or int(torch.randint(100, 200, (1,)).item())
self.max_length = max_length
self.vocab_size = vocab_size
@ -52,7 +55,7 @@ class MultiTurnDataset(Dataset):
"""Multi-turn dataset with loss mask for SFT training tests."""
def __init__(self, length=None, max_length=64, vocab_size=1000):
self.length = length or int(np.random.randint(100, 200))
self.length = length or int(torch.randint(100, 200, (1,)).item())
self.max_length = max_length
self.vocab_size = vocab_size
@ -93,46 +96,65 @@ class EarlyStoppingDataset(Dataset):
}
@pytest.fixture(scope="session")
def test_tokenizer():
"""Session-scoped tokenizer, created once for the entire test run."""
return create_test_tokenizer()
@pytest.fixture(scope="session")
def test_model():
"""Session-scoped small Transformer model, created once."""
config = ModelConfig(
vocab_size=1000,
dim=16,
n_heads=4,
n_kv_heads=2,
dim_ffn=32,
max_len=1024,
n_layers=4,
norm_eps=1e-5,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
return {
"model": model,
"device": device,
"config": config,
}
@pytest.fixture
def base_test_env(request: pytest.FixtureRequest):
"""Create base test environment with randomly configured model and tokenizer"""
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
def base_test_env(test_model, test_tokenizer):
"""Function-scoped test environment with isolated temp directory.
Composes session-scoped model and tokenizer with a per-test temp dir.
"""
test_dir = tempfile.mkdtemp()
config_path = os.path.join(test_dir, "config.json")
n_dim_choices = [8, 16, 32]
n_head_choices = [2, 4]
dim = int(np.random.choice(n_dim_choices))
n_heads = int(np.random.choice(n_head_choices))
n_kv_heads = n_heads // 2
dim_ffn = dim * 2
config = {
with open(config_path, "w") as f:
json.dump(
{
"vocab_size": 1000,
"dim": dim,
"n_heads": n_heads,
"n_kv_heads": n_kv_heads,
"dim_ffn": dim_ffn,
"dim": 16,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 32,
"max_len": 1024,
"n_layers": 4,
"norm_eps": 1e-5,
}
with open(config_path, "w") as f:
json.dump(config, f)
device = "cuda" if torch.cuda.is_available() else "cpu"
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config).to(device=device)
tokenizer = create_test_tokenizer()
},
f,
)
yield {
"device": device,
"device": test_model["device"],
"test_dir": str(test_dir),
"config_path": config_path,
"transformer_config": transformer_config,
"model": model,
"tokenizer": tokenizer,
"transformer_config": test_model["config"],
"model": test_model["model"],
"tokenizer": test_tokenizer,
}
shutil.rmtree(test_dir)
@ -154,43 +176,3 @@ def multi_turn_dataset():
def early_stopping_dataset():
dataset = EarlyStoppingDataset()
yield dataset
@pytest.fixture
def test_env(request: pytest.FixtureRequest):
"""Create a test environment with saved model and tokenizer files."""
func_name = request.function.__name__
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
config_path = os.path.join(test_dir, "config.json")
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
model_path = os.path.join(test_dir, "model.safetensors")
config = {
"vocab_size": 1000,
"dim": 128,
"n_heads": 4,
"n_kv_heads": 2,
"dim_ffn": 256,
"max_len": 64,
"n_layers": 2,
"norm_eps": 1e-5,
}
with open(config_path, "w") as f:
json.dump(config, f)
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
tokenizer.save(tokenizer_path)
transformer_config = ModelConfig().load(config_path)
model = Transformer(transformer_config)
st.save_file(model.state_dict(), model_path)
yield {
"test_dir": test_dir,
"model": model,
"tokenizer": tokenizer,
"transformer_config": transformer_config,
}
shutil.rmtree(test_dir)

View File

@ -1,7 +1,6 @@
"""Tests for scheduler concurrency."""
import threading
import time
from unittest.mock import MagicMock, patch
import pytest
@ -63,14 +62,11 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
for t in threads:
t.start()
# Let some tasks be processed
time.sleep(0.1)
scheduler.stop()
for t in threads:
t.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["task_ids"]) == 50
@ -89,19 +85,21 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
)
results = {"added": [], "removed": [], "errors": []}
add_ready = threading.Event()
def add_worker():
try:
for i in range(20):
task_id = scheduler.add_task(f"prompt {i}")
results["added"].append(task_id)
time.sleep(0.001)
if len(results["added"]) >= 10:
add_ready.set()
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def remove_worker():
try:
time.sleep(0.05) # Wait for some tasks to be added
add_ready.wait(timeout=5.0)
for task_id in results["added"][:10]:
scheduler.remove_task(task_id)
results["removed"].append(task_id)
@ -114,11 +112,9 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
add_thread.start()
remove_thread.start()
time.sleep(0.2)
scheduler.stop()
add_thread.join()
remove_thread.join()
scheduler.stop()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["added"]) == 20
@ -138,21 +134,24 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
)
results = {"stats": [], "errors": []}
started = threading.Event()
stats_done = threading.Event()
def add_tasks():
try:
for i in range(20):
scheduler.add_task(f"prompt {i}")
time.sleep(0.001)
started.set()
except Exception as e:
results["errors"].append(f"Add: {str(e)}")
def get_stats():
try:
started.wait(timeout=5.0)
for _ in range(50):
stats = scheduler.get_stats()
results["stats"].append(stats)
time.sleep(0.001)
stats_done.set()
except Exception as e:
results["errors"].append(f"Get stats: {str(e)}")
@ -162,16 +161,15 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
add_thread.start()
stats_thread.start()
time.sleep(0.3)
add_thread.join()
stats_done.wait(timeout=5.0)
scheduler.stop()
add_thread.join()
stats_thread.join()
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
assert len(results["stats"]) == 50
# Verify stats are consistent
for stats in results["stats"]:
assert "total_tasks" in stats
assert stats["total_tasks"] >= 0