refactor: 检查点加载重构,路径替代对象传递

- model: nn.Module -> model_fn 工厂函数,spawn 边界只传字符串
- Trainer.train(resume_dir=path) — Checkpoint 不再通过 pickle 传递
- TrainContextBuilder.with_resume_dir(path) — 自动检测 meta.json 分流 resume/from-scratch
- CheckpointCallback: 拆分 state_dict 收集(全 rank)与磁盘写入(rank-0),修复 FSDP 死锁
- serialization: load_torch 支持 broadcast,消除 _load_extra/_load_torch_broadcast
- optimizer/scheduler 恢复逻辑内联到 build(),在 executor.prepare() 之后执行
- pyproject.toml: ruff exclude build/ 避免 CI 扫描构建产物
This commit is contained in:
ViperEkura 2026-05-27 20:06:44 +08:00
parent 34c6c45bd6
commit 4145d35e3c
10 changed files with 170 additions and 116 deletions

View File

@ -17,8 +17,8 @@ def required(**kw):
@dataclass
class TrainConfig(BaseConfig):
# basic setting
model: nn.Module = field(
default=None, metadata=required(help="Model for training.")
model_fn: Callable[[], nn.Module] = field(
default=None, metadata=required(help="Model factory for training.")
)
strategy: str = field(default=None, metadata=required(help="Training strategy."))
dataset: Dataset = field(

View File

@ -1,8 +1,9 @@
import io
import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, List, Tuple
import safetensors.torch as st
import torch
@ -11,8 +12,8 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank
_META_FILE = "meta.json"
_CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors"
_MODEL_CONFIG_FILE = "config.json"
def save_safetensors(state_dict: dict, path: str | Path) -> None:
@ -37,8 +38,87 @@ def save_torch(obj: Any, path: str | Path) -> None:
torch.save(obj, str(path))
def load_torch(path: str | Path) -> Any:
return torch.load(str(path), map_location="cpu", weights_only=False)
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False)
path = Path(path)
rank = get_rank()
if rank == 0:
with open(path, "rb") as f:
raw = f.read()
data_tensor = torch.frombuffer(bytearray(raw), dtype=torch.uint8)
num_bytes = torch.tensor([len(raw)], dtype=torch.long)
else:
num_bytes = torch.tensor([0], dtype=torch.long)
dist.broadcast(num_bytes, src=0)
if rank != 0:
data_tensor = torch.empty(num_bytes.item(), dtype=torch.uint8)
dist.broadcast(data_tensor, src=0)
buf = io.BytesIO(data_tensor.numpy().tobytes())
return torch.load(buf, map_location="cpu", weights_only=False)
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
def _get_meta(save_path: Path) -> dict:
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
specs: List[Tuple[str, List[int], str]] = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
else:
state_dict = {}
specs = []
specs_list = [specs]
dist.broadcast_object_list(specs_list, src=0)
specs = specs_list[0]
for key, shape, dtype_name in specs:
dtype = getattr(torch, dtype_name)
if rank != 0:
tensor = torch.empty(shape, dtype=dtype, device="cpu")
else:
tensor = state_dict[key].contiguous().cpu()
dist.broadcast(tensor, src=0)
if rank != 0:
state_dict[key] = tensor
return state_dict
@dataclass
@ -68,24 +148,16 @@ class Checkpoint:
save_torch(value, save_path / f"{key}.pt")
@classmethod
def load(cls, save_dir: str) -> "Checkpoint":
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir)
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
meta = _get_meta(save_path)
state_dict = _load_state_dict(save_path, broadcast=broadcast)
extra = {}
for f in save_path.iterdir():
for f in sorted(save_path.iterdir()):
if f.suffix == ".pt":
extra[f.stem] = load_torch(f)
extra[f.stem] = load_torch(f, broadcast=broadcast)
return cls(
state_dict=state_dict,
@ -93,18 +165,3 @@ class Checkpoint:
iteration=meta.get("iteration", 0),
extra=extra,
)
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _MODEL_CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank
from astrai.parallel.setup import get_current_device
from astrai.parallel.setup import get_current_device, get_rank
from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import (
ctx_get_grad_max,
@ -139,42 +139,36 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0
@only_on_rank(0)
def _save_checkpoint(self, context: TrainContext):
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
# All ranks gather state_dict — collective for FSDP, local for DDP
state_dict = (
self.state_dict_fn(context.model)
if self.state_dict_fn
else context.model.state_dict()
)
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
if get_rank() == 0:
save_path = os.path.join(
self.save_dir, f"epoch_{context.epoch}_iter_{context.iteration}"
)
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
meta=context.config.to_dict(),
)
context.checkpoint.save(save_path)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
@ -196,12 +190,6 @@ class CheckpointCallback(TrainCallback):
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):

