perf: DDP 加 gradient_as_bucket_view/static_graph/broadcast_buffers,AdamW fused
- gradient_as_bucket_view=True 零拷贝梯度归并 - static_graph=True 跳过每轮 bucket 重建 - broadcast_buffers=False 省 buffer 广播 - AdamW fused=True 融合优化器 kernel
This commit is contained in:
parent
08dde46778
commit
9d5e9fa6c4
|
|
@ -155,18 +155,20 @@ def parse_args() -> argparse.Namespace:
|
|||
|
||||
def ddp_wrap(model: nn.Module):
|
||||
local_rank = get_rank()
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
ddp_model = DDP(
|
||||
model,
|
||||
device_ids=[local_rank],
|
||||
output_device=local_rank,
|
||||
static_graph=True,
|
||||
find_unused_parameters=False,
|
||||
gradient_as_bucket_view=True,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
return ddp_model
|
||||
|
||||
|
||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||
return optim.AdamW(model.parameters(), **kwargs)
|
||||
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
|
|
@ -231,6 +233,8 @@ def train(
|
|||
state_dict = st.load_file(weights_path)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
|
||||
strategy_kwargs = {
|
||||
"dpo_beta": dpo_beta,
|
||||
"label_smoothing": label_smoothing,
|
||||
|
|
|
|||
Loading…
Reference in New Issue