From 7242eedbf45ce2046319b677b35b4d03bc0f94b5 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 16 May 2026 17:07:36 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E5=AD=A6=E4=B9=A0=E7=8E=87=E8=B0=83?= =?UTF-8?q?=E5=BA=A6=E6=8C=89=20optimizer=20step=20=E8=AE=A1=E6=95=B0?= =?UTF-8?q?=E5=B9=B6=E9=98=B2=E6=AD=A2=20warmup=20=E8=B6=8A=E7=95=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - total_steps 除以 accumulation_steps,匹配 optimizer.step() 频率 - warmup_steps 用 min 截断,避免 lr_decay_steps 为负 --- scripts/tools/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 0208a15..95a549d 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -260,13 +260,13 @@ def train( }, ) - total_steps = len(dataset) * n_epoch // (batch_size * nprocs) + total_steps = len(dataset) * n_epoch // (batch_size * nprocs) // accumulation_steps scheduler_fn = partial( create_scheduler, **{ "schedule_type": "cosine", - "warmup_steps": warmup_steps, - "lr_decay_steps": total_steps - warmup_steps, + "warmup_steps": min(warmup_steps, total_steps), + "lr_decay_steps": total_steps - min(warmup_steps, total_steps), }, )