refactor: 优化并行训练配置与启动管理

- 配置新增 start_method 支持 spawn/fork/forkserver 选择
- 启动方式 mp.spawn 改为 mp.start_processes,支持 daemon=True
- validate() 改为基于 metadata 的反射式校验,不再硬编码字段列表
- CLI 新增 --start_method 参数
This commit is contained in:
ViperEkura 2026-05-17 12:32:15 +08:00
parent 44dab27fdc
commit c241a5dcef
4 changed files with 37 additions and 19 deletions

View File

@ -1,4 +1,4 @@
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from typing import Callable, Optional
import torch.nn as nn
@ -9,17 +9,21 @@ from torch.utils.data import Dataset
from astrai.config.base import BaseConfig
def required(**kw):
return {"required": True, **kw}
@dataclass
class TrainConfig(BaseConfig):
# basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."})
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
model: nn.Module = field(default=None, metadata=required(help="Model for training."))
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(default=None, metadata=required(help="Dataset for training."))
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata={"help": "Optimizer factory for training."}
default=None, metadata=required(help="Optimizer factory for training.")
)
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, metadata={"help": "Scheduler factory for training."}
default=None, metadata=required(help="Scheduler factory for training.")
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_per_device: int = field(
@ -76,6 +80,10 @@ class TrainConfig(BaseConfig):
state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."}
)
start_method: str = field(
default="spawn",
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
)
# others
device_type: str = field(
@ -89,14 +97,8 @@ class TrainConfig(BaseConfig):
self.validate()
def validate(self):
required_fields = [
"model",
"strategy",
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")
for fld in fields(self):
if fld.metadata.get("required") and getattr(self, fld.name) is None:
raise ValueError(
f"TrainConfig.{fld.name} is required but got None."
)

View File

@ -123,6 +123,7 @@ def spawn_parallel_fn(
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
start_method: str = "spawn",
**kwargs,
):
# clear environment variables
@ -156,6 +157,11 @@ def spawn_parallel_fn(
kwargs,
)
mp.spawn(
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
mp.start_processes(
wrapper_spawn_func,
args=wrapper_spawn_func_args,
nprocs=world_size,
start_method=start_method,
join=True,
daemon=True,
)

View File

@ -52,6 +52,7 @@ class Trainer:
master_addr=cfg.master_addr,
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint,
)

View File

@ -149,6 +149,13 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--device_type", type=str, default="cuda", help="Device type to use."
)
parser.add_argument(
"--start_method",
type=str,
default="spawn",
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method.",
)
args = parser.parse_args()
@ -232,6 +239,7 @@ def train(
stride: int,
nprocs: int,
device_type: str,
start_method: str,
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
@ -314,6 +322,7 @@ def train(
parallel_wrapper=ddp_wrap,
state_dict_fn=prepare_checkpoint,
device_type=device_type,
start_method=start_method,
extra_kwargs=strategy_kwargs,
)