From 8a11a7d4448b692e329ce0e77f87f31933de9f3d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 11:04:40 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=E4=B8=A4=E5=A4=84=E5=8F=82=E6=95=B0=E4=BC=A0?= =?UTF-8?q?=E9=80=92=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - prepare_checkpoint 增加 DDP 判断,单卡时不访问 .module - dpo_beta 改为 beta,对齐 DPOStrategy 参数名 --- scripts/tools/train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 30ca44a..f04ed3b 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -180,7 +180,9 @@ def create_scheduler( def prepare_checkpoint(model: nn.Module) -> dict: - return model.module.state_dict() + if isinstance(model, DDP): + return model.module.state_dict() + return model.state_dict() def compute_total_steps( @@ -253,7 +255,7 @@ def train( model = model.to(dtype=torch.bfloat16) strategy_kwargs = { - "dpo_beta": dpo_beta, + "beta": dpo_beta, "label_smoothing": label_smoothing, "clip_eps": grpo_clip_eps, "kl_coef": grpo_kl_coef,