perf: DDP 加 gradient_as_bucket_view/static_graph/broadcast_buffers,AdamW fused

- gradient_as_bucket_view=True 零拷贝梯度归并
- static_graph=True 跳过每轮 bucket 重建
- broadcast_buffers=False 省 buffer 广播
- AdamW fused=True 融合优化器 kernel
This commit is contained in:
ViperEkura 2026-05-15 15:30:24 +08:00
parent 08dde46778
commit 9d5e9fa6c4
1 changed files with 6 additions and 2 deletions

View File

@ -155,18 +155,20 @@ def parse_args() -> argparse.Namespace:
def ddp_wrap(model: nn.Module): def ddp_wrap(model: nn.Module):
local_rank = get_rank() local_rank = get_rank()
model = model.to(dtype=torch.bfloat16)
ddp_model = DDP( ddp_model = DDP(
model, model,
device_ids=[local_rank], device_ids=[local_rank],
output_device=local_rank, output_device=local_rank,
static_graph=True,
find_unused_parameters=False, find_unused_parameters=False,
gradient_as_bucket_view=True,
broadcast_buffers=False,
) )
return ddp_model return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), **kwargs) return optim.AdamW(model.parameters(), fused=True, **kwargs)
def create_scheduler( def create_scheduler(
@ -231,6 +233,8 @@ def train(
state_dict = st.load_file(weights_path) state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
model = model.to(dtype=torch.bfloat16)
strategy_kwargs = { strategy_kwargs = {
"dpo_beta": dpo_beta, "dpo_beta": dpo_beta,
"label_smoothing": label_smoothing, "label_smoothing": label_smoothing,