feat : 训练脚本新增 gradient_checkpointing 与多机 DDP 参数

This commit is contained in:
yegroup001 2026-06-02 01:01:00 +08:00
parent 746a1475b2
commit 0deee48602
1 changed files with 35 additions and 0 deletions

View File

@ -8,6 +8,7 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM from astrai.model import AutoRegressiveLM
from astrai.model.components.decoder_block import DecoderBlock
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
@ -115,6 +116,12 @@ def parse_args() -> argparse.Namespace:
default=0.05, default=0.05,
help="cross_entropy function label smoothing parameter", help="cross_entropy function label smoothing parameter",
) )
parser.add_argument(
"--gradient_checkpointing",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable activation checkpointing for DecoderBlock modules.",
)
parser.add_argument( parser.add_argument(
"--ckpt_interval", "--ckpt_interval",
@ -141,6 +148,24 @@ def parse_args() -> argparse.Namespace:
"--start_batch", type=int, default=0, help="Start batch for training." "--start_batch", type=int, default=0, help="Start batch for training."
) )
parser.add_argument(
"--master_addr",
type=str,
default="localhost",
help="Master node address for distributed training.",
)
parser.add_argument(
"--master_port",
type=str,
default="29500",
help="Master node port for distributed training.",
)
parser.add_argument(
"--backend",
type=str,
default="nccl",
help="Distributed training backend.",
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument( parser.add_argument(
"--parallel_mode", "--parallel_mode",
@ -222,11 +247,15 @@ def train(
random_seed: int, random_seed: int,
num_workers: int, num_workers: int,
pin_memory: bool, pin_memory: bool,
gradient_checkpointing: bool,
window_size: int, window_size: int,
stride: int, stride: int,
nprocs: int, nprocs: int,
parallel_mode: str, parallel_mode: str,
device_type: str, device_type: str,
backend: str,
master_addr: str,
master_port: str,
start_method: str, start_method: str,
): ):
assert train_type in ["seq", "sft", "dpo", "grpo"] assert train_type in ["seq", "sft", "dpo", "grpo"]
@ -303,7 +332,13 @@ def train(
random_seed=random_seed, random_seed=random_seed,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
gradient_checkpointing_modules=[DecoderBlock]
if gradient_checkpointing
else [],
nprocs=nprocs, nprocs=nprocs,
backend=backend,
master_addr=master_addr,
master_port=master_port,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device_type=device_type, device_type=device_type,
start_method=start_method, start_method=start_method,