feat: Checkpoint 支持 extra 通用扩展数据,用户通过函数自定义保存/恢复优化器等状态

- serialization.py: Checkpoint 新增 extra: dict 字段,
  save() 写入 extra.pt,load() 自动恢复
- train_callback.py: CheckpointCallback 新增 save_extra_fn
  参数,用户传入 (context) -> dict 决定保存哪些额外状态
- train_context.py: TrainContextBuilder 新增 load_extra_fn
  参数,用户传入 (extra, context) 从 checkpoint 恢复状态
This commit is contained in:
ViperEkura 2026-05-09 15:50:38 +08:00
parent db99d8b254
commit ca4e6b907c
3 changed files with 28 additions and 4 deletions

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import h5py import h5py
import safetensors.torch as st import safetensors.torch as st
@ -54,10 +54,12 @@ class Checkpoint:
state_dict: Dict[str, Any], state_dict: Dict[str, Any],
epoch: int = 0, epoch: int = 0,
iteration: int = 0, iteration: int = 0,
extra: Optional[Dict[str, Any]] = None,
): ):
self.state_dict = state_dict self.state_dict = state_dict
self.epoch = epoch self.epoch = epoch
self.iteration = iteration self.iteration = iteration
self.extra = extra or {}
def save( def save(
self, self,
@ -77,6 +79,8 @@ class Checkpoint:
json.dump(meta, f, indent=2) json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors") st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra:
torch.save(self.extra, save_path / "extra.pt")
@classmethod @classmethod
def load( def load(
@ -99,8 +103,14 @@ class Checkpoint:
state_dict = st.load_file(save_path / "state_dict.safetensors") state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = None
extra_path = save_path / "extra.pt"
if extra_path.exists():
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,
epoch=meta["epoch"], epoch=meta["epoch"],
iteration=meta["iteration"], iteration=meta["iteration"],
extra=extra,
) )

View File

@ -121,11 +121,13 @@ class CheckpointCallback(TrainCallback):
interval: int, interval: int,
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0) @only_on_rank(0)
@ -139,8 +141,12 @@ class CheckpointCallback(TrainCallback):
else context.model.state_dict() else context.model.state_dict()
) )
extra = self.save_extra_fn(context) if self.save_extra_fn else None
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration state_dict=state_dict,
epoch=context.epoch,
iteration=context.iteration,
extra=extra,
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Self from typing import Callable, Optional, Self
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -32,9 +32,14 @@ class TrainContext:
class TrainContextBuilder: class TrainContextBuilder:
def __init__(self, config: TrainConfig): def __init__(
self,
config: TrainConfig,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.config = config self.config = config
self._checkpoint: Optional[Checkpoint] = None self._checkpoint: Optional[Checkpoint] = None
self._load_extra_fn = load_extra_fn
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint self._checkpoint = checkpoint
@ -66,6 +71,9 @@ class TrainContextBuilder:
context.optimizer = self.config.optimizer_fn(context.model) context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer) context.scheduler = self.config.scheduler_fn(context.optimizer)
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
self._load_extra_fn(self._checkpoint.extra, context)
cfg = self.config cfg = self.config
sampler_offset = context.iteration * cfg.batch_size sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler( sampler = ResumableDistributedSampler(