Compare commits

..

2 Commits

3 changed files with 76 additions and 35 deletions

View File

@ -156,11 +156,15 @@ class InferenceScheduler:
t.output_ids.append(ntok) t.output_ids.append(ntok)
t.output_tokens += 1 t.output_tokens += 1
pos = t.input_tokens + t.output_tokens pos = t.input_tokens + t.output_tokens
self._page_cache.task_extend(t.task_id, pos) extend_ok = self._page_cache.task_extend(t.task_id, pos)
if t.stream_callback: if t.stream_callback:
t.stream_callback( t.stream_callback(
self._task_mgr.tokenizer.decode([ntok]) self._task_mgr.tokenizer.decode([ntok])
) )
if not extend_ok:
t.status = TaskStatus.ABORTED
if t.stream_callback:
t.stream_callback(STOP)
for t in valid: for t in valid:
if t.is_finished(stop_ids): if t.is_finished(stop_ids):
@ -173,6 +177,9 @@ class InferenceScheduler:
if task.stream_callback: if task.stream_callback:
task.stream_callback(STOP) task.stream_callback(STOP)
self._page_cache.task_free(task.task_id) self._page_cache.task_free(task.task_id)
for task in self._task_mgr.get_waiting_tasks():
if task.stream_callback:
task.stream_callback(STOP)
self._task_mgr.clear_queues() self._task_mgr.clear_queues()
raise raise

View File

@ -193,6 +193,10 @@ class TaskManager:
with self._lock: with self._lock:
return list(self.active_tasks) return list(self.active_tasks)
def get_waiting_tasks(self) -> List[Task]:
with self._lock:
return list(self.waiting_queue)
def clear_queues(self) -> None: def clear_queues(self) -> None:
with self._lock: with self._lock:
self.waiting_queue.clear() self.waiting_queue.clear()

View File

@ -4,17 +4,17 @@ from torch.optim import Optimizer
def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5): def _zeropower_via_newtonschulz(G: torch.Tensor, steps: int = 5):
assert G.ndim == 2 assert G.ndim == 2
X = G.bfloat16() X = G
scale = max(1, G.size(0) / G.size(1)) ** 0.5 scale = max(1, G.size(0) / G.size(1)) ** 0.5
X = X / (X.norm() + 1e-7) * scale X = X / (X.norm() + 1e-7) * scale
if steps == 0: if steps == 0:
return X.type_as(G) return X
a, b, c = (3.4445, -4.7750, 2.0315) a, b, c = (3.4445, -4.7750, 2.0315)
for _ in range(steps): for _ in range(steps):
A = X @ X.T A = X @ X.T
B = A @ X B = A @ X
X = a * X + b * B + c * (A @ B) X = a * X + b * B + c * (A @ B)
return X.type_as(G) return X
class Muon(Optimizer): class Muon(Optimizer):
@ -50,64 +50,94 @@ class Muon(Optimizer):
if closure is not None: if closure is not None:
with torch.enable_grad(): with torch.enable_grad():
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
params_2d, params_1d = [], []
grads_2d, grads_1d = [], []
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
continue continue
grad = p.grad if p.grad.is_sparse:
if grad.is_sparse:
raise RuntimeError("Muon does not support sparse gradients") raise RuntimeError("Muon does not support sparse gradients")
if p.ndim >= 2: if p.ndim >= 2:
self._muon_update(p, grad, group) params_2d.append(p)
grads_2d.append(p.grad)
else: else:
self._adamw_update(p, grad, group) params_1d.append(p)
grads_1d.append(p.grad)
if params_2d:
self._muon_update_foreach(params_2d, grads_2d, group)
if params_1d:
self._adamw_update_foreach(params_1d, grads_1d, group)
return loss return loss
def _muon_update(self, p, grad, group): def _muon_update_foreach(self, params_2d, grads_2d, group):
lr = group["lr"] lr = group["lr"]
momentum = group["momentum"] momentum = group["momentum"]
wd = group["weight_decay"] wd = group["weight_decay"]
nesterov = group["nesterov"] nesterov = group["nesterov"]
ns_steps = group["ns_steps"] ns_steps = group["ns_steps"]
state = self.state[p]
p.mul_(1 - lr * wd) if wd != 0:
torch._foreach_mul_(params_2d, 1 - lr * wd)
if nesterov: if nesterov:
grad = grad.add(p, alpha=wd) grads_2d = torch._foreach_add(grads_2d, params_2d, alpha=wd)
bufs = []
for p, grad in zip(params_2d, grads_2d):
state = self.state[p]
if "momentum_buffer" not in state: if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(grad) state["momentum_buffer"] = torch.zeros_like(grad)
buf = state["momentum_buffer"] bufs.append(state["momentum_buffer"])
buf.lerp_(grad, 1 - momentum)
torch._foreach_lerp_(bufs, grads_2d, 1 - momentum)
for p, buf in zip(params_2d, bufs):
update = _zeropower_via_newtonschulz(buf, steps=ns_steps) update = _zeropower_via_newtonschulz(buf, steps=ns_steps)
scale = max(1, p.size(0) / p.size(1)) ** 0.5 scale = max(1, p.size(0) / p.size(1)) ** 0.5
p.add_(update, alpha=-lr * scale) p.add_(update, alpha=-lr * scale)
def _adamw_update(self, p, grad, group): def _adamw_update_foreach(self, params_1d, grads_1d, group):
lr = group["adamw_lr"] lr = group["adamw_lr"]
betas = group["adamw_betas"] betas = group["adamw_betas"]
eps = group["adamw_eps"] eps = group["adamw_eps"]
wd = group["adamw_wd"] wd = group["adamw_wd"]
state = self.state[p]
steps: list[int] = []
exp_avgs, exp_avg_sqs = [], []
has_state = []
for p in params_1d:
state = self.state[p]
if not state: if not state:
state["step"] = 0 state["step"] = 0
state["exp_avg"] = torch.zeros_like(p) state["exp_avg"] = torch.zeros_like(p)
state["exp_avg_sq"] = torch.zeros_like(p) state["exp_avg_sq"] = torch.zeros_like(p)
has_state.append(False)
else:
has_state.append(True)
state["step"] += 1 state["step"] += 1
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] steps.append(state["step"])
exp_avgs.append(state["exp_avg"])
exp_avg_sqs.append(state["exp_avg_sq"])
beta1, beta2 = betas beta1, beta2 = betas
exp_avg.lerp_(grad, 1 - beta1) torch._foreach_lerp_(exp_avgs, grads_1d, 1 - beta1)
exp_avg_sq.lerp_(grad.square(), 1 - beta2) grads_sq = torch._foreach_mul(grads_1d, grads_1d)
torch._foreach_lerp_(exp_avg_sqs, grads_sq, 1 - beta2)
step = state["step"] bias_correction1 = [1 - beta1**s for s in steps]
bias1 = 1 - beta1**step bias_correction2 = [1 - beta2**s for s in steps]
bias2 = 1 - beta2**step
p.mul_(1 - lr * wd) if wd != 0:
denom = exp_avg_sq.sqrt().div_(bias2**0.5).add_(eps) torch._foreach_mul_(params_1d, 1 - lr * wd)
p.addcdiv_(exp_avg / bias1, denom, value=-lr)
exp_avg_corrected = torch._foreach_div(exp_avgs, bias_correction1)
denom = torch._foreach_div(exp_avg_sqs, bias_correction2)
denom = torch._foreach_sqrt(denom)
torch._foreach_add_(denom, eps)
torch._foreach_addcdiv_(params_1d, exp_avg_corrected, denom, value=-lr)