From 0594ce101727988ede827f398de0b5729eeed922 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 23 May 2026 19:50:12 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20Muon=20step=20=E6=94=B9=E7=94=A8=20torc?= =?UTF-8?q?h.=5Fforeach=5F*=20=E6=89=B9=E5=A4=84=E7=90=86=E5=B9=B6?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=20NS=20=E8=BF=AD=E4=BB=A3=E7=9A=84=E5=86=97?= =?UTF-8?q?=E4=BD=99=20bf16=20=E8=BD=AC=E6=8D=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/trainer/optim.py | 98 +++++++++++++++++++++++++++-------------- 1 file changed, 64 insertions(+), 34 deletions(-) diff --git a/astrai/trainer/optim.py b/astrai/trainer/optim.py index 6e810a9..27a2991 100644 --- a/astrai/trainer/optim.py +++ b/astrai/trainer/optim.py @@ -4,17 +4,17 @@ from torch.optim import Optimizer def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5): assert G.ndim == 2 - X = G.bfloat16() + X = G scale = max(1, G.size(0) / G.size(1)) ** 0.5 X = X / (X.norm() + 1e-7) * scale if steps == 0: - return X.type_as(G) + return X a, b, c = (3.4445, -4.7750, 2.0315) for _ in range(steps): A = X @ X.T B = A @ X X = a * X + b * B + c * (A @ B) - return X.type_as(G) + return X class Muon(Optimizer): @@ -50,64 +50,94 @@ class Muon(Optimizer): if closure is not None: with torch.enable_grad(): loss = closure() + for group in self.param_groups: + params_2d, params_1d = [], [] + grads_2d, grads_1d = [], [] + for p in group["params"]: if p.grad is None: continue - grad = p.grad - if grad.is_sparse: + if p.grad.is_sparse: raise RuntimeError("Muon does not support sparse gradients") if p.ndim >= 2: - self._muon_update(p, grad, group) + params_2d.append(p) + grads_2d.append(p.grad) else: - self._adamw_update(p, grad, group) + params_1d.append(p) + grads_1d.append(p.grad) + + if params_2d: + self._muon_update_foreach(params_2d, grads_2d, group) + if params_1d: + self._adamw_update_foreach(params_1d, grads_1d, group) + return loss - def _muon_update(self, p, grad, group): + def _muon_update_foreach(self, params_2d, grads_2d, group): lr = group["lr"] momentum = group["momentum"] wd = group["weight_decay"] nesterov = group["nesterov"] ns_steps = group["ns_steps"] - state = self.state[p] - p.mul_(1 - lr * wd) + if wd != 0: + torch._foreach_mul_(params_2d, 1 - lr * wd) if nesterov: - grad = grad.add(p, alpha=wd) + grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd) - if "momentum_buffer" not in state: - state["momentum_buffer"] = torch.zeros_like(grad) - buf = state["momentum_buffer"] - buf.lerp_(grad, 1 - momentum) + bufs = [] + for p, grad in zip(params_2d, grads_2d): + state = self.state[p] + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(grad) + bufs.append(state["momentum_buffer"]) - update = _zeropower_via_newtonschulz(buf, steps=ns_steps) - scale = max(1, p.size(0) / p.size(1)) ** 0.5 - p.add_(update, alpha=-lr * scale) + torch._foreach_lerp_(bufs, grads_2d, 1 - momentum) - def _adamw_update(self, p, grad, group): + for p, buf in zip(params_2d, bufs): + update = _zeropower_via_newtonschulz(buf, steps=ns_steps) + scale = max(1, p.size(0) / p.size(1)) ** 0.5 + p.add_(update, alpha=-lr * scale) + + def _adamw_update_foreach(self, params_1d, grads_1d, group): lr = group["adamw_lr"] betas = group["adamw_betas"] eps = group["adamw_eps"] wd = group["adamw_wd"] - state = self.state[p] - if not state: - state["step"] = 0 - state["exp_avg"] = torch.zeros_like(p) - state["exp_avg_sq"] = torch.zeros_like(p) + steps: list[int] = [] + exp_avgs, exp_avg_sqs = [], [] + has_state = [] + for p in params_1d: + state = self.state[p] + if not state: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + has_state.append(False) + else: + has_state.append(True) + state["step"] += 1 + steps.append(state["step"]) + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) - state["step"] += 1 - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] beta1, beta2 = betas - exp_avg.lerp_(grad, 1 - beta1) - exp_avg_sq.lerp_(grad.square(), 1 - beta2) + torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1) + grads_sq = torch._foreach_mul(grads_1d, grads_1d) + torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2) - step = state["step"] - bias1 = 1 - beta1**step - bias2 = 1 - beta2**step + bias_correction1 = [1 - beta1**s for s in steps] + bias_correction2 = [1 - beta2**s for s in steps] - p.mul_(1 - lr * wd) - denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps) - p.addcdiv_(exp_avg / bias1, denom, value=-lr) + if wd != 0: + torch._foreach_mul_(params_1d, 1 - lr * wd) + + exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1) + denom = torch._foreach_div(exp_avg_sqs, bias_correction2) + denom = torch._foreach_sqrt(denom) + torch._foreach_add_(denom, eps) + torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)