AstrAI/scripts/tools/train.py

328 lines
8.8 KiB
Python

import argparse
import os
from functools import partial
import safetensors.torch as st
import torch
import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.trainer import SchedulerFactory, Trainer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train the AutoRegressiveLM model.")
parser.add_argument(
"--train_type",
type=str,
required=True,
choices=["seq", "sft", "dpo", "grpo"],
help="Train type.",
)
parser.add_argument(
"--data_root_path",
type=str,
required=True,
help="Path to the root directory of the dataset.",
)
parser.add_argument(
"--param_path",
type=str,
required=True,
help="Path to the model parameters or resume checkpoint.",
)
parser.add_argument(
"--n_epoch", type=int, default=1, help="Number of epochs to train."
)
parser.add_argument(
"--batch_per_device", type=int, default=1, help="Batch size per GPU."
)
parser.add_argument(
"--grad_accum_steps",
type=int,
default=1,
help="Number of iterations between each optimizer step.",
)
parser.add_argument(
"--warmup_ratio",
type=float,
default=0.05,
help="Fraction of total steps used for LR warmup.",
)
parser.add_argument(
"--max_lr", type=float, default=3e-4, help="Max learning rate for training."
)
parser.add_argument(
"--max_grad_norm",
type=float,
default=1.0,
help="Max gradient norm for clipping.",
)
parser.add_argument(
"--adamw_beta1",
type=float,
default=0.9,
help="Beta1 for AdamW optimizer.",
)
parser.add_argument(
"--adamw_beta2",
type=float,
default=0.95,
help="Beta2 for AdamW optimizer.",
)
parser.add_argument(
"--adamw_weight_decay",
type=float,
default=0.01,
help="Weight decay for AdamW optimizer.",
)
parser.add_argument(
"--random_seed", type=int, default=3407, help="Random seed for reproducibility."
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of workers for data loading."
)
parser.add_argument(
"--no_pin_memory",
action="store_false",
dest="pin_memory",
help="Disable pin memory",
)
parser.add_argument(
"--window_size",
type=int,
default=None,
help="Max length of the input sequence.",
)
parser.add_argument(
"--stride", type=int, default=None, help="Step size of the input sequence."
)
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
parser.add_argument(
"--grpo_clip_eps", type=float, default=0.2, help="GRPO clipping epsilon."
)
parser.add_argument(
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
)
parser.add_argument(
"--label_smoothing",
type=float,
default=0.05,
help="cross_entropy function label smoothing parameter",
)
parser.add_argument(
"--ckpt_interval",
type=int,
default=5000,
help="Number of iters between checkpoints.",
)
parser.add_argument(
"--ckpt_dir",
type=str,
default="checkpoint",
help="Directory to save checkpoints.",
)
parser.add_argument(
"--grpo_sync_interval",
type=int,
default=200,
help="GRPO ref model sync interval (steps).",
)
parser.add_argument(
"--start_epoch", type=int, default=0, help="Start epoch for training."
)
parser.add_argument(
"--start_batch", type=int, default=0, help="Start batch for training."
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument(
"--parallel_mode",
type=str,
default="none",
choices=["none", "ddp"],
help="Parallel training strategy.",
)
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
parser.add_argument(
"--start_method",
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method.",
)
args = parser.parse_args()
return args
def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs)
def create_scheduler(
optimizer: optim.Optimizer, **kwargs
) -> optim.lr_scheduler.LRScheduler:
return SchedulerFactory.create(optimizer, **kwargs)
def compute_total_steps(
dataset_len: int,
n_epoch: int,
batch_per_device: int,
nprocs: int,
grad_accum_steps: int,
) -> int:
def ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
samples_per_replica = ceil_div(dataset_len, nprocs)
batches_per_replica = ceil_div(samples_per_replica, batch_per_device)
total_steps = (batches_per_replica // grad_accum_steps) * n_epoch
return total_steps
def train(
train_type: str,
param_path: str,
data_root_path: str,
max_lr: float,
n_epoch: int,
batch_per_device: int,
start_epoch: int,
start_batch: int,
grad_accum_steps: int,
warmup_ratio: float,
ckpt_interval: int,
ckpt_dir: str,
dpo_beta: float,
grpo_clip_eps: float,
grpo_kl_coef: float,
group_size: int,
grpo_sync_interval: int,
adamw_beta1: float,
adamw_beta2: float,
adamw_weight_decay: float,
max_grad_norm: float,
label_smoothing: float,
random_seed: int,
num_workers: int,
pin_memory: bool,
window_size: int,
stride: int,
nprocs: int,
parallel_mode: str,
device_type: str,
start_method: str,
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
# Load config
config_path = os.path.join(param_path, "config.json")
config = AutoRegressiveLMConfig.from_file(config_path)
if window_size is None:
window_size = config.max_len
# Create bare AutoRegressiveLM (for training, no tokenizer needed)
model = AutoRegressiveLM(config)
# Load weights if available
weights_path = os.path.join(param_path, "model.safetensors")
if os.path.exists(weights_path):
state_dict = st.load_file(weights_path)
model.load_state_dict(state_dict, strict=False)
model = model.to(dtype=torch.bfloat16)
strategy_kwargs = {
"beta": dpo_beta,
"label_smoothing": label_smoothing,
"clip_eps": grpo_clip_eps,
"kl_coef": grpo_kl_coef,
"group_size": group_size,
"sync_interval": grpo_sync_interval,
}
executor_kwargs = {
"static_graph": True,
"find_unused_parameters": False,
"gradient_as_bucket_view": True,
"broadcast_buffers": False,
}
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
window_size=window_size,
stride=stride,
)
optimizer_fn = partial(
create_optimizer,
**{
"lr": max_lr,
"betas": (adamw_beta1, adamw_beta2),
"weight_decay": adamw_weight_decay,
},
)
total_steps = compute_total_steps(
len(dataset), n_epoch, batch_per_device, nprocs, grad_accum_steps
)
warmup_steps = int(warmup_ratio * total_steps)
scheduler_fn = partial(
create_scheduler,
**{
"schedule_type": "cosine",
"warmup_steps": min(warmup_steps, total_steps),
"lr_decay_steps": total_steps - min(warmup_steps, total_steps),
},
)
train_config = TrainConfig(
model=model,
strategy=train_type,
dataset=dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
ckpt_dir=ckpt_dir,
n_epoch=n_epoch,
batch_per_device=batch_per_device,
start_epoch=start_epoch,
start_batch=start_batch,
ckpt_interval=ckpt_interval,
grad_accum_steps=grad_accum_steps,
max_grad_norm=max_grad_norm,
random_seed=random_seed,
num_workers=num_workers,
pin_memory=pin_memory,
nprocs=nprocs,
parallel_mode=parallel_mode,
device_type=device_type,
start_method=start_method,
executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs,
)
trainer = Trainer(train_config)
trainer.train()
if __name__ == "__main__":
args = parse_args()
train(**vars(args))