refactor: 优化并行训练配置与启动管理

- 配置新增 start_method 支持 spawn/fork/forkserver 选择
- 启动方式 mp.spawn 改为 mp.start_processes,支持 daemon=True
- validate() 改为基于 metadata 的反射式校验,不再硬编码字段列表
- CLI 新增 --start_method 参数
This commit is contained in:
ViperEkura 2026-05-17 12:32:15 +08:00
parent 44dab27fdc
commit c241a5dcef
4 changed files with 37 additions and 19 deletions

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field, fields
from typing import Callable, Optional from typing import Callable, Optional
import torch.nn as nn import torch.nn as nn
@ -9,17 +9,21 @@ from torch.utils.data import Dataset
from astrai.config.base import BaseConfig from astrai.config.base import BaseConfig
def required(**kw):
return {"required": True, **kw}
@dataclass @dataclass
class TrainConfig(BaseConfig): class TrainConfig(BaseConfig):
# basic setting # basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."}) model: nn.Module = field(default=None, metadata=required(help="Model for training."))
strategy: str = field(default=None, metadata={"help": "Training strategy."}) strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."}) dataset: Dataset = field(default=None, metadata=required(help="Dataset for training."))
optimizer_fn: Callable[[nn.Module], Optimizer] = field( optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata={"help": "Optimizer factory for training."} default=None, metadata=required(help="Optimizer factory for training.")
) )
scheduler_fn: Callable[[Optimizer], LRScheduler] = field( scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, metadata={"help": "Scheduler factory for training."} default=None, metadata=required(help="Scheduler factory for training.")
) )
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."}) n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_per_device: int = field( batch_per_device: int = field(
@ -76,6 +80,10 @@ class TrainConfig(BaseConfig):
state_dict_fn: Optional[Callable] = field( state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."} default=None, metadata={"help": "Parallel function for state dict saving."}
) )
start_method: str = field(
default="spawn",
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
)
# others # others
device_type: str = field( device_type: str = field(
@ -89,14 +97,8 @@ class TrainConfig(BaseConfig):
self.validate() self.validate()
def validate(self): def validate(self):
required_fields = [ for fld in fields(self):
"model", if fld.metadata.get("required") and getattr(self, fld.name) is None:
"strategy", raise ValueError(
"dataset", f"TrainConfig.{fld.name} is required but got None."
"optimizer_fn", )
"scheduler_fn",
]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")

View File

@ -123,6 +123,7 @@ def spawn_parallel_fn(
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
start_method: str = "spawn",
**kwargs, **kwargs,
): ):
# clear environment variables # clear environment variables
@ -156,6 +157,11 @@ def spawn_parallel_fn(
kwargs, kwargs,
) )
mp.spawn( mp.start_processes(
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True wrapper_spawn_func,
args=wrapper_spawn_func_args,
nprocs=world_size,
start_method=start_method,
join=True,
daemon=True,
) )

View File

@ -52,6 +52,7 @@ class Trainer:
master_addr=cfg.master_addr, master_addr=cfg.master_addr,
master_port=cfg.master_port, master_port=cfg.master_port,
device_type=cfg.device_type, device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint, checkpoint=checkpoint,
) )

View File

@ -149,6 +149,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use." "--device_type", type=str, default="cuda", help="Device type to use."
) )
parser.add_argument(
"--start_method",
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method.",
)
args = parser.parse_args() args = parser.parse_args()
@ -232,6 +239,7 @@ def train(
stride: int, stride: int,
nprocs: int, nprocs: int,
device_type: str, device_type: str,
start_method: str,
): ):
assert train_type in ["seq", "sft", "dpo", "grpo"] assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path) assert os.path.exists(param_path)
@ -314,6 +322,7 @@ def train(
parallel_wrapper=ddp_wrap, parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint, state_dict_fn=prepare_checkpoint,
device_type=device_type, device_type=device_type,
start_method=start_method,
extra_kwargs=strategy_kwargs, extra_kwargs=strategy_kwargs,
) )