refactor: checkpoint 按 HF 方式存独立 .pt 文件,callback 接管恢复

- Checkpoint.save/load: extra 逐 key 写为 {key}.pt 而非单个 extra.pt
- meta.json 新增 timestamp
- CheckpointCallback: save_extra/load_extra 静态方法 + extra_keys 类属性
- on_train_begin 接管 optimizer/scheduler 恢复,TrainContextBuilder 不再传 load_extra_fn
This commit is contained in:
ViperEkura 2026-05-16 18:21:19 +08:00
parent 026d1fc33d
commit 7dea929788
4 changed files with 62 additions and 14 deletions

View File

@ -1,4 +1,5 @@
import json import json
import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -35,13 +36,15 @@ class Checkpoint:
meta = { meta = {
"epoch": self.epoch, "epoch": self.epoch,
"iteration": self.iteration, "iteration": self.iteration,
"timestamp": time.time(),
} }
with open(save_path / "meta.json", "w") as f: with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2) json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors") st.save_file(self.state_dict, save_path / "state_dict.safetensors")
if self.extra: if self.extra:
torch.save(self.extra, save_path / "extra.pt") for key, value in self.extra.items():
torch.save(value, save_path / f"{key}.pt")
@classmethod @classmethod
def load( def load(
@ -64,14 +67,14 @@ class Checkpoint:
state_dict = st.load_file(save_path / "state_dict.safetensors") state_dict = st.load_file(save_path / "state_dict.safetensors")
extra = None extra = {}
extra_path = save_path / "extra.pt" for f in save_path.iterdir():
if extra_path.exists(): if f.suffix == ".pt" and f.stem not in ("meta",):
extra = torch.load(extra_path, map_location="cpu", weights_only=False) extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,
epoch=meta["epoch"], epoch=meta["epoch"],
iteration=meta["iteration"], iteration=meta["iteration"],
extra=extra, extra=extra or None,
) )

View File

@ -90,6 +90,8 @@ class CheckpointCallback(TrainCallback):
Checkpoint callback for trainer. Checkpoint callback for trainer.
""" """
extra_keys = ("optimizer", "scheduler")
def __init__( def __init__(
self, self,
save_dir: str, save_dir: str,
@ -97,12 +99,14 @@ class CheckpointCallback(TrainCallback):
weight_only: bool = False, weight_only: bool = False,
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
): ):
self.save_dir = save_dir self.save_dir = save_dir
self.interval = interval self.interval = interval
self.weight_only = weight_only self.weight_only = weight_only
self.state_dict_fn = state_dict_fn self.state_dict_fn = state_dict_fn
self.save_extra_fn = save_extra_fn self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
self.last_ckpt_iter = 0 self.last_ckpt_iter = 0
@only_on_rank(0) @only_on_rank(0)
@ -116,7 +120,7 @@ class CheckpointCallback(TrainCallback):
else context.model.state_dict() else context.model.state_dict()
) )
extra = self.save_extra_fn(context) if self.save_extra_fn else None extra = self.save_extra_fn(context)
context.checkpoint = Checkpoint( context.checkpoint = Checkpoint(
state_dict=state_dict, state_dict=state_dict,
epoch=context.epoch, epoch=context.epoch,
@ -127,6 +131,10 @@ class CheckpointCallback(TrainCallback):
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.iteration self.last_ckpt_iter = context.iteration
def on_train_begin(self, context: TrainContext):
if context.checkpoint and context.checkpoint.extra:
self.load_extra_fn(context.checkpoint.extra, context)
def on_batch_end(self, context: TrainContext): def on_batch_end(self, context: TrainContext):
if context.iteration - self.last_ckpt_iter >= self.interval: if context.iteration - self.last_ckpt_iter >= self.interval:
self._save_checkpoint(context) self._save_checkpoint(context)
@ -138,6 +146,21 @@ class CheckpointCallback(TrainCallback):
def on_error(self, context: TrainContext): def on_error(self, context: TrainContext):
self._save_checkpoint(context) self._save_checkpoint(context)
@staticmethod
def save_extra(context: TrainContext) -> dict:
extra = {}
for name in CheckpointCallback.extra_keys:
obj = getattr(context, name, None)
if obj:
extra[name] = obj.state_dict()
return extra
@staticmethod
def load_extra(extra: dict, context: TrainContext):
for name in CheckpointCallback.extra_keys:
if name in extra:
getattr(context, name).load_state_dict(extra[name])
@CallbackFactory.register("progress_bar") @CallbackFactory.register("progress_bar")
class ProgressBarCallback(TrainCallback): class ProgressBarCallback(TrainCallback):

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Optional, Self from typing import Optional, Self
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -35,11 +35,9 @@ class TrainContextBuilder:
def __init__( def __init__(
self, self,
config: TrainConfig, config: TrainConfig,
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
): ):
self.config = config self.config = config
self._checkpoint: Optional[Checkpoint] = None self._checkpoint: Optional[Checkpoint] = None
self._load_extra_fn = load_extra_fn
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
self._checkpoint = checkpoint self._checkpoint = checkpoint
@ -71,9 +69,6 @@ class TrainContextBuilder:
context.optimizer = self.config.optimizer_fn(context.model) context.optimizer = self.config.optimizer_fn(context.model)
context.scheduler = self.config.scheduler_fn(context.optimizer) context.scheduler = self.config.scheduler_fn(context.optimizer)
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
self._load_extra_fn(self._checkpoint.extra, context)
cfg = self.config cfg = self.config
sampler_offset = context.iteration * cfg.batch_size sampler_offset = context.iteration * cfg.batch_size
sampler = ResumableDistributedSampler( sampler = ResumableDistributedSampler(

View File

@ -35,6 +35,33 @@ def test_single_process():
assert loaded_checkpoint.iteration == 30 assert loaded_checkpoint.iteration == 30
def test_checkpoint_with_extra():
"""Verify extra keys are saved as individual .pt files and loaded back."""
model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step()
extra = {
"optimizer": optimizer.state_dict(),
"scheduler": {"last_epoch": 5},
}
checkpoint = Checkpoint(
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
)
with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir)
import os
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
loaded = Checkpoint.load(tmpdir)
assert loaded.extra["scheduler"]["last_epoch"] == 5
assert "state" in loaded.extra["optimizer"]
def simple_training(): def simple_training():
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)