feat : 训练脚本新增 gradient_checkpointing 与多机 DDP 参数
This commit is contained in:
parent
746a1475b2
commit
0deee48602
|
|
@ -8,6 +8,7 @@ import torch.optim as optim
|
||||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import AutoRegressiveLM
|
from astrai.model import AutoRegressiveLM
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -115,6 +116,12 @@ def parse_args() -> argparse.Namespace:
|
||||||
default=0.05,
|
default=0.05,
|
||||||
help="cross_entropy function label smoothing parameter",
|
help="cross_entropy function label smoothing parameter",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_checkpointing",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=False,
|
||||||
|
help="Enable activation checkpointing for DecoderBlock modules.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ckpt_interval",
|
"--ckpt_interval",
|
||||||
|
|
@ -141,6 +148,24 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--start_batch", type=int, default=0, help="Start batch for training."
|
"--start_batch", type=int, default=0, help="Start batch for training."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master_addr",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="Master node address for distributed training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--master_port",
|
||||||
|
type=str,
|
||||||
|
default="29500",
|
||||||
|
help="Master node port for distributed training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="nccl",
|
||||||
|
help="Distributed training backend.",
|
||||||
|
)
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--parallel_mode",
|
"--parallel_mode",
|
||||||
|
|
@ -222,11 +247,15 @@ def train(
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
|
gradient_checkpointing: bool,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
parallel_mode: str,
|
parallel_mode: str,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
|
backend: str,
|
||||||
|
master_addr: str,
|
||||||
|
master_port: str,
|
||||||
start_method: str,
|
start_method: str,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||||
|
|
@ -303,7 +332,13 @@ def train(
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
gradient_checkpointing_modules=[DecoderBlock]
|
||||||
|
if gradient_checkpointing
|
||||||
|
else [],
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
|
backend=backend,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
parallel_mode=parallel_mode,
|
parallel_mode=parallel_mode,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
start_method=start_method,
|
start_method=start_method,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue