Compare commits

..

No commits in common. "9096e413c3bb08edfbf42ed3d381f8fecb212edd" and "513f1f78269b66cdba550455276657557e394f40" have entirely different histories.

5 changed files with 104 additions and 58 deletions

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -25,13 +25,11 @@ def get_rotary_emb(
max_len: int, max_len: int,
base: float = 10000, base: float = 10000,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
) -> Tensor: ) -> Tuple[Tensor, Tensor]:
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim) theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
t = torch.arange(0, max_len, dtype=torch.float64, device=device) t = torch.arange(0, max_len, dtype=torch.float64, device=device)
freqs = torch.outer(t, theta).float() freqs = torch.outer(t, theta)
cos = torch.cos(freqs) return torch.cos(freqs).float(), torch.sin(freqs).float()
sin = torch.sin(freqs)
return torch.complex(cos, sin)
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
@ -52,10 +50,10 @@ class RotaryEmbedding(nn.Module):
self.base = base self.base = base
self._set_rotary_buffer(self.max_len) self._set_rotary_buffer(self.max_len)
def _set_rotary_buffer(self, max_len: int): def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
rotary_emb = get_rotary_emb(self.dim, max_len, self.base) cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
freqs_cis = torch.view_as_real(rotary_emb) self.register_buffer("cos_cached", cos_cached, persistent=False)
self.register_buffer("freqs_cis", freqs_cis, persistent=False) self.register_buffer("sin_cached", sin_cached, persistent=False)
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor: def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
if position_ids is None: if position_ids is None:
@ -64,8 +62,9 @@ class RotaryEmbedding(nn.Module):
.unsqueeze(0) .unsqueeze(0)
.expand(x.size(0), -1) .expand(x.size(0), -1)
) )
position_freq_cis = self.freqs_cis[position_ids].float() cos = self.cos_cached[position_ids].float()
return torch.view_as_complex(position_freq_cis) sin = self.sin_cached[position_ids].float()
return torch.complex(cos, sin)
class Linear(nn.Module): class Linear(nn.Module):

View File

@ -1,42 +1,75 @@
from typing import Any, Callable, Dict from typing import Dict
import torch
import torch.nn as nn import torch.nn as nn
def _grad_stat(
model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any
) -> dict:
results = {}
for name, param in model.named_parameters():
results[name] = default
if param.grad is not None:
results[name] = fn(param.grad.data)
return results
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]: def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0) """Compute gradient norm for each parameter in the model."""
norms = {}
for name, param in model.named_parameters():
norms[name] = 0.0
if param.grad:
norm = param.grad.data.norm(norm_type).item()
norms[name] = norm
return norms
def grad_std(model: nn.Module) -> Dict[str, float]: def grad_std(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.std().item(), 0.0) """Compute standard deviation of gradients for each parameter."""
stds = {}
for name, param in model.named_parameters():
stds[name] = 0.0
if param.grad:
std = param.grad.data.std().item()
stds[name] = std
return stds
def grad_max(model: nn.Module) -> Dict[str, float]: def grad_max(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.max().item(), -float("inf")) """Find the maximum absolute gradient value for each parameter."""
max_vals = {}
for name, param in model.named_parameters():
max_vals[name] = -float("inf")
if param.grad:
max_val = param.grad.data.max().item()
max_vals[name] = max_val
return max_vals
def grad_min(model: nn.Module) -> Dict[str, float]: def grad_min(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.min().item(), float("inf")) """Find the minimum absolute gradient value for each parameter."""
min_vals = {}
for name, param in model.named_parameters():
min_vals[name] = float("inf")
if param.grad:
min_val = param.grad.data.min().item()
min_vals[name] = min_val
return min_vals
def grad_mean(model: nn.Module) -> Dict[str, float]: def grad_mean(model: nn.Module) -> Dict[str, float]:
return _grad_stat(model, lambda g: g.mean().item(), 0.0) """Compute mean of gradients for each parameter."""
means = {}
for name, param in model.named_parameters():
means[name] = 0.0
if param.grad:
mean = param.grad.data.mean().item()
means[name] = mean
return means
def grad_nan_num(model: nn.Module) -> Dict[str, int]: def grad_nan_num(model: nn.Module) -> Dict[str, int]:
return _grad_stat(model, lambda g: g.isnan().sum().item(), 0) """Count the number of NaNs in gradients for each parameter."""
nan_nums = {}
for name, param in model.named_parameters():
nan_nums[name] = 0
if param.grad:
nan_num = param.grad.isnan().sum().item()
nan_nums[name] = nan_num
return nan_nums
def ctx_get_loss(ctx): def ctx_get_loss(ctx):

