diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 801edd6..0a60de0 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from typing import Callable, Optional import torch.nn as nn @@ -9,17 +9,21 @@ from torch.utils.data import Dataset from astrai.config.base import BaseConfig +def required(**kw): + return {"required": True, **kw} + + @dataclass class TrainConfig(BaseConfig): # basic setting - model: nn.Module = field(default=None, metadata={"help": "Model for training."}) - strategy: str = field(default=None, metadata={"help": "Training strategy."}) - dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."}) + model: nn.Module = field(default=None, metadata=required(help="Model for training.")) + strategy: str = field(default=None, metadata=required(help="Training strategy.")) + dataset: Dataset = field(default=None, metadata=required(help="Dataset for training.")) optimizer_fn: Callable[[nn.Module], Optimizer] = field( - default=None, metadata={"help": "Optimizer factory for training."} + default=None, metadata=required(help="Optimizer factory for training.") ) scheduler_fn: Callable[[Optimizer], LRScheduler] = field( - default=None, metadata={"help": "Scheduler factory for training."} + default=None, metadata=required(help="Scheduler factory for training.") ) n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."}) batch_per_device: int = field( @@ -76,6 +80,10 @@ class TrainConfig(BaseConfig): state_dict_fn: Optional[Callable] = field( default=None, metadata={"help": "Parallel function for state dict saving."} ) + start_method: str = field( + default="spawn", + metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."}, + ) # others device_type: str = field( @@ -89,14 +97,8 @@ class TrainConfig(BaseConfig): self.validate() def validate(self): - required_fields = [ - "model", - "strategy", - "dataset", - "optimizer_fn", - "scheduler_fn", - ] - - for field_name in required_fields: - if getattr(self, field_name) is None: - raise ValueError(f"{field_name} is required.") + for fld in fields(self): + if fld.metadata.get("required") and getattr(self, fld.name) is None: + raise ValueError( + f"TrainConfig.{fld.name} is required but got None." + ) diff --git a/astrai/parallel/setup.py b/astrai/parallel/setup.py index 3128a67..7fee102 100644 --- a/astrai/parallel/setup.py +++ b/astrai/parallel/setup.py @@ -123,6 +123,7 @@ def spawn_parallel_fn( master_addr: str = "localhost", master_port: str = "29500", device_type: str = "cuda", + start_method: str = "spawn", **kwargs, ): # clear environment variables @@ -156,6 +157,11 @@ def spawn_parallel_fn( kwargs, ) - mp.spawn( - wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True + mp.start_processes( + wrapper_spawn_func, + args=wrapper_spawn_func_args, + nprocs=world_size, + start_method=start_method, + join=True, + daemon=True, ) diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 2eb565f..4264c47 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -52,6 +52,7 @@ class Trainer: master_addr=cfg.master_addr, master_port=cfg.master_port, device_type=cfg.device_type, + start_method=cfg.start_method, checkpoint=checkpoint, ) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index f04ed3b..db74745 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -149,6 +149,13 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--device_type", type=str, default="cuda", help="Device type to use." ) + parser.add_argument( + "--start_method", + type=str, + default="spawn", + choices=["spawn", "fork", "forkserver"], + help="Multiprocessing start method.", + ) args = parser.parse_args() @@ -232,6 +239,7 @@ def train( stride: int, nprocs: int, device_type: str, + start_method: str, ): assert train_type in ["seq", "sft", "dpo", "grpo"] assert os.path.exists(param_path) @@ -314,6 +322,7 @@ def train( parallel_wrapper=ddp_wrap, state_dict_fn=prepare_checkpoint, device_type=device_type, + start_method=start_method, extra_kwargs=strategy_kwargs, )