refactor: checkpoint 按 HF 方式存独立 .pt 文件,callback 接管恢复

- Checkpoint.save/load: extra 逐 key 写为 {key}.pt 而非单个 extra.pt
- meta.json 新增 timestamp
- CheckpointCallback: save_extra/load_extra 静态方法 + extra_keys 类属性
- on_train_begin 接管 optimizer/scheduler 恢复,TrainContextBuilder 不再传 load_extra_fn
This commit is contained in:
ViperEkura 2026-05-16 18:21:19 +08:00
parent 026d1fc33d
commit 7dea929788
4 changed files with 62 additions and 14 deletions

View File

@ -1,4 +1,5 @@
import json
import time
from pathlib import Path
from typing import Any, Dict, Optional
@ -35,13 +36,15 @@ class Checkpoint:
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.time(),
}
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra:
torch.save(self.extra, save_path / "extra.pt")
for key, value in self.extra.items():
torch.save(value, save_path / f"{key}.pt")
@classmethod
def load(
@ -64,14 +67,14 @@ class Checkpoint:
state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = None
extra_path = save_path / "extra.pt"
if extra_path.exists():
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
extra = {}
for f in save_path.iterdir():
if f.suffix == ".pt" and f.stem not in ("meta",):
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
return cls(
state_dict=state_dict,
epoch=meta["epoch"],
iteration=meta["iteration"],
extra=extra,
extra=extra or None,
)

View File

@ -90,6 +90,8 @@ class CheckpointCallback(TrainCallback):
Checkpoint callback for trainer.
"""
extra_keys = ("optimizer", "scheduler")
def __init__(
self,
save_dir: str,
@ -97,12 +99,14 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.save_dir = save_dir
self.interval = interval
self.weight_only = weight_only
self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0
@only_on_rank(0)
@ -116,7 +120,7 @@ class CheckpointCallback(TrainCallback):
else context.model.state_dict()
)
extra = self.save_extra_fn(context) if self.save_extra_fn else None
extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint(
state_dict=state_dict,
epoch=context.epoch,
@ -127,6 +131,10 @@ class CheckpointCallback(TrainCallback):
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context)
@ -138,6 +146,21 @@ class CheckpointCallback(TrainCallback):
def on_error(self, context: TrainContext):
self._save_checkpoint(context)
@staticmethod
def save_extra(context: TrainContext) -> dict:
extra = {}
for name in CheckpointCallback.extra_keys:
obj = getattr(context, name, None)
if obj:
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback):

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Callable, Optional, Self
from typing import Optional, Self
import torch.nn as nn
from torch.optim import Optimizer
@ -35,11 +35,9 @@ class TrainContextBuilder:
def __init__(
self,
config: TrainConfig,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
):
self.config = config
self._checkpoint: Optional[Checkpoint] = None
self._load_extra_fn = load_extra_fn
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint
@ -71,9 +69,6 @@ class TrainContextBuilder:
context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer)
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
self._load_extra_fn(self._checkpoint.extra, context)
cfg = self.config
sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler(

View File

@ -35,6 +35,33 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30
def test_checkpoint_with_extra():
"""Verify extra keys are saved as individual .pt files and loaded back."""
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step()
extra = {
"optimizer": optimizer.state_dict(),
"scheduler": {"last_epoch": 5},
}
checkpoint = Checkpoint(
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
import os
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
loaded = Checkpoint.load(tmpdir)
assert loaded.extra["scheduler"]["last_epoch"] == 5
assert "state" in loaded.extra["optimizer"]
def simple_training():
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)