feat: 新增 Muon 优化器
- 2D 参数用 Newton-Schulz 正交化 + Nesterov 动量更新 - 1D 参数用 AdamW 更新 - 支持 lr/momentum/weight_decay/ns_steps 配置
This commit is contained in:
parent
42a391f0fb
commit
10ebd7211f
|
|
@ -1,3 +1,4 @@
|
|||
from astrai.trainer.optim import Muon
|
||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
from astrai.trainer.train_callback import (
|
||||
|
|
@ -9,6 +10,8 @@ from astrai.trainer.trainer import Trainer
|
|||
__all__ = [
|
||||
# Main trainer
|
||||
"Trainer",
|
||||
# Optimizer
|
||||
"Muon",
|
||||
# Strategy factory
|
||||
"StrategyFactory",
|
||||
"BaseStrategy",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||
assert G.ndim == 2
|
||||
X = G.bfloat16()
|
||||
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||
X = X / (X.norm() + 1e-7) * scale
|
||||
if steps == 0:
|
||||
return X.type_as(G)
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = A @ X
|
||||
X = a * X + b * B + c * (A @ B)
|
||||
return X.type_as(G)
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr: float = 2e-3,
|
||||
momentum: float = 0.95,
|
||||
weight_decay: float = 0.0,
|
||||
nesterov: bool = True,
|
||||
ns_steps: int = 5,
|
||||
adamw_lr: float = None,
|
||||
adamw_betas: tuple = (0.9, 0.95),
|
||||
adamw_eps: float = 1e-8,
|
||||
adamw_wd: float = 0.0,
|
||||
):
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
momentum=momentum,
|
||||
weight_decay=weight_decay,
|
||||
nesterov=nesterov,
|
||||
ns_steps=ns_steps,
|
||||
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
|
||||
adamw_betas=adamw_betas,
|
||||
adamw_eps=adamw_eps,
|
||||
adamw_wd=adamw_wd,
|
||||
)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
for group in self.param_groups:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError("Muon does not support sparse gradients")
|
||||
if p.ndim >= 2:
|
||||
self._muon_update(p, grad, group)
|
||||
else:
|
||||
self._adamw_update(p, grad, group)
|
||||
return loss
|
||||
|
||||
def _muon_update(self, p, grad, group):
|
||||
lr = group["lr"]
|
||||
momentum = group["momentum"]
|
||||
wd = group["weight_decay"]
|
||||
nesterov = group["nesterov"]
|
||||
ns_steps = group["ns_steps"]
|
||||
state = self.state[p]
|
||||
|
||||
p.mul_(1 - lr * wd)
|
||||
|
||||
if nesterov:
|
||||
grad = grad.add(p, alpha=wd)
|
||||
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||
buf = state["momentum_buffer"]
|
||||
buf.lerp_(grad, 1 - momentum)
|
||||
|
||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||
p.add_(update, alpha=-lr * scale)
|
||||
|
||||
def _adamw_update(self, p, grad, group):
|
||||
lr = group["adamw_lr"]
|
||||
betas = group["adamw_betas"]
|
||||
eps = group["adamw_eps"]
|
||||
wd = group["adamw_wd"]
|
||||
state = self.state[p]
|
||||
|
||||
if not state:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
|
||||
state["step"] += 1
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = betas
|
||||
|
||||
exp_avg.lerp_(grad, 1 - beta1)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
|
||||
|
||||
step = state["step"]
|
||||
bias1 = 1 - beta1**step
|
||||
bias2 = 1 - beta2**step
|
||||
|
||||
p.mul_(1 - lr * wd)
|
||||
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
|
||||
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
|
||||
Loading…
Reference in New Issue