From 9d5e9fa6c4b8cfa0c98910740cc1c285b8625f4d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Fri, 15 May 2026 15:30:24 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20DDP=20=E5=8A=A0=20gradient=5Fas=5Fbucke?= =?UTF-8?q?t=5Fview/static=5Fgraph/broadcast=5Fbuffers=EF=BC=8CAdamW=20fus?= =?UTF-8?q?ed?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - gradient_as_bucket_view=True 零拷贝梯度归并 - static_graph=True 跳过每轮 bucket 重建 - broadcast_buffers=False 省 buffer 广播 - AdamW fused=True 融合优化器 kernel --- scripts/tools/train.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index bdc4067..0208a15 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -155,18 +155,20 @@ def parse_args() -> argparse.Namespace: def ddp_wrap(model: nn.Module): local_rank = get_rank() - model = model.to(dtype=torch.bfloat16) ddp_model = DDP( model, device_ids=[local_rank], output_device=local_rank, + static_graph=True, find_unused_parameters=False, + gradient_as_bucket_view=True, + broadcast_buffers=False, ) return ddp_model def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: - return optim.AdamW(model.parameters(), **kwargs) + return optim.AdamW(model.parameters(), fused=True, **kwargs) def create_scheduler( @@ -231,6 +233,8 @@ def train( state_dict = st.load_file(weights_path) model.load_state_dict(state_dict, strict=False) + model = model.to(dtype=torch.bfloat16) + strategy_kwargs = { "dpo_beta": dpo_beta, "label_smoothing": label_smoothing,