View File

@ -79,11 +79,30 @@ class GradientClippingCallback(TrainCallback):
def __init__(self, max_grad_norm: float): def __init__(self, max_grad_norm: float):
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
def on_step_end(self, context: TrainContext): def on_step_begin(self, context: TrainContext):
_ = context _ = context
clip_grad_norm_(context.model.parameters(), self.max_grad_norm) clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
@CallbackFactory.register("scheduler")
class SchedulerCallback(TrainCallback):
"""
Scheduler callback for trainer.
"""
def __init__(self):
pass
def on_train_begin(self, context: TrainContext):
for group in context.optimizer.param_groups:
if "initial_lr" not in group:
group["initial_lr"] = group["lr"]
def on_batch_end(self, context: TrainContext):
if context.scheduler:
context.scheduler.step()
@CallbackFactory.register("checkpoint") @CallbackFactory.register("checkpoint")
class CheckpointCallback(TrainCallback): class CheckpointCallback(TrainCallback):
""" """

View File

@ -1,5 +1,4 @@
import logging import logging
from itertools import batched
from typing import List, Optional from typing import List, Optional
from astrai.config import TrainConfig from astrai.config import TrainConfig
@ -31,6 +30,7 @@ class Trainer:
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval), CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm), CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
CallbackFactory.create("scheduler"),
] ]
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
@ -62,32 +62,31 @@ class Trainer:
try: try:
context.model.train() context.model.train()
accumulation_steps = max(self.train_config.accumulation_steps, 1) # 1.epoch
for epoch in range(context.epoch, self.train_config.n_epoch): for epoch in range(context.epoch, self.train_config.n_epoch):
context.epoch = epoch context.epoch = epoch
self._call_callbacks("on_epoch_begin", context) self._call_callbacks("on_epoch_begin", context)
for steps in batched(context.dataloader, accumulation_steps): accumulation_steps = max(self.train_config.accumulation_steps, 1)
self._call_callbacks("on_step_begin", context) for batch in context.dataloader:
if context.iteration % accumulation_steps == 0:
# 2. step
self._call_callbacks("on_step_begin", context)
context.optimizer.step()
context.optimizer.zero_grad()
self._call_callbacks("on_step_end", context)
step_batch_nums = len(steps) # 3. batch
for batch in steps: self._call_callbacks("on_batch_begin", context)
self._call_callbacks("on_batch_begin", context) loss = context.strategy(batch)
loss = context.strategy(batch) context.loss = loss.item()
context.loss = loss.item() context.iteration += 1
context.iteration += 1
stand_loss = loss / step_batch_nums # to make the loss normalized by accumulation steps
stand_loss.backward() stand_loss = loss / accumulation_steps
self._call_callbacks("on_batch_end", context) stand_loss.backward()
self._call_callbacks("on_step_end", context) self._call_callbacks("on_batch_end", context)
context.optimizer.step()
context.optimizer.zero_grad()
if context.scheduler:
context.scheduler.step()
self._call_callbacks("on_epoch_end", context) self._call_callbacks("on_epoch_end", context)

View File

@ -155,20 +155,18 @@ def parse_args() -> argparse.Namespace:
def ddp_wrap(model: nn.Module): def ddp_wrap(model: nn.Module):
local_rank = get_rank() local_rank = get_rank()
model = model.to(dtype=torch.bfloat16)
ddp_model = DDP( ddp_model = DDP(
model, model,
device_ids=[local_rank], device_ids=[local_rank],
output_device=local_rank, output_device=local_rank,
static_graph=True,
find_unused_parameters=False, find_unused_parameters=False,
gradient_as_bucket_view=True,
broadcast_buffers=False,
) )
return ddp_model return ddp_model
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer: def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs) return optim.AdamW(model.parameters(), **kwargs)
def create_scheduler( def create_scheduler(
@ -233,8 +231,6 @@ def train(
state_dict = st.load_file(weights_path) state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False) model.load_state_dict(state_dict, strict=False)
model = model.to(dtype=torch.bfloat16)
strategy_kwargs = { strategy_kwargs = {
"dpo_beta": dpo_beta, "dpo_beta": dpo_beta,
"label_smoothing": label_smoothing, "label_smoothing": label_smoothing,