Compare commits
3 Commits
513f1f7826
...
9096e413c3
| Author | SHA1 | Date |
|---|---|---|
|
|
9096e413c3 | |
|
|
9d5e9fa6c4 | |
|
|
08dde46778 |
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -25,11 +25,13 @@ def get_rotary_emb(
|
||||||
max_len: int,
|
max_len: int,
|
||||||
base: float = 10000,
|
base: float = 10000,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tensor:
|
||||||
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
theta = base ** (-torch.arange(0, dim, 2, dtype=torch.float64, device=device) / dim)
|
||||||
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
t = torch.arange(0, max_len, dtype=torch.float64, device=device)
|
||||||
freqs = torch.outer(t, theta)
|
freqs = torch.outer(t, theta).float()
|
||||||
return torch.cos(freqs).float(), torch.sin(freqs).float()
|
cos = torch.cos(freqs)
|
||||||
|
sin = torch.sin(freqs)
|
||||||
|
return torch.complex(cos, sin)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
@ -50,10 +52,10 @@ class RotaryEmbedding(nn.Module):
|
||||||
self.base = base
|
self.base = base
|
||||||
self._set_rotary_buffer(self.max_len)
|
self._set_rotary_buffer(self.max_len)
|
||||||
|
|
||||||
def _set_rotary_buffer(self, max_len: int, device: Optional[torch.device] = None):
|
def _set_rotary_buffer(self, max_len: int):
|
||||||
cos_cached, sin_cached = get_rotary_emb(self.dim, max_len, self.base, device)
|
rotary_emb = get_rotary_emb(self.dim, max_len, self.base)
|
||||||
self.register_buffer("cos_cached", cos_cached, persistent=False)
|
freqs_cis = torch.view_as_real(rotary_emb)
|
||||||
self.register_buffer("sin_cached", sin_cached, persistent=False)
|
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
|
||||||
|
|
||||||
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
def forward(self, x: Tensor, position_ids: Optional[Tensor] = None) -> Tensor:
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
|
|
@ -62,9 +64,8 @@ class RotaryEmbedding(nn.Module):
|
||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.expand(x.size(0), -1)
|
.expand(x.size(0), -1)
|
||||||
)
|
)
|
||||||
cos = self.cos_cached[position_ids].float()
|
position_freq_cis = self.freqs_cis[position_ids].float()
|
||||||
sin = self.sin_cached[position_ids].float()
|
return torch.view_as_complex(position_freq_cis)
|
||||||
return torch.complex(cos, sin)
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -1,75 +1,42 @@
|
||||||
from typing import Dict
|
from typing import Any, Callable, Dict
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
def _grad_stat(
|
||||||
"""Compute gradient norm for each parameter in the model."""
|
model: nn.Module, fn: Callable[[torch.Tensor], Any], default: Any
|
||||||
norms = {}
|
) -> dict:
|
||||||
|
results = {}
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
norms[name] = 0.0
|
results[name] = default
|
||||||
if param.grad:
|
if param.grad is not None:
|
||||||
norm = param.grad.data.norm(norm_type).item()
|
results[name] = fn(param.grad.data)
|
||||||
norms[name] = norm
|
return results
|
||||||
return norms
|
|
||||||
|
|
||||||
|
def grad_norm(model: nn.Module, norm_type: int = 2) -> Dict[str, float]:
|
||||||
|
return _grad_stat(model, lambda g: g.norm(norm_type).item(), 0.0)
|
||||||
|
|
||||||
|
|
||||||
def grad_std(model: nn.Module) -> Dict[str, float]:
|
def grad_std(model: nn.Module) -> Dict[str, float]:
|
||||||
"""Compute standard deviation of gradients for each parameter."""
|
return _grad_stat(model, lambda g: g.std().item(), 0.0)
|
||||||
stds = {}
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
stds[name] = 0.0
|
|
||||||
if param.grad:
|
|
||||||
std = param.grad.data.std().item()
|
|
||||||
stds[name] = std
|
|
||||||
return stds
|
|
||||||
|
|
||||||
|
|
||||||
def grad_max(model: nn.Module) -> Dict[str, float]:
|
def grad_max(model: nn.Module) -> Dict[str, float]:
|
||||||
"""Find the maximum absolute gradient value for each parameter."""
|
return _grad_stat(model, lambda g: g.max().item(), -float("inf"))
|
||||||
max_vals = {}
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
max_vals[name] = -float("inf")
|
|
||||||
if param.grad:
|
|
||||||
max_val = param.grad.data.max().item()
|
|
||||||
max_vals[name] = max_val
|
|
||||||
|
|
||||||
return max_vals
|
|
||||||
|
|
||||||
|
|
||||||
def grad_min(model: nn.Module) -> Dict[str, float]:
|
def grad_min(model: nn.Module) -> Dict[str, float]:
|
||||||
"""Find the minimum absolute gradient value for each parameter."""
|
return _grad_stat(model, lambda g: g.min().item(), float("inf"))
|
||||||
min_vals = {}
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
min_vals[name] = float("inf")
|
|
||||||
if param.grad:
|
|
||||||
min_val = param.grad.data.min().item()
|
|
||||||
min_vals[name] = min_val
|
|
||||||
|
|
||||||
return min_vals
|
|
||||||
|
|
||||||
|
|
||||||
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
def grad_mean(model: nn.Module) -> Dict[str, float]:
|
||||||
"""Compute mean of gradients for each parameter."""
|
return _grad_stat(model, lambda g: g.mean().item(), 0.0)
|
||||||
means = {}
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
means[name] = 0.0
|
|
||||||
if param.grad:
|
|
||||||
mean = param.grad.data.mean().item()
|
|
||||||
means[name] = mean
|
|
||||||
|
|
||||||
return means
|
|
||||||
|
|
||||||
|
|
||||||
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
def grad_nan_num(model: nn.Module) -> Dict[str, int]:
|
||||||
"""Count the number of NaNs in gradients for each parameter."""
|
return _grad_stat(model, lambda g: g.isnan().sum().item(), 0)
|
||||||
nan_nums = {}
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
nan_nums[name] = 0
|
|
||||||
if param.grad:
|
|
||||||
nan_num = param.grad.isnan().sum().item()
|
|
||||||
nan_nums[name] = nan_num
|
|
||||||
return nan_nums
|
|
||||||
|
|
||||||
|
|
||||||
def ctx_get_loss(ctx):
|
def ctx_get_loss(ctx):
|
||||||
|
|
|
||||||
|
|
@ -79,30 +79,11 @@ class GradientClippingCallback(TrainCallback):
|
||||||
def __init__(self, max_grad_norm: float):
|
def __init__(self, max_grad_norm: float):
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
def on_step_begin(self, context: TrainContext):
|
def on_step_end(self, context: TrainContext):
|
||||||
_ = context
|
_ = context
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("scheduler")
|
|
||||||
class SchedulerCallback(TrainCallback):
|
|
||||||
"""
|
|
||||||
Scheduler callback for trainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_train_begin(self, context: TrainContext):
|
|
||||||
for group in context.optimizer.param_groups:
|
|
||||||
if "initial_lr" not in group:
|
|
||||||
group["initial_lr"] = group["lr"]
|
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
|
||||||
if context.scheduler:
|
|
||||||
context.scheduler.step()
|
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("checkpoint")
|
@CallbackFactory.register("checkpoint")
|
||||||
class CheckpointCallback(TrainCallback):
|
class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
from itertools import batched
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
|
|
@ -30,7 +31,6 @@ class Trainer:
|
||||||
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
|
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
CallbackFactory.create("scheduler"),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
|
|
@ -62,32 +62,33 @@ class Trainer:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
context.model.train()
|
context.model.train()
|
||||||
# 1.epoch
|
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
||||||
|
|
||||||
for epoch in range(context.epoch, self.train_config.n_epoch):
|
for epoch in range(context.epoch, self.train_config.n_epoch):
|
||||||
context.epoch = epoch
|
context.epoch = epoch
|
||||||
self._call_callbacks("on_epoch_begin", context)
|
self._call_callbacks("on_epoch_begin", context)
|
||||||
|
|
||||||
accumulation_steps = max(self.train_config.accumulation_steps, 1)
|
for steps in batched(context.dataloader, accumulation_steps):
|
||||||
for batch in context.dataloader:
|
|
||||||
if context.iteration % accumulation_steps == 0:
|
|
||||||
# 2. step
|
|
||||||
self._call_callbacks("on_step_begin", context)
|
self._call_callbacks("on_step_begin", context)
|
||||||
context.optimizer.step()
|
|
||||||
context.optimizer.zero_grad()
|
|
||||||
self._call_callbacks("on_step_end", context)
|
|
||||||
|
|
||||||
# 3. batch
|
step_batch_nums = len(steps)
|
||||||
|
for batch in steps:
|
||||||
self._call_callbacks("on_batch_begin", context)
|
self._call_callbacks("on_batch_begin", context)
|
||||||
loss = context.strategy(batch)
|
loss = context.strategy(batch)
|
||||||
context.loss = loss.item()
|
context.loss = loss.item()
|
||||||
context.iteration += 1
|
context.iteration += 1
|
||||||
|
|
||||||
# to make the loss normalized by accumulation steps
|
stand_loss = loss / step_batch_nums
|
||||||
stand_loss = loss / accumulation_steps
|
|
||||||
stand_loss.backward()
|
stand_loss.backward()
|
||||||
|
|
||||||
self._call_callbacks("on_batch_end", context)
|
self._call_callbacks("on_batch_end", context)
|
||||||
|
|
||||||
|
self._call_callbacks("on_step_end", context)
|
||||||
|
context.optimizer.step()
|
||||||
|
context.optimizer.zero_grad()
|
||||||
|
|
||||||
|
if context.scheduler:
|
||||||
|
context.scheduler.step()
|
||||||
|
|
||||||
self._call_callbacks("on_epoch_end", context)
|
self._call_callbacks("on_epoch_end", context)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -155,18 +155,20 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def ddp_wrap(model: nn.Module):
|
||||||
local_rank = get_rank()
|
local_rank = get_rank()
|
||||||
model = model.to(dtype=torch.bfloat16)
|
|
||||||
ddp_model = DDP(
|
ddp_model = DDP(
|
||||||
model,
|
model,
|
||||||
device_ids=[local_rank],
|
device_ids=[local_rank],
|
||||||
output_device=local_rank,
|
output_device=local_rank,
|
||||||
|
static_graph=True,
|
||||||
find_unused_parameters=False,
|
find_unused_parameters=False,
|
||||||
|
gradient_as_bucket_view=True,
|
||||||
|
broadcast_buffers=False,
|
||||||
)
|
)
|
||||||
return ddp_model
|
return ddp_model
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
def create_optimizer(model: nn.Module, **kwargs) -> optim.Optimizer:
|
||||||
return optim.AdamW(model.parameters(), **kwargs)
|
return optim.AdamW(model.parameters(), fused=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
|
|
@ -231,6 +233,8 @@ def train(
|
||||||
state_dict = st.load_file(weights_path)
|
state_dict = st.load_file(weights_path)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
model = model.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
strategy_kwargs = {
|
strategy_kwargs = {
|
||||||
"dpo_beta": dpo_beta,
|
"dpo_beta": dpo_beta,
|
||||||
"label_smoothing": label_smoothing,
|
"label_smoothing": label_smoothing,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue