From 0deee48602f27c73dd7dbb2a934f7e940bbb42fb Mon Sep 17 00:00:00 2001 From: yegroup001 Date: Tue, 2 Jun 2026 01:01:00 +0800 Subject: [PATCH] =?UTF-8?q?feat=20:=20=E8=AE=AD=E7=BB=83=E8=84=9A=E6=9C=AC?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=20gradient=5Fcheckpointing=20=E4=B8=8E?= =?UTF-8?q?=E5=A4=9A=E6=9C=BA=20DDP=20=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/tools/train.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index e5be30e..376a709 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -8,6 +8,7 @@ import torch.optim as optim from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.dataset import DatasetFactory from astrai.model import AutoRegressiveLM +from astrai.model.components.decoder_block import DecoderBlock from astrai.trainer import SchedulerFactory, Trainer @@ -115,6 +116,12 @@ def parse_args() -> argparse.Namespace: default=0.05, help="cross_entropy function label smoothing parameter", ) + parser.add_argument( + "--gradient_checkpointing", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable activation checkpointing for DecoderBlock modules.", + ) parser.add_argument( "--ckpt_interval", @@ -141,6 +148,24 @@ def parse_args() -> argparse.Namespace: "--start_batch", type=int, default=0, help="Start batch for training." ) + parser.add_argument( + "--master_addr", + type=str, + default="localhost", + help="Master node address for distributed training.", + ) + parser.add_argument( + "--master_port", + type=str, + default="29500", + help="Master node port for distributed training.", + ) + parser.add_argument( + "--backend", + type=str, + default="nccl", + help="Distributed training backend.", + ) parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument( "--parallel_mode", @@ -222,11 +247,15 @@ def train( random_seed: int, num_workers: int, pin_memory: bool, + gradient_checkpointing: bool, window_size: int, stride: int, nprocs: int, parallel_mode: str, device_type: str, + backend: str, + master_addr: str, + master_port: str, start_method: str, ): assert train_type in ["seq", "sft", "dpo", "grpo"] @@ -303,7 +332,13 @@ def train( random_seed=random_seed, num_workers=num_workers, pin_memory=pin_memory, + gradient_checkpointing_modules=[DecoderBlock] + if gradient_checkpointing + else [], nprocs=nprocs, + backend=backend, + master_addr=master_addr, + master_port=master_port, parallel_mode=parallel_mode, device_type=device_type, start_method=start_method,