diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 0208a15..95a549d 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -260,13 +260,13 @@ def train( }, ) - total_steps = len(dataset) * n_epoch // (batch_size * nprocs) + total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps scheduler_fn = partial( create_scheduler, **{ "schedule_type": "cosine", - "warmup_steps": warmup_steps, - "lr_decay_steps": total_steps - warmup_steps, + "warmup_steps": min(warmup_steps, total_steps), + "lr_decay_steps": total_steps - min(warmup_steps, total_steps), }, )