View File

@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Self
import torch.nn as nn
@ -10,7 +11,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint
from astrai.serialization import Checkpoint, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -42,10 +43,10 @@ class TrainContextBuilder:
config: TrainConfig,
):
self.config = config
self._checkpoint: Optional[Checkpoint] = None
self._resume_dir: Optional[str] = None
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
def with_resume_dir(self, resume_dir: Optional[str]) -> Self:
self._resume_dir = resume_dir
return self
def build(self) -> TrainContext:
@ -58,36 +59,40 @@ class TrainContextBuilder:
**cfg.executor_kwargs,
)
model = cfg.model_fn()
model = model.to(device=device)
context = TrainContext(
model=cfg.model,
model=model,
world_size=get_world_size(),
rank=get_rank(),
config=cfg,
executor=executor,
)
context.model = context.model.to(device=device)
if self._checkpoint is not None:
context.epoch = max(self._checkpoint.epoch, cfg.start_epoch)
context.iteration = max(self._checkpoint.iteration, cfg.start_batch)
if self._checkpoint.state_dict:
context.model.load_state_dict(self._checkpoint.state_dict)
context.checkpoint = self._checkpoint
else:
context.checkpoint = Checkpoint(
state_dict=context.model.state_dict(),
)
if self._resume_dir is not None:
resume_path = Path(self._resume_dir)
if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict
else:
checkpoint = None
state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False)
if checkpoint is not None:
context.epoch = max(checkpoint.epoch, cfg.start_epoch)
context.iteration = max(checkpoint.iteration, cfg.start_batch)
context.checkpoint = checkpoint
if cfg.lora is not None:
inject_lora(
context.model,
model,
r=cfg.lora.r,
alpha=cfg.lora.alpha,
target_modules=set(cfg.lora.target_modules),
)
context.optimizer = cfg.optimizer_fn(context.model)
context.optimizer = cfg.optimizer_fn(model)
context.scheduler = cfg.scheduler_fn(context.optimizer)
sampler_offset = context.iteration * cfg.batch_per_device
@ -125,13 +130,21 @@ class TrainContextBuilder:
context.model, context.optimizer, context.dataloader, context.scheduler = (
executor.prepare(
context.model,
model,
context.optimizer,
context.dataloader,
context.scheduler,
)
)
if context.checkpoint and context.checkpoint.extra:
extra = context.checkpoint.extra
for name in ("optimizer", "scheduler"):
if name in extra:
obj = getattr(context, name, None)
if obj is not None:
obj.load_state_dict(extra[name])
context.strategy = StrategyFactory.create(
model=context.model,
train_type=cfg.strategy,

View File

@ -3,7 +3,6 @@ from typing import List, Optional
from astrai.config import TrainConfig
from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import (
CallbackFactory,
TrainCallback,
@ -54,9 +53,9 @@ class Trainer:
if method:
method(context)
def _trainer_loop(self, checkpoint: Optional[Checkpoint] = None):
def _trainer_loop(self, resume_dir: Optional[str] = None):
context = (
TrainContextBuilder(self.train_config).with_checkpoint(checkpoint).build()
TrainContextBuilder(self.train_config).with_resume_dir(resume_dir).build()
)
executor = context.executor
self._call_callbacks("on_train_begin", context)
@ -90,13 +89,13 @@ class Trainer:
self._call_callbacks("on_epoch_end", context)
except Exception as e:
logger.error(f"Training failed: {str(e)}", exc_info=True)
logger.error("Training failed: %s", str(e), exc_info=True)
self._call_callbacks("on_error", context)
raise
finally:
self._call_callbacks("on_train_end", context)
def train(self, checkpoint: Optional[Checkpoint] = None):
def train(self, resume_dir: Optional[str] = None):
cfg = self.train_config
spawn_parallel_fn(
self._trainer_loop,
@ -106,5 +105,5 @@ class Trainer:
master_port=cfg.master_port,
device_type=cfg.device_type,
start_method=cfg.start_method,
checkpoint=checkpoint,
resume_dir=resume_dir,
)

View File

@ -8,7 +8,6 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM
from astrai.serialization import Checkpoint
from astrai.trainer import SchedulerFactory, Trainer
@ -166,6 +165,10 @@ def parse_args() -> argparse.Namespace:
return args
def create_model(config):
return AutoRegressiveLM(config).to(dtype=torch.bfloat16)
def create_optimizer(model, **kwargs) -> optim.Optimizer:
return optim.AdamW(model.parameters(), fused=True, **kwargs)
@ -238,15 +241,6 @@ def train(
if window_size is None:
window_size = config.max_len
# Create model and load full checkpoint (state_dict + optimizer + scheduler + meta)
checkpoint = Checkpoint.load(param_path)
model = AutoRegressiveLM(config).to(dtype=torch.bfloat16)
model.load_state_dict(checkpoint.state_dict, strict=False)
# Strip state_dict to avoid pickling ~7GB through mp.spawn pipe
# (model weights already loaded into model above)
checkpoint.state_dict = {}
strategy_kwargs = {
"beta": dpo_beta,
"label_smoothing": label_smoothing,
@ -261,6 +255,7 @@ def train(
"broadcast_buffers": False,
}
model_fn = partial(create_model, config)
dataset = DatasetFactory.load(
train_type=train_type,
load_path=data_root_path,
@ -292,7 +287,7 @@ def train(
)
train_config = TrainConfig(
model=model,
model_fn=model_fn,
strategy=train_type,
dataset=dataset,
optimizer_fn=optimizer_fn,
@ -317,7 +312,7 @@ def train(
)
trainer = Trainer(train_config)
trainer.train(checkpoint=checkpoint)
trainer.train(resume_dir=param_path)
if __name__ == "__main__":

View File

@ -27,7 +27,7 @@ class TrainerDataset(Dataset):
def create_train_config(
model: torch.nn.Module,
model_fn,
dataset: Dataset,
test_dir: str,
device: str,
@ -43,7 +43,7 @@ def create_train_config(
"""Factory function to create common TrainConfig for tests.
Args:
model: The model to train
model_fn: Model factory (callable returning nn.Module)
dataset: Training dataset
test_dir: Checkpoint directory
device: Device type ("cuda" or "cpu")
@ -70,7 +70,7 @@ def create_train_config(
return TrainConfig(
strategy=strategy,
model=model,
model_fn=model_fn,
dataset=dataset,
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,

View File

@ -106,7 +106,7 @@ def test_gradient_checkpointing_trainer_integration(base_test_env, random_datase
)
train_config = TrainConfig(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,
@ -140,7 +140,7 @@ def test_callback_integration(base_test_env, random_dataset):
)
train_config = TrainConfig(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
strategy="seq",
dataset=random_dataset,
optimizer_fn=optimizer_fn,

View File

@ -4,7 +4,6 @@ import numpy as np
import torch
from astrai.config.train_config import TrainConfig
from astrai.serialization import Checkpoint
from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer
@ -24,7 +23,7 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
strategy="seq",
optimizer_fn=optimizer_fn,
scheduler_fn=scheduler_fn,
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
dataset=early_stopping_dataset,
ckpt_dir=base_test_env["test_dir"],
log_dir=os.path.join(base_test_env["test_dir"], "logs"),
@ -39,17 +38,20 @@ def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
trainer = Trainer(train_config)
# Should handle early stopping gracefully
checkpoint = None
try:
checkpoint = trainer.train()
trainer.train()
except Exception:
# Handle any exceptions
pass
# Resume from latest checkpoint
load_dir = os.path.join(base_test_env["test_dir"], "epoch_0_iter_2")
checkpoint = Checkpoint.load(load_dir)
trainer.train(checkpoint)
trainer = Trainer(train_config)
trainer.train(resume_dir=load_dir)
# Verify checkpoint was saved at expected iteration
load_dir = os.path.join(base_test_env["test_dir"], "epoch_1_iter_10")
checkpoint = Checkpoint.load(load_dir)
assert checkpoint.iteration == 10
import json
with open(os.path.join(load_dir, "meta.json")) as f:
meta = json.load(f)
assert meta["iteration"] == 10

View File

@ -9,7 +9,7 @@ def test_different_batch_sizes(base_test_env, random_dataset, train_config_facto
for batch_per_device in batch_sizes:
train_config = train_config_factory(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
@ -25,7 +25,7 @@ def test_gradient_accumulation(base_test_env, random_dataset, train_config_facto
for grad_accum_steps in grad_accum_steps_list:
train_config = train_config_factory(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],
@ -50,7 +50,7 @@ def test_memory_efficient_training(base_test_env, random_dataset, train_config_f
for config in small_batch_configs:
train_config = train_config_factory(
model=base_test_env["model"],
model_fn=lambda: base_test_env["model"],
dataset=random_dataset,
test_dir=base_test_env["test_dir"],
device=base_test_env["device"],