Compare commits

..

No commits in common. "0594ce101727988ede827f398de0b5729eeed922" and "785d65436c53de0fcbc28416bad6a7a4b9d8607f" have entirely different histories.

3 changed files with 35 additions and 76 deletions

View File

@ -156,15 +156,11 @@ class InferenceScheduler:
t.output_ids.append(ntok) t.output_ids.append(ntok)
t.output_tokens += 1 t.output_tokens += 1
pos = t.input_tokens + t.output_tokens pos = t.input_tokens + t.output_tokens
extend_ok = self._page_cache.task_extend(t.task_id, pos) self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback: if t.stream_callback:
t.stream_callback( t.stream_callback(
self._task_mgr.tokenizer.decode([ntok]) self._task_mgr.tokenizer.decode([ntok])
) )
if not extend_ok:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
for t in valid: for t in valid:
if t.is_finished(stop_ids): if t.is_finished(stop_ids):
@ -177,9 +173,6 @@ class InferenceScheduler:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
raise raise

View File

@ -193,10 +193,6 @@ class TaskManager:
with self._lock: with self._lock:
return list(self.active_tasks) return list(self.active_tasks)
def get_waiting_tasks(self) -> List[Task]:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self) -> None: def clear_queues(self) -> None:
with self._lock: with self._lock:
self.waiting_queue.clear() self.waiting_queue.clear()

View File

@ -4,17 +4,17 @@ from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5): def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2 assert G.ndim == 2
X = G X = G.bfloat16()
scale = max(1, G.size(0) / G.size(1)) ** 0.5 scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale X = X / (X.norm() + 1e-7) * scale
if steps == 0: if steps == 0:
return X return X.type_as(G)
a, b, c = (3.4445, -4.7750, 2.0315) a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps): for _ in range(steps):
A = X @ X.T A = X @ X.T
B = A @ X B = A @ X
X = a * X + b * B + c * (A @ B) X = a * X + b * B + c * (A @ B)
return X return X.type_as(G)
class Muon(Optimizer): class Muon(Optimizer):
@ -50,94 +50,64 @@ class Muon(Optimizer):
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
params_2d, params_1d = [], []
grads_2d, grads_1d = [], []
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
if p.grad.is_sparse: grad = p.grad
if grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients") raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2: if p.ndim >= 2:
params_2d.append(p) self._muon_update(p, grad, group)
grads_2d.append(p.grad)
else: else:
params_1d.append(p) self._adamw_update(p, grad, group)
grads_1d.append(p.grad)
if params_2d:
self._muon_update_foreach(params_2d, grads_2d, group)
if params_1d:
self._adamw_update_foreach(params_1d, grads_1d, group)
return loss return loss
def _muon_update_foreach(self, params_2d, grads_2d, group): def _muon_update(self, p, grad, group):
lr = group["lr"] lr = group["lr"]
momentum = group["momentum"] momentum = group["momentum"]
wd = group["weight_decay"] wd = group["weight_decay"]
nesterov = group["nesterov"] nesterov = group["nesterov"]
ns_steps = group["ns_steps"] ns_steps = group["ns_steps"]
state = self.state[p]
if wd != 0: p.mul_(1 - lr * wd)
torch._foreach_mul_(params_2d, 1 - lr * wd)
if nesterov: if nesterov:
grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd) grad = grad.add(p, alpha=wd)
bufs = [] if "momentum_buffer" not in state:
for p, grad in zip(params_2d, grads_2d): state["momentum_buffer"] = torch.zeros_like(grad)
state = self.state[p] buf = state["momentum_buffer"]
if "momentum_buffer" not in state: buf.lerp_(grad, 1 - momentum)
state["momentum_buffer"] = torch.zeros_like(grad)
bufs.append(state["momentum_buffer"])
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum) update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
for p, buf in zip(params_2d, bufs): def _adamw_update(self, p, grad, group):
update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale)
def _adamw_update_foreach(self, params_1d, grads_1d, group):
lr = group["adamw_lr"] lr = group["adamw_lr"]
betas = group["adamw_betas"] betas = group["adamw_betas"]
eps = group["adamw_eps"] eps = group["adamw_eps"]
wd = group["adamw_wd"] wd = group["adamw_wd"]
state = self.state[p]
steps: list[int] = [] if not state:
exp_avgs, exp_avg_sqs = [], [] state["step"] = 0
has_state = [] state["exp_avg"] = torch.zeros_like(p)
for p in params_1d: state["exp_avg_sq"] = torch.zeros_like(p)
state = self.state[p]
if not state:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p)
has_state.append(False)
else:
has_state.append(True)
state["step"] += 1
steps.append(state["step"])
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
state["step"] += 1
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = betas beta1, beta2 = betas
torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1) exp_avg.lerp_(grad, 1 - beta1)
grads_sq = torch._foreach_mul(grads_1d, grads_1d) exp_avg_sq.lerp_(grad.square(), 1 - beta2)
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
bias_correction1 = [1 - beta1**s for s in steps] step = state["step"]
bias_correction2 = [1 - beta2**s for s in steps] bias1 = 1 - beta1**step
bias2 = 1 - beta2**step
if wd != 0: p.mul_(1 - lr * wd)
torch._foreach_mul_(params_1d, 1 - lr * wd) denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps)
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
denom = torch._foreach_sqrt(denom)
torch._foreach_add_(denom, eps)
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)