144 lines
4.5 KiB
Python
144 lines
4.5 KiB
Python
import torch
|
|
from torch.optim import Optimizer
|
|
|
|
|
|
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
|
assert G.ndim == 2
|
|
X = G
|
|
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
|
X = X / (X.norm() + 1e-7) * scale
|
|
if steps == 0:
|
|
return X
|
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
|
for _ in range(steps):
|
|
A = X @ X.T
|
|
B = A @ X
|
|
X = a * X + b * B + c * (A @ B)
|
|
return X
|
|
|
|
|
|
class Muon(Optimizer):
|
|
def __init__(
|
|
self,
|
|
params,
|
|
lr: float = 2e-3,
|
|
momentum: float = 0.95,
|
|
weight_decay: float = 0.0,
|
|
nesterov: bool = True,
|
|
ns_steps: int = 5,
|
|
adamw_lr: float = None,
|
|
adamw_betas: tuple = (0.9, 0.95),
|
|
adamw_eps: float = 1e-8,
|
|
adamw_wd: float = 0.0,
|
|
):
|
|
defaults = dict(
|
|
lr=lr,
|
|
momentum=momentum,
|
|
weight_decay=weight_decay,
|
|
nesterov=nesterov,
|
|
ns_steps=ns_steps,
|
|
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
|
|
adamw_betas=adamw_betas,
|
|
adamw_eps=adamw_eps,
|
|
adamw_wd=adamw_wd,
|
|
)
|
|
super().__init__(params, defaults)
|
|
|
|
@torch.no_grad()
|
|
def step(self, closure=None):
|
|
loss = None
|
|
if closure is not None:
|
|
with torch.enable_grad():
|
|
loss = closure()
|
|
|
|
for group in self.param_groups:
|
|
params_2d, params_1d = [], []
|
|
grads_2d, grads_1d = [], []
|
|
|
|
for p in group["params"]:
|
|
if p.grad is None:
|
|
continue
|
|
if p.grad.is_sparse:
|
|
raise RuntimeError("Muon does not support sparse gradients")
|
|
if p.ndim >= 2:
|
|
params_2d.append(p)
|
|
grads_2d.append(p.grad)
|
|
else:
|
|
params_1d.append(p)
|
|
grads_1d.append(p.grad)
|
|
|
|
if params_2d:
|
|
self._muon_update_foreach(params_2d, grads_2d, group)
|
|
if params_1d:
|
|
self._adamw_update_foreach(params_1d, grads_1d, group)
|
|
|
|
return loss
|
|
|
|
def _muon_update_foreach(self, params_2d, grads_2d, group):
|
|
lr = group["lr"]
|
|
momentum = group["momentum"]
|
|
wd = group["weight_decay"]
|
|
nesterov = group["nesterov"]
|
|
ns_steps = group["ns_steps"]
|
|
|
|
if wd != 0:
|
|
torch._foreach_mul_(params_2d, 1 - lr * wd)
|
|
|
|
if nesterov:
|
|
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
|
|
|
|
bufs = []
|
|
for p, grad in zip(params_2d, grads_2d):
|
|
state = self.state[p]
|
|
if "momentum_buffer" not in state:
|
|
state["momentum_buffer"] = torch.zeros_like(grad)
|
|
bufs.append(state["momentum_buffer"])
|
|
|
|
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
|
|
|
|
for p, buf in zip(params_2d, bufs):
|
|
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
|
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
|
p.add_(update, alpha=-lr * scale)
|
|
|
|
def _adamw_update_foreach(self, params_1d, grads_1d, group):
|
|
lr = group["adamw_lr"]
|
|
betas = group["adamw_betas"]
|
|
eps = group["adamw_eps"]
|
|
wd = group["adamw_wd"]
|
|
|
|
steps: list[int] = []
|
|
exp_avgs, exp_avg_sqs = [], []
|
|
has_state = []
|
|
for p in params_1d:
|
|
state = self.state[p]
|
|
if not state:
|
|
state["step"] = 0
|
|
state["exp_avg"] = torch.zeros_like(p)
|
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
|
has_state.append(False)
|
|
else:
|
|
has_state.append(True)
|
|
state["step"] += 1
|
|
steps.append(state["step"])
|
|
exp_avgs.append(state["exp_avg"])
|
|
exp_avg_sqs.append(state["exp_avg_sq"])
|
|
|
|
beta1, beta2 = betas
|
|
|
|
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
|
|
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
|
|
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
|
|
|
|
bias_correction1 = [1 - beta1**s for s in steps]
|
|
bias_correction2 = [1 - beta2**s for s in steps]
|
|
|
|
if wd != 0:
|
|
torch._foreach_mul_(params_1d, 1 - lr * wd)
|
|
|
|
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
|
|
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
|
|
denom = torch._foreach_sqrt(denom)
|
|
torch._foreach_add_(denom, eps)
|
|
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
|