fix: total_steps 改用 ceiling 匹配实际步数
原公式全用 floor 少算 optimizer step,改用逐层 ceiling (ceil_div via (a+b-1)//b)对齐 DDP sampler padding + DataLoader drop_last=False 尾批 + batched 尾组截断。
This commit is contained in:
parent
7242eedbf4
commit
026d1fc33d
|
|
@ -177,6 +177,22 @@ def create_scheduler(
|
|||
return SchedulerFactory.create(optimizer, **kwargs)
|
||||
|
||||
|
||||
def ceil_div(a: int, b: int) -> int:
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
def compute_total_steps(
|
||||
dataset_len: int,
|
||||
n_epoch: int,
|
||||
batch_size: int,
|
||||
nprocs: int,
|
||||
accumulation_steps: int,
|
||||
) -> int:
|
||||
samples_per_replica = ceil_div(dataset_len, nprocs)
|
||||
batches_per_replica = ceil_div(samples_per_replica, batch_size)
|
||||
return ceil_div(batches_per_replica, accumulation_steps) * n_epoch
|
||||
|
||||
|
||||
def prepare_checkpoint(model: nn.Module) -> dict:
|
||||
return model.module.state_dict()
|
||||
|
||||
|
|
@ -260,7 +276,9 @@ def train(
|
|||
},
|
||||
)
|
||||
|
||||
total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps
|
||||
total_steps = compute_total_steps(
|
||||
len(dataset), n_epoch, batch_size, nprocs, accumulation_steps
|
||||
)
|
||||
scheduler_fn = partial(
|
||||
create_scheduler,
|
||||
**{
|
||||
|
|
|
|||
Loading…
Reference in New Issue