diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 30ca44a..f04ed3b 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -180,7 +180,9 @@ def create_scheduler( def prepare_checkpoint(model: nn.Module) -> dict: - return model.module.state_dict() + if isinstance(model, DDP): + return model.module.state_dict() + return model.state_dict() def compute_total_steps( @@ -253,7 +255,7 @@ def train( model = model.to(dtype=torch.bfloat16) strategy_kwargs = { - "dpo_beta": dpo_beta, + "beta": dpo_beta, "label_smoothing": label_smoothing, "clip_eps": grpo_clip_eps, "kl_coef": grpo_kl_coef,