feat: Checkpoint 支持 extra 通用扩展数据,用户通过函数自定义保存/恢复优化器等状态

- serialization.py: Checkpoint 新增 extra: dict 字段,
  save() 写入 extra.pt,load() 自动恢复
- train_callback.py: CheckpointCallback 新增 save_extra_fn
  参数,用户传入 (context) -> dict 决定保存哪些额外状态
- train_context.py: TrainContextBuilder 新增 load_extra_fn
  参数,用户传入 (extra, context) 从 checkpoint 恢复状态
This commit is contained in:
ViperEkura 2026-05-09 15:50:38 +08:00
parent db99d8b254
commit ca4e6b907c
3 changed files with 28 additions and 4 deletions

View File

@ -1,7 +1,7 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import h5py
import safetensors.torch as st
@ -54,10 +54,12 @@ class Checkpoint:
state_dict: Dict[str, Any],
epoch: int = 0,
iteration: int = 0,
extra: Optional[Dict[str, Any]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.extra = extra or {}
def save(
self,
@ -77,6 +79,8 @@ class Checkpoint:
json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra:
torch.save(self.extra, save_path / "extra.pt")
@classmethod
def load(
@ -99,8 +103,14 @@ class Checkpoint:
state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = None
extra_path = save_path / "extra.pt"
if extra_path.exists():
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
return cls(
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
extra=extra,
)

View File

@ -121,11 +121,13 @@ class CheckpointCallback(TrainCallback):
interval: int,
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn
self.last_ckpt_iter = 0
@only_on_rank(0)
@ -139,8 +141,12 @@ class CheckpointCallback(TrainCallback):
else context.model.state_dict()
)
extra = self.save_extra_fn(context) if self.save_extra_fn else None
context.checkpoint = Checkpoint(
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
)
context.checkpoint.save(save_path)

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Optional, Self
from typing import Callable, Optional, Self
import torch.nn as nn
from torch.optim import Optimizer
@ -32,9 +32,14 @@ class TrainContext:
class TrainContextBuilder:
def __init__(self, config: TrainConfig):
def __init__(
self,
config: TrainConfig,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.config = config
self._checkpoint: Optional[Checkpoint] = None
self._load_extra_fn = load_extra_fn
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
@ -66,6 +71,9 @@ class TrainContextBuilder:
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
self._load_extra_fn(self._checkpoint.extra, context)
cfg = self.config
sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler(