From 4145d35e3c25f2c007ab70fb3e3a65be7482e1d8 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 27 May 2026 20:06:44 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=A3=80=E6=9F=A5=E7=82=B9?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=E9=87=8D=E6=9E=84=EF=BC=8C=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E6=9B=BF=E4=BB=A3=E5=AF=B9=E8=B1=A1=E4=BC=A0=E9=80=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串 - Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递 - TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch - CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁 - serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast - optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行 - pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物 --- astrai/config/train_config.py | 4 +- astrai/serialization.py | 121 ++++++++++++++++++++------- astrai/trainer/train_callback.py | 42 ++++------ astrai/trainer/train_context.py | 53 +++++++----- astrai/trainer/trainer.py | 11 ++- scripts/tools/train.py | 19 ++--- tests/trainer/conftest.py | 6 +- tests/trainer/test_callbacks.py | 4 +- tests/trainer/test_early_stopping.py | 20 +++-- tests/trainer/test_trainer.py | 6 +- 10 files changed, 170 insertions(+), 116 deletions(-) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index de8ebf6..c6e78d1 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -17,8 +17,8 @@ def required(**kw): @dataclass class TrainConfig(BaseConfig): # basic setting - model: nn.Module = field( - default=None, metadata=required(help="Model for training.") + model_fn: Callable[[], nn.Module] = field( + default=None, metadata=required(help="Model factory for training.") ) strategy: str = field(default=None, metadata=required(help="Training strategy.")) dataset: Dataset = field( diff --git a/astrai/serialization.py b/astrai/serialization.py index 1990ae5..a21d55b 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -1,8 +1,9 @@ +import io import json import time from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, List, Tuple import safetensors.torch as st import torch @@ -11,8 +12,8 @@ import torch.distributed as dist from astrai.parallel.setup import get_rank _META_FILE = "meta.json" +_CONFIG_FILE = "config.json" _WEIGHTS_FILE = "model.safetensors" -_MODEL_CONFIG_FILE = "config.json" def save_safetensors(state_dict: dict, path: str | Path) -> None: @@ -37,8 +38,87 @@ def save_torch(obj: Any, path: str | Path) -> None: torch.save(obj, str(path)) -def load_torch(path: str | Path) -> Any: - return torch.load(str(path), map_location="cpu", weights_only=False) +def load_torch(path: str | Path, broadcast: bool = False) -> Any: + if not broadcast or not dist.is_initialized(): + return torch.load(str(path), map_location="cpu", weights_only=False) + + path = Path(path) + rank = get_rank() + + if rank == 0: + with open(path, "rb") as f: + raw = f.read() + data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8) + num_bytes = torch.tensor([len(raw)], dtype=torch.long) + else: + num_bytes = torch.tensor([0], dtype=torch.long) + + dist.broadcast(num_bytes, src=0) + + if rank != 0: + data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8) + + dist.broadcast(data_tensor, src=0) + + buf = io.BytesIO(data_tensor.numpy().tobytes()) + return torch.load(buf, map_location="cpu", weights_only=False) + + +def save_model(config: dict, state_dict: dict, save_directory: str) -> None: + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + save_json(config, save_path / _CONFIG_FILE) + save_safetensors(state_dict, save_path / _WEIGHTS_FILE) + + +def load_model_config(save_directory: str) -> dict: + return load_json(Path(save_directory) / _CONFIG_FILE) + + +def load_model_weights(save_directory: str) -> dict: + return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) + + +def _get_meta(save_path: Path) -> dict: + meta = {} + if get_rank() == 0: + meta = load_json(save_path / _META_FILE) + if dist.is_initialized(): + meta_list = [meta] + dist.broadcast_object_list(meta_list, src=0) + meta = meta_list[0] + return meta + + +def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict: + if not broadcast or not dist.is_initialized(): + return load_safetensors(save_path / _WEIGHTS_FILE) + + rank = get_rank() + if rank == 0: + state_dict = load_safetensors(save_path / _WEIGHTS_FILE) + specs: List[Tuple[str, List[int], str]] = [ + (k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1]) + for k in sorted(state_dict) + ] + else: + state_dict = {} + specs = [] + + specs_list = [specs] + dist.broadcast_object_list(specs_list, src=0) + specs = specs_list[0] + + for key, shape, dtype_name in specs: + dtype = getattr(torch, dtype_name) + if rank != 0: + tensor = torch.empty(shape, dtype=dtype, device="cpu") + else: + tensor = state_dict[key].contiguous().cpu() + dist.broadcast(tensor, src=0) + if rank != 0: + state_dict[key] = tensor + return state_dict @dataclass @@ -68,24 +148,16 @@ class Checkpoint: save_torch(value, save_path / f"{key}.pt") @classmethod - def load(cls, save_dir: str) -> "Checkpoint": + def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": save_path = Path(save_dir) - meta = {} - if get_rank() == 0: - meta = load_json(save_path / _META_FILE) - - if dist.is_initialized(): - meta_list = [meta] - dist.broadcast_object_list(meta_list, src=0) - meta = meta_list[0] - - state_dict = load_safetensors(save_path / _WEIGHTS_FILE) + meta = _get_meta(save_path) + state_dict = _load_state_dict(save_path, broadcast=broadcast) extra = {} - for f in save_path.iterdir(): + for f in sorted(save_path.iterdir()): if f.suffix == ".pt": - extra[f.stem] = load_torch(f) + extra[f.stem] = load_torch(f, broadcast=broadcast) return cls( state_dict=state_dict, @@ -93,18 +165,3 @@ class Checkpoint: iteration=meta.get("iteration", 0), extra=extra, ) - - -def save_model(config: dict, state_dict: dict, save_directory: str) -> None: - save_path = Path(save_directory) - save_path.mkdir(parents=True, exist_ok=True) - save_json(config, save_path / _MODEL_CONFIG_FILE) - save_safetensors(state_dict, save_path / _WEIGHTS_FILE) - - -def load_model_config(save_directory: str) -> dict: - return load_json(Path(save_directory) / _MODEL_CONFIG_FILE) - - -def load_model_weights(save_directory: str) -> dict: - return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index ee55a43..ba85b2d 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -15,7 +15,7 @@ from tqdm import tqdm from astrai.factory import BaseFactory from astrai.parallel import only_on_rank -from astrai.parallel.setup import get_current_device +from astrai.parallel.setup import get_current_device, get_rank from astrai.serialization import Checkpoint from astrai.trainer.metric_util import ( ctx_get_grad_max, @@ -139,42 +139,36 @@ class CheckpointCallback(TrainCallback): weight_only: bool = False, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, - load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only self.state_dict_fn = state_dict_fn self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra - self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra self.last_ckpt_iter = 0 - @only_on_rank(0) def _save_checkpoint(self, context: TrainContext): - save_path = os.path.join( - self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}" - ) + # All ranks gather state_dict — collective for FSDP, local for DDP state_dict = ( self.state_dict_fn(context.model) if self.state_dict_fn else context.model.state_dict() ) - - extra = self.save_extra_fn(context) - context.checkpoint = Checkpoint( - state_dict=state_dict, - epoch=context.epoch, - iteration=context.iteration, - extra=extra, - meta=context.config.to_dict(), - ) - - context.checkpoint.save(save_path) self.last_ckpt_iter = context.iteration - def on_train_begin(self, context: TrainContext): - if context.checkpoint and context.checkpoint.extra: - self.load_extra_fn(context.checkpoint.extra, context) + if get_rank() == 0: + save_path = os.path.join( + self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}" + ) + extra = self.save_extra_fn(context) + context.checkpoint = Checkpoint( + state_dict=state_dict, + epoch=context.epoch, + iteration=context.iteration, + extra=extra, + meta=context.config.to_dict(), + ) + context.checkpoint.save(save_path) def on_batch_end(self, context: TrainContext): if context.iteration - self.last_ckpt_iter >= self.interval: @@ -196,12 +190,6 @@ class CheckpointCallback(TrainCallback): extra[name] = obj.state_dict() return extra - @staticmethod - def load_extra(extra: dict, context: TrainContext): - for name in CheckpointCallback.extra_keys: - if name in extra: - getattr(context, name).load_state_dict(extra[name]) - @CallbackFactory.register("progress_bar") class ProgressBarCallback(TrainCallback): diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 13caf24..d80dae8 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from pathlib import Path from typing import Optional, Self import torch.nn as nn @@ -10,7 +11,7 @@ from astrai.model.components.lora import inject_lora from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.protocols import OptimizerProtocol, SchedulerProtocol -from astrai.serialization import Checkpoint +from astrai.serialization import Checkpoint, load_model_weights from astrai.trainer.strategy import BaseStrategy, StrategyFactory @@ -42,10 +43,10 @@ class TrainContextBuilder: config: TrainConfig, ): self.config = config - self._checkpoint: Optional[Checkpoint] = None + self._resume_dir: Optional[str] = None - def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: - self._checkpoint = checkpoint + def with_resume_dir(self, resume_dir: Optional[str]) -> Self: + self._resume_dir = resume_dir return self def build(self) -> TrainContext: @@ -58,36 +59,40 @@ class TrainContextBuilder: **cfg.executor_kwargs, ) + model = cfg.model_fn() + model = model.to(device=device) + context = TrainContext( - model=cfg.model, + model=model, world_size=get_world_size(), rank=get_rank(), config=cfg, executor=executor, ) - context.model = context.model.to(device=device) - - if self._checkpoint is not None: - context.epoch = max(self._checkpoint.epoch, cfg.start_epoch) - context.iteration = max(self._checkpoint.iteration, cfg.start_batch) - if self._checkpoint.state_dict: - context.model.load_state_dict(self._checkpoint.state_dict) - context.checkpoint = self._checkpoint - else: - context.checkpoint = Checkpoint( - state_dict=context.model.state_dict(), - ) + if self._resume_dir is not None: + resume_path = Path(self._resume_dir) + if (resume_path / "meta.json").exists(): + checkpoint = Checkpoint.load(self._resume_dir) + state_dict = checkpoint.state_dict + else: + checkpoint = None + state_dict = load_model_weights(self._resume_dir) + model.load_state_dict(state_dict, strict=False) + if checkpoint is not None: + context.epoch = max(checkpoint.epoch, cfg.start_epoch) + context.iteration = max(checkpoint.iteration, cfg.start_batch) + context.checkpoint = checkpoint if cfg.lora is not None: inject_lora( - context.model, + model, r=cfg.lora.r, alpha=cfg.lora.alpha, target_modules=set(cfg.lora.target_modules), ) - context.optimizer = cfg.optimizer_fn(context.model) + context.optimizer = cfg.optimizer_fn(model) context.scheduler = cfg.scheduler_fn(context.optimizer) sampler_offset = context.iteration * cfg.batch_per_device @@ -125,13 +130,21 @@ class TrainContextBuilder: context.model, context.optimizer, context.dataloader, context.scheduler = ( executor.prepare( - context.model, + model, context.optimizer, context.dataloader, context.scheduler, ) ) + if context.checkpoint and context.checkpoint.extra: + extra = context.checkpoint.extra + for name in ("optimizer", "scheduler"): + if name in extra: + obj = getattr(context, name, None) + if obj is not None: + obj.load_state_dict(extra[name]) + context.strategy = StrategyFactory.create( model=context.model, train_type=cfg.strategy, diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 81e4044..aa8b467 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -3,7 +3,6 @@ from typing import List, Optional from astrai.config import TrainConfig from astrai.parallel.setup import spawn_parallel_fn -from astrai.serialization import Checkpoint from astrai.trainer.train_callback import ( CallbackFactory, TrainCallback, @@ -54,9 +53,9 @@ class Trainer: if method: method(context) - def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None): + def _trainer_loop(self, resume_dir: Optional[str] = None): context = ( - TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build() + TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build() ) executor = context.executor self._call_callbacks("on_train_begin", context) @@ -90,13 +89,13 @@ class Trainer: self._call_callbacks("on_epoch_end", context) except Exception as e: - logger.error(f"Training failed: {str(e)}", exc_info=True) + logger.error("Training failed: %s", str(e), exc_info=True) self._call_callbacks("on_error", context) raise finally: self._call_callbacks("on_train_end", context) - def train(self, checkpoint: Optional[Checkpoint] = None): + def train(self, resume_dir: Optional[str] = None): cfg = self.train_config spawn_parallel_fn( self._trainer_loop, @@ -106,5 +105,5 @@ class Trainer: master_port=cfg.master_port, device_type=cfg.device_type, start_method=cfg.start_method, - checkpoint=checkpoint, + resume_dir=resume_dir, ) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 044054a..e5be30e 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -8,7 +8,6 @@ import torch.optim as optim from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.dataset import DatasetFactory from astrai.model import AutoRegressiveLM -from astrai.serialization import Checkpoint from astrai.trainer import SchedulerFactory, Trainer @@ -166,6 +165,10 @@ def parse_args() -> argparse.Namespace: return args +def create_model(config): + return AutoRegressiveLM(config).to(dtype=torch.bfloat16) + + def create_optimizer(model, **kwargs) -> optim.Optimizer: return optim.AdamW(model.parameters(), fused=True, **kwargs) @@ -238,15 +241,6 @@ def train( if window_size is None: window_size = config.max_len - # Create model and load full checkpoint (state_dict + optimizer + scheduler + meta) - checkpoint = Checkpoint.load(param_path) - model = AutoRegressiveLM(config).to(dtype=torch.bfloat16) - model.load_state_dict(checkpoint.state_dict, strict=False) - - # Strip state_dict to avoid pickling ~7GB through mp.spawn pipe - # (model weights already loaded into model above) - checkpoint.state_dict = {} - strategy_kwargs = { "beta": dpo_beta, "label_smoothing": label_smoothing, @@ -261,6 +255,7 @@ def train( "broadcast_buffers": False, } + model_fn = partial(create_model, config) dataset = DatasetFactory.load( train_type=train_type, load_path=data_root_path, @@ -292,7 +287,7 @@ def train( ) train_config = TrainConfig( - model=model, + model_fn=model_fn, strategy=train_type, dataset=dataset, optimizer_fn=optimizer_fn, @@ -317,7 +312,7 @@ def train( ) trainer = Trainer(train_config) - trainer.train(checkpoint=checkpoint) + trainer.train(resume_dir=param_path) if __name__ == "__main__": diff --git a/tests/trainer/conftest.py b/tests/trainer/conftest.py index 0b76ca6..5ce6c51 100644 --- a/tests/trainer/conftest.py +++ b/tests/trainer/conftest.py @@ -27,7 +27,7 @@ class TrainerDataset(Dataset): def create_train_config( - model: torch.nn.Module, + model_fn, dataset: Dataset, test_dir: str, device: str, @@ -43,7 +43,7 @@ def create_train_config( """Factory function to create common TrainConfig for tests. Args: - model: The model to train + model_fn: Model factory (callable returning nn.Module) dataset: Training dataset test_dir: Checkpoint directory device: Device type ("cuda" or "cpu") @@ -70,7 +70,7 @@ def create_train_config( return TrainConfig( strategy=strategy, - model=model, + model_fn=model_fn, dataset=dataset, optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index 5be6e6c..b604a9c 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase ) train_config = TrainConfig( - model=base_test_env["model"], + model_fn=lambda: base_test_env["model"], strategy="seq", dataset=random_dataset, optimizer_fn=optimizer_fn, @@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset): ) train_config = TrainConfig( - model=base_test_env["model"], + model_fn=lambda: base_test_env["model"], strategy="seq", dataset=random_dataset, optimizer_fn=optimizer_fn, diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index 70a4301..2047d7f 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -4,7 +4,6 @@ import numpy as np import torch from astrai.config.train_config import TrainConfig -from astrai.serialization import Checkpoint from astrai.trainer.schedule import SchedulerFactory from astrai.trainer.trainer import Trainer @@ -24,7 +23,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): strategy="seq", optimizer_fn=optimizer_fn, scheduler_fn=scheduler_fn, - model=base_test_env["model"], + model_fn=lambda: base_test_env["model"], dataset=early_stopping_dataset, ckpt_dir=base_test_env["test_dir"], log_dir=os.path.join(base_test_env["test_dir"], "logs"), @@ -39,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset): trainer = Trainer(train_config) # Should handle early stopping gracefully - checkpoint = None try: - checkpoint = trainer.train() + trainer.train() except Exception: - # Handle any exceptions pass + # Resume from latest checkpoint load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2") - checkpoint = Checkpoint.load(load_dir) - trainer.train(checkpoint) + trainer = Trainer(train_config) + trainer.train(resume_dir=load_dir) + # Verify checkpoint was saved at expected iteration load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10") - checkpoint = Checkpoint.load(load_dir) - assert checkpoint.iteration == 10 + import json + + with open(os.path.join(load_dir, "meta.json")) as f: + meta = json.load(f) + assert meta["iteration"] == 10 diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index f51cb40..d941dcd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto for batch_per_device in batch_sizes: train_config = train_config_factory( - model=base_test_env["model"], + model_fn=lambda: base_test_env["model"], dataset=random_dataset, test_dir=base_test_env["test_dir"], device=base_test_env["device"], @@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto for grad_accum_steps in grad_accum_steps_list: train_config = train_config_factory( - model=base_test_env["model"], + model_fn=lambda: base_test_env["model"], dataset=random_dataset, test_dir=base_test_env["test_dir"], device=base_test_env["device"], @@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f for config in small_batch_configs: train_config = train_config_factory( - model=base_test_env["model"], + model_fn=lambda: base_test_env["model"], dataset=random_dataset, test_dir=base_test_env["test_dir"], device=base_test_env["device"],