63 lines
2.2 KiB
Python
63 lines
2.2 KiB
Python
from astrai.trainer import Trainer
|
|
|
|
# train_config_factory is injected via fixture
|
|
|
|
|
|
def test_different_batch_sizes(base_test_env, random_dataset, train_config_factory):
|
|
"""Test training with different batch sizes"""
|
|
batch_sizes = [1, 2, 4, 8]
|
|
|
|
for batch_per_device in batch_sizes:
|
|
train_config = train_config_factory(
|
|
model=base_test_env["model"],
|
|
dataset=random_dataset,
|
|
test_dir=base_test_env["test_dir"],
|
|
device=base_test_env["device"],
|
|
batch_per_device=batch_per_device,
|
|
)
|
|
|
|
assert train_config.batch_per_device == batch_per_device
|
|
|
|
|
|
def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory):
|
|
"""Test training with different gradient accumulation steps"""
|
|
grad_accum_steps_list = [1, 2, 4]
|
|
|
|
for grad_accum_steps in grad_accum_steps_list:
|
|
train_config = train_config_factory(
|
|
model=base_test_env["model"],
|
|
dataset=random_dataset,
|
|
test_dir=base_test_env["test_dir"],
|
|
device=base_test_env["device"],
|
|
batch_per_device=2,
|
|
grad_accum_steps=grad_accum_steps,
|
|
)
|
|
|
|
trainer = Trainer(train_config)
|
|
trainer.train()
|
|
|
|
assert train_config.grad_accum_steps == grad_accum_steps
|
|
|
|
|
|
def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory):
|
|
"""Test training with memory-efficient configurations"""
|
|
# Test with smaller batch sizes and gradient checkpointing
|
|
small_batch_configs = [
|
|
{"batch_per_device": 1, "grad_accum_steps": 8},
|
|
{"batch_per_device": 2, "grad_accum_steps": 4},
|
|
{"batch_per_device": 4, "grad_accum_steps": 2},
|
|
]
|
|
|
|
for config in small_batch_configs:
|
|
train_config = train_config_factory(
|
|
model=base_test_env["model"],
|
|
dataset=random_dataset,
|
|
test_dir=base_test_env["test_dir"],
|
|
device=base_test_env["device"],
|
|
batch_per_device=config["batch_per_device"],
|
|
grad_accum_steps=config["grad_accum_steps"],
|
|
)
|
|
|
|
assert train_config.grad_accum_steps == config["grad_accum_steps"]
|
|
assert train_config.batch_per_device == config["batch_per_device"]
|