diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 95a549d..fe6c6cb 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -177,6 +177,22 @@ def create_scheduler( return SchedulerFactory.create(optimizer, **kwargs) +def ceil_div(a: int, b: int) -> int: + return (a + b - 1) // b + + +def compute_total_steps( + dataset_len: int, + n_epoch: int, + batch_size: int, + nprocs: int, + accumulation_steps: int, +) -> int: + samples_per_replica = ceil_div(dataset_len, nprocs) + batches_per_replica = ceil_div(samples_per_replica, batch_size) + return ceil_div(batches_per_replica, accumulation_steps) * n_epoch + + def prepare_checkpoint(model: nn.Module) -> dict: return model.module.state_dict() @@ -260,7 +276,9 @@ def train( }, ) - total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps + total_steps = compute_total_steps( + len(dataset), n_epoch, batch_size, nprocs, accumulation_steps + ) scheduler_fn = partial( create_scheduler, **{