fix: 学习率调度按 optimizer step 计数并防止 warmup 越界

- total_steps 除以 accumulation_steps,匹配 optimizer.step() 频率
- warmup_steps 用 min 截断,避免 lr_decay_steps 为负
This commit is contained in:
ViperEkura 2026-05-16 17:07:36 +08:00
parent 04c0dc7a47
commit 7242eedbf4
1 changed files with 3 additions and 3 deletions

View File

@ -260,13 +260,13 @@ def train(
},
)
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps
scheduler_fn = partial(
create_scheduler,
**{
"schedule_type": "cosine",
"warmup_steps": warmup_steps,
"lr_decay_steps": total_steps - warmup_steps,
"warmup_steps": min(warmup_steps, total_steps),
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
},
)