feat : 训练脚本新增 gradient_checkpointing 与多机 DDP 参数

This commit is contained in:
yegroup001 2026-06-02 01:01:00 +08:00
parent 746a1475b2
commit 0deee48602
1 changed files with 35 additions and 0 deletions

View File

@ -8,6 +8,7 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.model.components.decoder_block import DecoderBlock
from astrai.trainer import SchedulerFactory, Trainer
@ -115,6 +116,12 @@ def parse_args() -> argparse.Namespace:
default=0.05,
help="cross_entropy function label smoothing parameter",
)
parser.add_argument(
"--gradient_checkpointing",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable activation checkpointing for DecoderBlock modules.",
)
parser.add_argument(
"--ckpt_interval",
@ -141,6 +148,24 @@ def parse_args() -> argparse.Namespace:
"--start_batch", type=int, default=0, help="Start batch for training."
)
parser.add_argument(
"--master_addr",
type=str,
default="localhost",
help="Master node address for distributed training.",
)
parser.add_argument(
"--master_port",
type=str,
default="29500",
help="Master node port for distributed training.",
)
parser.add_argument(
"--backend",
type=str,
default="nccl",
help="Distributed training backend.",
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument(
"--parallel_mode",
@ -222,11 +247,15 @@ def train(
random_seed: int,
num_workers: int,
pin_memory: bool,
gradient_checkpointing: bool,
window_size: int,
stride: int,
nprocs: int,
parallel_mode: str,
device_type: str,
backend: str,
master_addr: str,
master_port: str,
start_method: str,
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
@ -303,7 +332,13 @@ def train(
random_seed=random_seed,
num_workers=num_workers,
pin_memory=pin_memory,
gradient_checkpointing_modules=[DecoderBlock]
if gradient_checkpointing
else [],
nprocs=nprocs,
backend=backend,
master_addr=master_addr,
master_port=master_port,
parallel_mode=parallel_mode,
device_type=device_type,
start_method=start_method,