refactor: 优化并行训练配置与启动管理
- 配置新增 start_method 支持 spawn/fork/forkserver 选择 - 启动方式 mp.spawn 改为 mp.start_processes,支持 daemon=True - validate() 改为基于 metadata 的反射式校验,不再硬编码字段列表 - CLI 新增 --start_method 参数
This commit is contained in:
parent
44dab27fdc
commit
c241a5dcef
|
|
@ -1,4 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
|
@ -9,17 +9,21 @@ from torch.utils.data import Dataset
|
|||
from astrai.config.base import BaseConfig
|
||||
|
||||
|
||||
def required(**kw):
|
||||
return {"required": True, **kw}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig(BaseConfig):
|
||||
# basic setting
|
||||
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
|
||||
strategy: str = field(default=None, metadata={"help": "Training strategy."})
|
||||
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
|
||||
model: nn.Module = field(default=None, metadata=required(help="Model for training."))
|
||||
strategy: str = field(default=None, metadata=required(help="Training strategy."))
|
||||
dataset: Dataset = field(default=None, metadata=required(help="Dataset for training."))
|
||||
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
|
||||
default=None, metadata={"help": "Optimizer factory for training."}
|
||||
default=None, metadata=required(help="Optimizer factory for training.")
|
||||
)
|
||||
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
|
||||
default=None, metadata={"help": "Scheduler factory for training."}
|
||||
default=None, metadata=required(help="Scheduler factory for training.")
|
||||
)
|
||||
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
|
||||
batch_per_device: int = field(
|
||||
|
|
@ -76,6 +80,10 @@ class TrainConfig(BaseConfig):
|
|||
state_dict_fn: Optional[Callable] = field(
|
||||
default=None, metadata={"help": "Parallel function for state dict saving."}
|
||||
)
|
||||
start_method: str = field(
|
||||
default="spawn",
|
||||
metadata={"help": "Multiprocessing start method (spawn/fork/forkserver)."},
|
||||
)
|
||||
|
||||
# others
|
||||
device_type: str = field(
|
||||
|
|
@ -89,14 +97,8 @@ class TrainConfig(BaseConfig):
|
|||
self.validate()
|
||||
|
||||
def validate(self):
|
||||
required_fields = [
|
||||
"model",
|
||||
"strategy",
|
||||
"dataset",
|
||||
"optimizer_fn",
|
||||
"scheduler_fn",
|
||||
]
|
||||
|
||||
for field_name in required_fields:
|
||||
if getattr(self, field_name) is None:
|
||||
raise ValueError(f"{field_name} is required.")
|
||||
for fld in fields(self):
|
||||
if fld.metadata.get("required") and getattr(self, fld.name) is None:
|
||||
raise ValueError(
|
||||
f"TrainConfig.{fld.name} is required but got None."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -123,6 +123,7 @@ def spawn_parallel_fn(
|
|||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
device_type: str = "cuda",
|
||||
start_method: str = "spawn",
|
||||
**kwargs,
|
||||
):
|
||||
# clear environment variables
|
||||
|
|
@ -156,6 +157,11 @@ def spawn_parallel_fn(
|
|||
kwargs,
|
||||
)
|
||||
|
||||
mp.spawn(
|
||||
wrapper_spawn_func, nprocs=world_size, args=wrapper_spawn_func_args, join=True
|
||||
mp.start_processes(
|
||||
wrapper_spawn_func,
|
||||
args=wrapper_spawn_func_args,
|
||||
nprocs=world_size,
|
||||
start_method=start_method,
|
||||
join=True,
|
||||
daemon=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class Trainer:
|
|||
master_addr=cfg.master_addr,
|
||||
master_port=cfg.master_port,
|
||||
device_type=cfg.device_type,
|
||||
start_method=cfg.start_method,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -149,6 +149,13 @@ def parse_args() -> argparse.Namespace:
|
|||
parser.add_argument(
|
||||
"--device_type", type=str, default="cuda", help="Device type to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start_method",
|
||||
type=str,
|
||||
default="spawn",
|
||||
choices=["spawn", "fork", "forkserver"],
|
||||
help="Multiprocessing start method.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -232,6 +239,7 @@ def train(
|
|||
stride: int,
|
||||
nprocs: int,
|
||||
device_type: str,
|
||||
start_method: str,
|
||||
):
|
||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||
assert os.path.exists(param_path)
|
||||
|
|
@ -314,6 +322,7 @@ def train(
|
|||
parallel_wrapper=ddp_wrap,
|
||||
state_dict_fn=prepare_checkpoint,
|
||||
device_type=device_type,
|
||||
start_method=start_method,
|
||||
extra_kwargs=strategy_kwargs,
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue