fix: 学习率调度按 optimizer step 计数并防止 warmup 越界
- total_steps 除以 accumulation_steps,匹配 optimizer.step() 频率 - warmup_steps 用 min 截断,避免 lr_decay_steps 为负
This commit is contained in:
parent
04c0dc7a47
commit
7242eedbf4
|
|
@ -260,13 +260,13 @@ def train(
|
|||
},
|
||||
)
|
||||
|
||||
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||
total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps
|
||||
scheduler_fn = partial(
|
||||
create_scheduler,
|
||||
**{
|
||||
"schedule_type": "cosine",
|
||||
"warmup_steps": warmup_steps,
|
||||
"lr_decay_steps": total_steps - warmup_steps,
|
||||
"warmup_steps": min(warmup_steps, total_steps),
|
||||
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue