diff --git a/scripts/tools/train.py b/scripts/tools/train.py index e5be30e..376a709 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -8,6 +8,7 @@ import torch.optim as optim from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.dataset import DatasetFactory from astrai.model import AutoRegressiveLM +from astrai.model.components.decoder_block import DecoderBlock from astrai.trainer import SchedulerFactory, Trainer @@ -115,6 +116,12 @@ def parse_args() -> argparse.Namespace: default=0.05, help="cross_entropy function label smoothing parameter", ) + parser.add_argument( + "--gradient_checkpointing", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable activation checkpointing for DecoderBlock modules.", + ) parser.add_argument( "--ckpt_interval", @@ -141,6 +148,24 @@ def parse_args() -> argparse.Namespace: "--start_batch", type=int, default=0, help="Start batch for training." ) + parser.add_argument( + "--master_addr", + type=str, + default="localhost", + help="Master node address for distributed training.", + ) + parser.add_argument( + "--master_port", + type=str, + default="29500", + help="Master node port for distributed training.", + ) + parser.add_argument( + "--backend", + type=str, + default="nccl", + help="Distributed training backend.", + ) parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument( "--parallel_mode", @@ -222,11 +247,15 @@ def train( random_seed: int, num_workers: int, pin_memory: bool, + gradient_checkpointing: bool, window_size: int, stride: int, nprocs: int, parallel_mode: str, device_type: str, + backend: str, + master_addr: str, + master_port: str, start_method: str, ): assert train_type in ["seq", "sft", "dpo", "grpo"] @@ -303,7 +332,13 @@ def train( random_seed=random_seed, num_workers=num_workers, pin_memory=pin_memory, + gradient_checkpointing_modules=[DecoderBlock] + if gradient_checkpointing + else [], nprocs=nprocs, + backend=backend, + master_addr=master_addr, + master_port=master_port, parallel_mode=parallel_mode, device_type=device_type, start_method=start_method,