feat: 新增 Muon 优化器
- 2D 参数用 Newton-Schulz 正交化 + Nesterov 动量更新 - 1D 参数用 AdamW 更新 - 支持 lr/momentum/weight_decay/ns_steps 配置
This commit is contained in:
parent
42a391f0fb
commit
10ebd7211f
|
|
@ -1,3 +1,4 @@
|
||||||
|
from astrai.trainer.optim import Muon
|
||||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
|
|
@ -9,6 +10,8 @@ from astrai.trainer.trainer import Trainer
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Main trainer
|
# Main trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
|
# Optimizer
|
||||||
|
"Muon",
|
||||||
# Strategy factory
|
# Strategy factory
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"BaseStrategy",
|
"BaseStrategy",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,113 @@
|
||||||
|
import torch
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||||
|
assert G.ndim == 2
|
||||||
|
X = G.bfloat16()
|
||||||
|
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||||
|
X = X / (X.norm() + 1e-7) * scale
|
||||||
|
if steps == 0:
|
||||||
|
return X.type_as(G)
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ X.T
|
||||||
|
B = A @ X
|
||||||
|
X = a * X + b * B + c * (A @ B)
|
||||||
|
return X.type_as(G)
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(Optimizer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
params,
|
||||||
|
lr: float = 2e-3,
|
||||||
|
momentum: float = 0.95,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
nesterov: bool = True,
|
||||||
|
ns_steps: int = 5,
|
||||||
|
adamw_lr: float = None,
|
||||||
|
adamw_betas: tuple = (0.9, 0.95),
|
||||||
|
adamw_eps: float = 1e-8,
|
||||||
|
adamw_wd: float = 0.0,
|
||||||
|
):
|
||||||
|
defaults = dict(
|
||||||
|
lr=lr,
|
||||||
|
momentum=momentum,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
nesterov=nesterov,
|
||||||
|
ns_steps=ns_steps,
|
||||||
|
adamw_lr=adamw_lr if adamw_lr is not None else lr * 0.1,
|
||||||
|
adamw_betas=adamw_betas,
|
||||||
|
adamw_eps=adamw_eps,
|
||||||
|
adamw_wd=adamw_wd,
|
||||||
|
)
|
||||||
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def step(self, closure=None):
|
||||||
|
loss = None
|
||||||
|
if closure is not None:
|
||||||
|
with torch.enable_grad():
|
||||||
|
loss = closure()
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
if p.grad is None:
|
||||||
|
continue
|
||||||
|
grad = p.grad
|
||||||
|
if grad.is_sparse:
|
||||||
|
raise RuntimeError("Muon does not support sparse gradients")
|
||||||
|
if p.ndim >= 2:
|
||||||
|
self._muon_update(p, grad, group)
|
||||||
|
else:
|
||||||
|
self._adamw_update(p, grad, group)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def _muon_update(self, p, grad, group):
|
||||||
|
lr = group["lr"]
|
||||||
|
momentum = group["momentum"]
|
||||||
|
wd = group["weight_decay"]
|
||||||
|
nesterov = group["nesterov"]
|
||||||
|
ns_steps = group["ns_steps"]
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
p.mul_(1 - lr * wd)
|
||||||
|
|
||||||
|
if nesterov:
|
||||||
|
grad = grad.add(p, alpha=wd)
|
||||||
|
|
||||||
|
if "momentum_buffer" not in state:
|
||||||
|
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||||
|
buf = state["momentum_buffer"]
|
||||||
|
buf.lerp_(grad, 1 - momentum)
|
||||||
|
|
||||||
|
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||||
|
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||||
|
p.add_(update, alpha=-lr * scale)
|
||||||
|
|
||||||
|
def _adamw_update(self, p, grad, group):
|
||||||
|
lr = group["adamw_lr"]
|
||||||
|
betas = group["adamw_betas"]
|
||||||
|
eps = group["adamw_eps"]
|
||||||
|
wd = group["adamw_wd"]
|
||||||
|
state = self.state[p]
|
||||||
|
|
||||||
|
if not state:
|
||||||
|
state["step"] = 0
|
||||||
|
state["exp_avg"] = torch.zeros_like(p)
|
||||||
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
|
|
||||||
|
state["step"] += 1
|
||||||
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||||
|
beta1, beta2 = betas
|
||||||
|
|
||||||
|
exp_avg.lerp_(grad, 1 - beta1)
|
||||||
|
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
|
||||||
|
|
||||||
|
step = state["step"]
|
||||||
|
bias1 = 1 - beta1**step
|
||||||
|
bias2 = 1 - beta2**step
|
||||||
|
|
||||||
|
p.mul_(1 - lr * wd)
|
||||||
|
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
|
||||||
|
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
|
||||||
Loading…
Reference in New Issue