Compare commits
2 Commits
785d65436c
...
0594ce1017
| Author | SHA1 | Date |
|---|---|---|
|
|
0594ce1017 | |
|
|
ff509ff39f |
|
|
@ -156,11 +156,15 @@ class InferenceScheduler:
|
|||
t.output_ids.append(ntok)
|
||||
t.output_tokens += 1
|
||||
pos = t.input_tokens + t.output_tokens
|
||||
self._page_cache.task_extend(t.task_id, pos)
|
||||
extend_ok = self._page_cache.task_extend(t.task_id, pos)
|
||||
if t.stream_callback:
|
||||
t.stream_callback(
|
||||
self._task_mgr.tokenizer.decode([ntok])
|
||||
)
|
||||
if not extend_ok:
|
||||
t.status = TaskStatus.ABORTED
|
||||
if t.stream_callback:
|
||||
t.stream_callback(STOP)
|
||||
|
||||
for t in valid:
|
||||
if t.is_finished(stop_ids):
|
||||
|
|
@ -173,6 +177,9 @@ class InferenceScheduler:
|
|||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._page_cache.task_free(task.task_id)
|
||||
for task in self._task_mgr.get_waiting_tasks():
|
||||
if task.stream_callback:
|
||||
task.stream_callback(STOP)
|
||||
self._task_mgr.clear_queues()
|
||||
raise
|
||||
|
||||
|
|
|
|||
|
|
@ -193,6 +193,10 @@ class TaskManager:
|
|||
with self._lock:
|
||||
return list(self.active_tasks)
|
||||
|
||||
def get_waiting_tasks(self) -> List[Task]:
|
||||
with self._lock:
|
||||
return list(self.waiting_queue)
|
||||
|
||||
def clear_queues(self) -> None:
|
||||
with self._lock:
|
||||
self.waiting_queue.clear()
|
||||
|
|
|
|||
|
|
@ -4,17 +4,17 @@ from torch.optim import Optimizer
|
|||
|
||||
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
|
||||
assert G.ndim == 2
|
||||
X = G.bfloat16()
|
||||
X = G
|
||||
scale = max(1, G.size(0) / G.size(1)) ** 0.5
|
||||
X = X / (X.norm() + 1e-7) * scale
|
||||
if steps == 0:
|
||||
return X.type_as(G)
|
||||
return X
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = A @ X
|
||||
X = a * X + b * B + c * (A @ B)
|
||||
return X.type_as(G)
|
||||
return X
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
|
|
@ -50,64 +50,94 @@ class Muon(Optimizer):
|
|||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_2d, params_1d = [], []
|
||||
grads_2d, grads_1d = [], []
|
||||
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError("Muon does not support sparse gradients")
|
||||
if p.ndim >= 2:
|
||||
self._muon_update(p, grad, group)
|
||||
params_2d.append(p)
|
||||
grads_2d.append(p.grad)
|
||||
else:
|
||||
self._adamw_update(p, grad, group)
|
||||
params_1d.append(p)
|
||||
grads_1d.append(p.grad)
|
||||
|
||||
if params_2d:
|
||||
self._muon_update_foreach(params_2d, grads_2d, group)
|
||||
if params_1d:
|
||||
self._adamw_update_foreach(params_1d, grads_1d, group)
|
||||
|
||||
return loss
|
||||
|
||||
def _muon_update(self, p, grad, group):
|
||||
def _muon_update_foreach(self, params_2d, grads_2d, group):
|
||||
lr = group["lr"]
|
||||
momentum = group["momentum"]
|
||||
wd = group["weight_decay"]
|
||||
nesterov = group["nesterov"]
|
||||
ns_steps = group["ns_steps"]
|
||||
state = self.state[p]
|
||||
|
||||
p.mul_(1 - lr * wd)
|
||||
if wd != 0:
|
||||
torch._foreach_mul_(params_2d, 1 - lr * wd)
|
||||
|
||||
if nesterov:
|
||||
grad = grad.add(p, alpha=wd)
|
||||
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
|
||||
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||
buf = state["momentum_buffer"]
|
||||
buf.lerp_(grad, 1 - momentum)
|
||||
bufs = []
|
||||
for p, grad in zip(params_2d, grads_2d):
|
||||
state = self.state[p]
|
||||
if "momentum_buffer" not in state:
|
||||
state["momentum_buffer"] = torch.zeros_like(grad)
|
||||
bufs.append(state["momentum_buffer"])
|
||||
|
||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||
p.add_(update, alpha=-lr * scale)
|
||||
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
|
||||
|
||||
def _adamw_update(self, p, grad, group):
|
||||
for p, buf in zip(params_2d, bufs):
|
||||
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
|
||||
scale = max(1, p.size(0) / p.size(1)) ** 0.5
|
||||
p.add_(update, alpha=-lr * scale)
|
||||
|
||||
def _adamw_update_foreach(self, params_1d, grads_1d, group):
|
||||
lr = group["adamw_lr"]
|
||||
betas = group["adamw_betas"]
|
||||
eps = group["adamw_eps"]
|
||||
wd = group["adamw_wd"]
|
||||
state = self.state[p]
|
||||
|
||||
if not state:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
steps: list[int] = []
|
||||
exp_avgs, exp_avg_sqs = [], []
|
||||
has_state = []
|
||||
for p in params_1d:
|
||||
state = self.state[p]
|
||||
if not state:
|
||||
state["step"] = 0
|
||||
state["exp_avg"] = torch.zeros_like(p)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
has_state.append(False)
|
||||
else:
|
||||
has_state.append(True)
|
||||
state["step"] += 1
|
||||
steps.append(state["step"])
|
||||
exp_avgs.append(state["exp_avg"])
|
||||
exp_avg_sqs.append(state["exp_avg_sq"])
|
||||
|
||||
state["step"] += 1
|
||||
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
|
||||
beta1, beta2 = betas
|
||||
|
||||
exp_avg.lerp_(grad, 1 - beta1)
|
||||
exp_avg_sq.lerp_(grad.square(), 1 - beta2)
|
||||
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
|
||||
grads_sq = torch._foreach_mul(grads_1d, grads_1d)
|
||||
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
|
||||
|
||||
step = state["step"]
|
||||
bias1 = 1 - beta1**step
|
||||
bias2 = 1 - beta2**step
|
||||
bias_correction1 = [1 - beta1**s for s in steps]
|
||||
bias_correction2 = [1 - beta2**s for s in steps]
|
||||
|
||||
p.mul_(1 - lr * wd)
|
||||
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
|
||||
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
|
||||
if wd != 0:
|
||||
torch._foreach_mul_(params_1d, 1 - lr * wd)
|
||||
|
||||
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
|
||||
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
|
||||
denom = torch._foreach_sqrt(denom)
|
||||
torch._foreach_add_(denom, eps)
|
||||
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)
|
||||
|
|
|
|||
Loading…
Reference in New Issue