perf: DDP 加 gradient_as_bucket_view/static_graph/broadcast_buffers,AdamW fused
- gradient_as_bucket_view=True 零拷贝梯度归并 - static_graph=True 跳过每轮 bucket 重建 - broadcast_buffers=False 省 buffer 广播 - AdamW fused=True 融合优化器 kernel
This commit is contained in:
parent
08dde46778
commit
9d5e9fa6c4
|
|
@ -155,18 +155,20 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def ddp_wrap(model: nn.Module):
|
||||||
local_rank = get_rank()
|
local_rank = get_rank()
|
||||||
model = model.to(dtype=torch.bfloat16)
|
|
||||||
ddp_model = DDP(
|
ddp_model = DDP(
|
||||||
model,
|
model,
|
||||||
device_ids=[local_rank],
|
device_ids=[local_rank],
|
||||||
output_device=local_rank,
|
output_device=local_rank,
|
||||||
|
static_graph=True,
|
||||||
find_unused_parameters=False,
|
find_unused_parameters=False,
|
||||||
|
gradient_as_bucket_view=True,
|
||||||
|
broadcast_buffers=False,
|
||||||
)
|
)
|
||||||
return ddp_model
|
return ddp_model
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
return optim.AdamW(model.parameters(), **kwargs)
|
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
|
|
@ -231,6 +233,8 @@ def train(
|
||||||
state_dict = st.load_file(weights_path)
|
state_dict = st.load_file(weights_path)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
model = model.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {
|
||||||
"dpo_beta": dpo_beta,
|
"dpo_beta": dpo_beta,
|
||||||
"label_smoothing": label_smoothing,
|
"label_smoothing": label_smoothing,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue