diff --git a/astrai/trainer/__init__.py b/astrai/trainer/__init__.py index f7c5d5b..d09fc7b 100644 --- a/astrai/trainer/__init__.py +++ b/astrai/trainer/__init__.py @@ -1,3 +1,4 @@ +from astrai.trainer.optim import Muon from astrai.trainer.schedule import BaseScheduler, SchedulerFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.train_callback import ( @@ -9,6 +10,8 @@ from astrai.trainer.trainer import Trainer __all__ = [ # Main trainer "Trainer", + # Optimizer + "Muon", # Strategy factory "StrategyFactory", "BaseStrategy", diff --git a/astrai/trainer/optim.py b/astrai/trainer/optim.py new file mode 100644 index 0000000..6e810a9 --- /dev/null +++ b/astrai/trainer/optim.py @@ -0,0 +1,113 @@ +import torch +from torch.optim import Optimizer + + +def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5): + assert G.ndim == 2 + X = G.bfloat16() + scale = max(1, G.size(0) / G.size(1)) ** 0.5 + X = X / (X.norm() + 1e-7) * scale + if steps == 0: + return X.type_as(G) + a, b, c = (3.4445, -4.7750, 2.0315) + for _ in range(steps): + A = X @ X.T + B = A @ X + X = a * X + b * B + c * (A @ B) + return X.type_as(G) + + +class Muon(Optimizer): + def __init__( + self, + params, + lr: float = 2e-3, + momentum: float = 0.95, + weight_decay: float = 0.0, + nesterov: bool = True, + ns_steps: int = 5, + adamw_lr: float = None, + adamw_betas: tuple = (0.9, 0.95), + adamw_eps: float = 1e-8, + adamw_wd: float = 0.0, + ): + defaults = dict( + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + nesterov=nesterov, + ns_steps=ns_steps, + adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1, + adamw_betas=adamw_betas, + adamw_eps=adamw_eps, + adamw_wd=adamw_wd, + ) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError("Muon does not support sparse gradients") + if p.ndim >= 2: + self._muon_update(p, grad, group) + else: + self._adamw_update(p, grad, group) + return loss + + def _muon_update(self, p, grad, group): + lr = group["lr"] + momentum = group["momentum"] + wd = group["weight_decay"] + nesterov = group["nesterov"] + ns_steps = group["ns_steps"] + state = self.state[p] + + p.mul_(1 - lr * wd) + + if nesterov: + grad = grad.add(p, alpha=wd) + + if "momentum_buffer" not in state: + state["momentum_buffer"] = torch.zeros_like(grad) + buf = state["momentum_buffer"] + buf.lerp_(grad, 1 - momentum) + + update = _zeropower_via_newtonschulz(buf, steps=ns_steps) + scale = max(1, p.size(0) / p.size(1)) ** 0.5 + p.add_(update, alpha=-lr * scale) + + def _adamw_update(self, p, grad, group): + lr = group["adamw_lr"] + betas = group["adamw_betas"] + eps = group["adamw_eps"] + wd = group["adamw_wd"] + state = self.state[p] + + if not state: + state["step"] = 0 + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + + state["step"] += 1 + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + beta1, beta2 = betas + + exp_avg.lerp_(grad, 1 - beta1) + exp_avg_sq.lerp_(grad.square(), 1 - beta2) + + step = state["step"] + bias1 = 1 - beta1**step + bias2 = 1 - beta2**step + + p.mul_(1 - lr * wd) + denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps) + p.addcdiv_(exp_avg / bias1, denom, value=-lr)