perf: 测试优化,model 改为 session 共享,scheduler 用 Event 替代 sleep
- 拆出 session-scoped test_tokenizer + test_model,14 次创建 → 1 次 - 删除无用 test_env fixture - 固定模型维度,消除随机性 - 添加 pytest markers 配置
This commit is contained in:
parent
5889179c54
commit
5203b7f53e
|
|
@ -3,9 +3,7 @@ import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import safetensors.torch as st
|
|
||||||
import torch
|
import torch
|
||||||
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
@ -15,6 +13,12 @@ from astrai.model.transformer import Transformer
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
config.addinivalue_line("markers", "slow: marks tests as slow")
|
||||||
|
config.addinivalue_line("markers", "integration: integration tests")
|
||||||
|
config.addinivalue_line("markers", "unit: fast unit tests")
|
||||||
|
|
||||||
|
|
||||||
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
|
def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
|
||||||
"""Create a simple tokenizer for testing purposes."""
|
"""Create a simple tokenizer for testing purposes."""
|
||||||
tokenizer = Tokenizer(models.BPE())
|
tokenizer = Tokenizer(models.BPE())
|
||||||
|
|
@ -22,7 +26,6 @@ def create_test_tokenizer(vocab_size: int = 1000) -> AutoTokenizer:
|
||||||
trainer = trainers.BpeTrainer(
|
trainer = trainers.BpeTrainer(
|
||||||
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
|
vocab_size=vocab_size, min_frequency=1, special_tokens=["<unk>", "<pad>"]
|
||||||
)
|
)
|
||||||
# Train on empty iterator with single character
|
|
||||||
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
|
tokenizer.train_from_iterator([chr(i) for i in range(256)], trainer)
|
||||||
auto_tokenizer = AutoTokenizer()
|
auto_tokenizer = AutoTokenizer()
|
||||||
auto_tokenizer._tokenizer = tokenizer
|
auto_tokenizer._tokenizer = tokenizer
|
||||||
|
|
@ -34,7 +37,7 @@ class RandomDataset(Dataset):
|
||||||
"""Random dataset for testing purposes."""
|
"""Random dataset for testing purposes."""
|
||||||
|
|
||||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||||
self.length = length or int(np.random.randint(100, 200))
|
self.length = length or int(torch.randint(100, 200, (1,)).item())
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
|
@ -52,7 +55,7 @@ class MultiTurnDataset(Dataset):
|
||||||
"""Multi-turn dataset with loss mask for SFT training tests."""
|
"""Multi-turn dataset with loss mask for SFT training tests."""
|
||||||
|
|
||||||
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
def __init__(self, length=None, max_length=64, vocab_size=1000):
|
||||||
self.length = length or int(np.random.randint(100, 200))
|
self.length = length or int(torch.randint(100, 200, (1,)).item())
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
|
@ -93,46 +96,65 @@ class EarlyStoppingDataset(Dataset):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture(scope="session")
|
||||||
def base_test_env(request: pytest.FixtureRequest):
|
def test_tokenizer():
|
||||||
"""Create base test environment with randomly configured model and tokenizer"""
|
"""Session-scoped tokenizer, created once for the entire test run."""
|
||||||
func_name = request.function.__name__
|
return create_test_tokenizer()
|
||||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
|
||||||
|
|
||||||
n_dim_choices = [8, 16, 32]
|
|
||||||
n_head_choices = [2, 4]
|
|
||||||
|
|
||||||
dim = int(np.random.choice(n_dim_choices))
|
@pytest.fixture(scope="session")
|
||||||
n_heads = int(np.random.choice(n_head_choices))
|
def test_model():
|
||||||
n_kv_heads = n_heads // 2
|
"""Session-scoped small Transformer model, created once."""
|
||||||
dim_ffn = dim * 2
|
config = ModelConfig(
|
||||||
|
vocab_size=1000,
|
||||||
|
dim=16,
|
||||||
|
n_heads=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
dim_ffn=32,
|
||||||
|
max_len=1024,
|
||||||
|
n_layers=4,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = Transformer(config).to(device=device)
|
||||||
|
|
||||||
config = {
|
return {
|
||||||
"vocab_size": 1000,
|
"model": model,
|
||||||
"dim": dim,
|
"device": device,
|
||||||
"n_heads": n_heads,
|
"config": config,
|
||||||
"n_kv_heads": n_kv_heads,
|
|
||||||
"dim_ffn": dim_ffn,
|
|
||||||
"max_len": 1024,
|
|
||||||
"n_layers": 4,
|
|
||||||
"norm_eps": 1e-5,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def base_test_env(test_model, test_tokenizer):
|
||||||
|
"""Function-scoped test environment with isolated temp directory.
|
||||||
|
|
||||||
|
Composes session-scoped model and tokenizer with a per-test temp dir.
|
||||||
|
"""
|
||||||
|
test_dir = tempfile.mkdtemp()
|
||||||
|
config_path = os.path.join(test_dir, "config.json")
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
json.dump(config, f)
|
json.dump(
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
{
|
||||||
transformer_config = ModelConfig().load(config_path)
|
"vocab_size": 1000,
|
||||||
model = Transformer(transformer_config).to(device=device)
|
"dim": 16,
|
||||||
tokenizer = create_test_tokenizer()
|
"n_heads": 4,
|
||||||
|
"n_kv_heads": 2,
|
||||||
|
"dim_ffn": 32,
|
||||||
|
"max_len": 1024,
|
||||||
|
"n_layers": 4,
|
||||||
|
"norm_eps": 1e-5,
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
|
||||||
yield {
|
yield {
|
||||||
"device": device,
|
"device": test_model["device"],
|
||||||
"test_dir": str(test_dir),
|
"test_dir": str(test_dir),
|
||||||
"config_path": config_path,
|
"config_path": config_path,
|
||||||
"transformer_config": transformer_config,
|
"transformer_config": test_model["config"],
|
||||||
"model": model,
|
"model": test_model["model"],
|
||||||
"tokenizer": tokenizer,
|
"tokenizer": test_tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
shutil.rmtree(test_dir)
|
shutil.rmtree(test_dir)
|
||||||
|
|
@ -154,43 +176,3 @@ def multi_turn_dataset():
|
||||||
def early_stopping_dataset():
|
def early_stopping_dataset():
|
||||||
dataset = EarlyStoppingDataset()
|
dataset = EarlyStoppingDataset()
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_env(request: pytest.FixtureRequest):
|
|
||||||
"""Create a test environment with saved model and tokenizer files."""
|
|
||||||
|
|
||||||
func_name = request.function.__name__
|
|
||||||
test_dir = tempfile.mkdtemp(prefix=f"{func_name}_")
|
|
||||||
config_path = os.path.join(test_dir, "config.json")
|
|
||||||
tokenizer_path = os.path.join(test_dir, "tokenizer.json")
|
|
||||||
model_path = os.path.join(test_dir, "model.safetensors")
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"vocab_size": 1000,
|
|
||||||
"dim": 128,
|
|
||||||
"n_heads": 4,
|
|
||||||
"n_kv_heads": 2,
|
|
||||||
"dim_ffn": 256,
|
|
||||||
"max_len": 64,
|
|
||||||
"n_layers": 2,
|
|
||||||
"norm_eps": 1e-5,
|
|
||||||
}
|
|
||||||
with open(config_path, "w") as f:
|
|
||||||
json.dump(config, f)
|
|
||||||
|
|
||||||
tokenizer = create_test_tokenizer(vocab_size=config["vocab_size"])
|
|
||||||
tokenizer.save(tokenizer_path)
|
|
||||||
|
|
||||||
transformer_config = ModelConfig().load(config_path)
|
|
||||||
model = Transformer(transformer_config)
|
|
||||||
st.save_file(model.state_dict(), model_path)
|
|
||||||
|
|
||||||
yield {
|
|
||||||
"test_dir": test_dir,
|
|
||||||
"model": model,
|
|
||||||
"tokenizer": tokenizer,
|
|
||||||
"transformer_config": transformer_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
shutil.rmtree(test_dir)
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"""Tests for scheduler concurrency."""
|
"""Tests for scheduler concurrency."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -63,14 +62,11 @@ def test_scheduler_concurrent_add_task(mock_model_and_tokenizer):
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
# Let some tasks be processed
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
scheduler.stop()
|
|
||||||
|
|
||||||
for t in threads:
|
for t in threads:
|
||||||
t.join()
|
t.join()
|
||||||
|
|
||||||
|
scheduler.stop()
|
||||||
|
|
||||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||||
assert len(results["task_ids"]) == 50
|
assert len(results["task_ids"]) == 50
|
||||||
|
|
||||||
|
|
@ -89,19 +85,21 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
||||||
)
|
)
|
||||||
|
|
||||||
results = {"added": [], "removed": [], "errors": []}
|
results = {"added": [], "removed": [], "errors": []}
|
||||||
|
add_ready = threading.Event()
|
||||||
|
|
||||||
def add_worker():
|
def add_worker():
|
||||||
try:
|
try:
|
||||||
for i in range(20):
|
for i in range(20):
|
||||||
task_id = scheduler.add_task(f"prompt {i}")
|
task_id = scheduler.add_task(f"prompt {i}")
|
||||||
results["added"].append(task_id)
|
results["added"].append(task_id)
|
||||||
time.sleep(0.001)
|
if len(results["added"]) >= 10:
|
||||||
|
add_ready.set()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results["errors"].append(f"Add: {str(e)}")
|
results["errors"].append(f"Add: {str(e)}")
|
||||||
|
|
||||||
def remove_worker():
|
def remove_worker():
|
||||||
try:
|
try:
|
||||||
time.sleep(0.05) # Wait for some tasks to be added
|
add_ready.wait(timeout=5.0)
|
||||||
for task_id in results["added"][:10]:
|
for task_id in results["added"][:10]:
|
||||||
scheduler.remove_task(task_id)
|
scheduler.remove_task(task_id)
|
||||||
results["removed"].append(task_id)
|
results["removed"].append(task_id)
|
||||||
|
|
@ -114,11 +112,9 @@ def test_scheduler_concurrent_add_remove_task(mock_model_and_tokenizer):
|
||||||
add_thread.start()
|
add_thread.start()
|
||||||
remove_thread.start()
|
remove_thread.start()
|
||||||
|
|
||||||
time.sleep(0.2)
|
|
||||||
scheduler.stop()
|
|
||||||
|
|
||||||
add_thread.join()
|
add_thread.join()
|
||||||
remove_thread.join()
|
remove_thread.join()
|
||||||
|
scheduler.stop()
|
||||||
|
|
||||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||||
assert len(results["added"]) == 20
|
assert len(results["added"]) == 20
|
||||||
|
|
@ -138,21 +134,24 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||||
)
|
)
|
||||||
|
|
||||||
results = {"stats": [], "errors": []}
|
results = {"stats": [], "errors": []}
|
||||||
|
started = threading.Event()
|
||||||
|
stats_done = threading.Event()
|
||||||
|
|
||||||
def add_tasks():
|
def add_tasks():
|
||||||
try:
|
try:
|
||||||
for i in range(20):
|
for i in range(20):
|
||||||
scheduler.add_task(f"prompt {i}")
|
scheduler.add_task(f"prompt {i}")
|
||||||
time.sleep(0.001)
|
started.set()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results["errors"].append(f"Add: {str(e)}")
|
results["errors"].append(f"Add: {str(e)}")
|
||||||
|
|
||||||
def get_stats():
|
def get_stats():
|
||||||
try:
|
try:
|
||||||
|
started.wait(timeout=5.0)
|
||||||
for _ in range(50):
|
for _ in range(50):
|
||||||
stats = scheduler.get_stats()
|
stats = scheduler.get_stats()
|
||||||
results["stats"].append(stats)
|
results["stats"].append(stats)
|
||||||
time.sleep(0.001)
|
stats_done.set()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results["errors"].append(f"Get stats: {str(e)}")
|
results["errors"].append(f"Get stats: {str(e)}")
|
||||||
|
|
||||||
|
|
@ -162,16 +161,15 @@ def test_scheduler_concurrent_get_stats(mock_model_and_tokenizer):
|
||||||
add_thread.start()
|
add_thread.start()
|
||||||
stats_thread.start()
|
stats_thread.start()
|
||||||
|
|
||||||
time.sleep(0.3)
|
add_thread.join()
|
||||||
|
stats_done.wait(timeout=5.0)
|
||||||
scheduler.stop()
|
scheduler.stop()
|
||||||
|
|
||||||
add_thread.join()
|
|
||||||
stats_thread.join()
|
stats_thread.join()
|
||||||
|
|
||||||
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
assert len(results["errors"]) == 0, f"Errors: {results['errors']}"
|
||||||
assert len(results["stats"]) == 50
|
assert len(results["stats"]) == 50
|
||||||
|
|
||||||
# Verify stats are consistent
|
|
||||||
for stats in results["stats"]:
|
for stats in results["stats"]:
|
||||||
assert "total_tasks" in stats
|
assert "total_tasks" in stats
|
||||||
assert stats["total_tasks"] >= 0
|
assert stats["total_tasks"] >= 0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue