From c241a5dcef0c3f520543363ed8aa5f6f341d8257 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 12:32:15 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E5=B9=B6?= =?UTF-8?q?=E8=A1=8C=E8=AE=AD=E7=BB=83=E9=85=8D=E7=BD=AE=E4=B8=8E=E5=90=AF?= =?UTF-8?q?=E5=8A=A8=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 配置新增 start_method 支持 spawn/fork/forkserver 选择 - 启动方式 mp.spawn 改为 mp.start_processes,支持 daemon=True - validate() 改为基于 metadata 的反射式校验,不再硬编码字段列表 - CLI 新增 --start_method 参数 --- astrai/config/train_config.py | 36 ++++++++++++++++++----------------- astrai/parallel/setup.py | 10 ++++++++-- astrai/trainer/trainer.py | 1 + scripts/tools/train.py | 9 +++++++++ 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 801edd6..0a60de0 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from typing import Callable, Optional import torch.nn as nn @@ -9,17 +9,21 @@ from torch.utils.data import Dataset from astrai.config.base import BaseConfig +def required(**kw): + return {"required": True, **kw} + + @dataclass class TrainConfig(BaseConfig): # basic setting - model: nn.Module = field(default=None, metadata={"help": "Model for training."}) - strategy: str = field(default=None, metadata={"help": "Training strategy."}) - dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."}) + model: nn.Module = field(default=None, metadata=required(help="Model for training.")) + strategy: str = field(default=None, metadata=required(help="Training strategy.")) + dataset: Dataset = field(default=None, metadata=required(help="Dataset for training.")) optimizer_fn: Callable[[nn.Module], Optimizer] = field( - default=None, metadata={"help": "Optimizer factory for training."} + default=None, metadata=required(help="Optimizer factory for training.") ) scheduler_fn: Callable[[Optimizer], LRScheduler] = field( - default=None, metadata={"help": "Scheduler factory for training."} + default=None, metadata=required(help="Scheduler factory for training.") ) n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."}) batch_per_device: int = field( @@ -76,6 +80,10 @@ class TrainConfig(BaseConfig): state_dict_fn: Optional[Callable] = field( default=None, metadata={"help": "Parallel function for state dict saving."} ) + start_method: str = field( + default="spawn", + metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."}, + ) # others device_type: str = field( @@ -89,14 +97,8 @@ class TrainConfig(BaseConfig): self.validate() def validate(self): - required_fields = [ - "model", - "strategy", - "dataset", - "optimizer_fn", - "scheduler_fn", - ] - - for field_name in required_fields: - if getattr(self, field_name) is None: - raise ValueError(f"{field_name} is required.") + for fld in fields(self): + if fld.metadata.get("required") and getattr(self, fld.name) is None: + raise ValueError( + f"TrainConfig.{fld.name} is required but got None." + ) diff --git a/astrai/parallel/setup.py b/astrai/parallel/setup.py index 3128a67..7fee102 100644 --- a/astrai/parallel/setup.py +++ b/astrai/parallel/setup.py @@ -123,6 +123,7 @@ def spawn_parallel_fn( master_addr: str = "localhost", master_port: str = "29500", device_type: str = "cuda", + start_method: str = "spawn", **kwargs, ): # clear environment variables @@ -156,6 +157,11 @@ def spawn_parallel_fn( kwargs, ) - mp.spawn( - wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True + mp.start_processes( + wrapper_spawn_func, + args=wrapper_spawn_func_args, + nprocs=world_size, + start_method=start_method, + join=True, + daemon=True, ) diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 2eb565f..4264c47 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -52,6 +52,7 @@ class Trainer: master_addr=cfg.master_addr, master_port=cfg.master_port, device_type=cfg.device_type, + start_method=cfg.start_method, checkpoint=checkpoint, ) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index f04ed3b..db74745 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -149,6 +149,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--device_type", type=str, default="cuda", help="Device type to use." ) + parser.add_argument( + "--start_method", + type=str, + default="spawn", + choices=["spawn", "fork", "forkserver"], + help="Multiprocessing start method.", + ) args = parser.parse_args() @@ -232,6 +239,7 @@ def train( stride: int, nprocs: int, device_type: str, + start_method: str, ): assert train_type in ["seq", "sft", "dpo", "grpo"] assert os.path.exists(param_path) @@ -314,6 +322,7 @@ def train( parallel_wrapper=ddp_wrap, state_dict_fn=prepare_checkpoint, device_type=device_type, + start_method=start_method, extra_kwargs=strategy_kwargs, )