perf: Muon step 改用 torch._foreach_* 批处理并移除 NS 迭代的冗余 bf16 转换

This commit is contained in:
ViperEkura 2026-05-23 19:50:12 +08:00
parent ff509ff39f
commit 0594ce1017
1 changed files with 64 additions and 34 deletions

View File

@ -4,17 +4,17 @@ from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2
X = G.bfloat16()
X = G
scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale
if steps == 0:
return X.type_as(G)
return X
a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps):
A = X @ X.T
B = A @ X
X = a * X + b * B + c * (A @ B)
return X.type_as(G)
return X
class Muon(Optimizer):
@ -50,64 +50,94 @@ class Muon(Optimizer):
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_2d, params_1d = [], []
grads_2d, grads_1d = [], []
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
if p.grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2:
self._muon_update(p, grad, group)
params_2d.append(p)
grads_2d.append(p.grad)
else:
self._adamw_update(p, grad, group)
params_1d.append(p)
grads_1d.append(p.grad)
if params_2d:
self._muon_update_foreach(params_2d, grads_2d, group)
if params_1d:
self._adamw_update_foreach(params_1d, grads_1d, group)
return loss
def _muon_update(self, p, grad, group):
def _muon_update_foreach(self, params_2d, grads_2d, group):
lr = group["lr"]
momentum = group["momentum"]
wd = group["weight_decay"]
nesterov = group["nesterov"]
ns_steps = group["ns_steps"]
state = self.state[p]
p.mul_(1 - lr * wd)
if wd != 0:
torch._foreach_mul_(params_2d, 1 - lr * wd)
if nesterov:
grad = grad.add(p, alpha=wd)
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
bufs = []
for p, grad in zip(params_2d, grads_2d):
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad)
buf = state["momentum_buffer"]
buf.lerp_(grad, 1 - momentum)
bufs.append(state["momentum_buffer"])
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
for p, buf in zip(params_2d, bufs):
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
def _adamw_update(self, p, grad, group):
def _adamw_update_foreach(self, params_1d, grads_1d, group):
lr = group["adamw_lr"]
betas = group["adamw_betas"]
eps = group["adamw_eps"]
wd = group["adamw_wd"]
state = self.state[p]
steps: list[int] = []
exp_avgs, exp_avg_sqs = [], []
has_state = []
for p in params_1d:
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
has_state.append(False)
else:
has_state.append(True)
state["step"] += 1
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
steps.append(state["step"])
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
beta1, beta2 = betas
exp_avg.lerp_(grad, 1 - beta1)
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
step = state["step"]
bias1 = 1 - beta1**step
bias2 = 1 - beta2**step
bias_correction1 = [1 - beta1**s for s in steps]
bias_correction2 = [1 - beta2**s for s in steps]
p.mul_(1 - lr * wd)
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
if wd != 0:
torch._foreach_mul_(params_1d, 1 - lr * wd)
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
denom = torch._foreach_sqrt(denom)
torch._foreach_add_(denom, eps)
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)