import os import numpy as np import torch from astrai.config.train_config import TrainConfig from astrai.trainer.schedule import SchedulerFactory from astrai.trainer.trainer import Trainer def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior""" def optimizer_fn(model): return torch.optim.AdamW(model.parameters()) def scheduler_fn(optim): return SchedulerFactory.create( optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 ) train_config = TrainConfig( strategy="seq", optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, model_fn=lambda: base_test_env["model"], dataset=early_stopping_dataset, ckpt_dir=base_test_env["test_dir"], log_dir=os.path.join(base_test_env["test_dir"], "logs"), n_epoch=2, batch_per_device=2, ckpt_interval=1, grad_accum_steps=2, random_seed=np.random.randint(1e4), device_type=base_test_env["device"], ) trainer = Trainer(train_config) # Should handle early stopping gracefully try: trainer.train() except Exception: pass # Resume from latest checkpoint load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2") trainer = Trainer(train_config) trainer.train(resume_dir=load_dir) # Verify checkpoint was saved at expected iteration load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10") import json with open(os.path.join(load_dir, "meta.json")) as f: meta = json.load(f) assert meta["iteration"] == 10