refactor: checkpoint 按 HF 方式存独立 .pt 文件,callback 接管恢复
- Checkpoint.save/load: extra 逐 key 写为 {key}.pt 而非单个 extra.pt
- meta.json 新增 timestamp
- CheckpointCallback: save_extra/load_extra 静态方法 + extra_keys 类属性
- on_train_begin 接管 optimizer/scheduler 恢复,TrainContextBuilder 不再传 load_extra_fn
This commit is contained in:
parent
026d1fc33d
commit
7dea929788
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
|
@ -35,13 +36,15 @@ class Checkpoint:
|
|||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||
if self.extra:
|
||||
torch.save(self.extra, save_path / "extra.pt")
|
||||
for key, value in self.extra.items():
|
||||
torch.save(value, save_path / f"{key}.pt")
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
|
|
@ -64,14 +67,14 @@ class Checkpoint:
|
|||
|
||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||
|
||||
extra = None
|
||||
extra_path = save_path / "extra.pt"
|
||||
if extra_path.exists():
|
||||
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
|
||||
extra = {}
|
||||
for f in save_path.iterdir():
|
||||
if f.suffix == ".pt" and f.stem not in ("meta",):
|
||||
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
extra=extra,
|
||||
extra=extra or None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -90,6 +90,8 @@ class CheckpointCallback(TrainCallback):
|
|||
Checkpoint callback for trainer.
|
||||
"""
|
||||
|
||||
extra_keys = ("optimizer", "scheduler")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str,
|
||||
|
|
@ -97,12 +99,14 @@ class CheckpointCallback(TrainCallback):
|
|||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.save_extra_fn = save_extra_fn
|
||||
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
@ -116,7 +120,7 @@ class CheckpointCallback(TrainCallback):
|
|||
else context.model.state_dict()
|
||||
)
|
||||
|
||||
extra = self.save_extra_fn(context) if self.save_extra_fn else None
|
||||
extra = self.save_extra_fn(context)
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
|
|
@ -127,6 +131,10 @@ class CheckpointCallback(TrainCallback):
|
|||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.iteration
|
||||
|
||||
def on_train_begin(self, context: TrainContext):
|
||||
if context.checkpoint and context.checkpoint.extra:
|
||||
self.load_extra_fn(context.checkpoint.extra, context)
|
||||
|
||||
def on_batch_end(self, context: TrainContext):
|
||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||
self._save_checkpoint(context)
|
||||
|
|
@ -138,6 +146,21 @@ class CheckpointCallback(TrainCallback):
|
|||
def on_error(self, context: TrainContext):
|
||||
self._save_checkpoint(context)
|
||||
|
||||
@staticmethod
|
||||
def save_extra(context: TrainContext) -> dict:
|
||||
extra = {}
|
||||
for name in CheckpointCallback.extra_keys:
|
||||
obj = getattr(context, name, None)
|
||||
if obj:
|
||||
extra[name] = obj.state_dict()
|
||||
return extra
|
||||
|
||||
@staticmethod
|
||||
def load_extra(extra: dict, context: TrainContext):
|
||||
for name in CheckpointCallback.extra_keys:
|
||||
if name in extra:
|
||||
getattr(context, name).load_state_dict(extra[name])
|
||||
|
||||
|
||||
@CallbackFactory.register("progress_bar")
|
||||
class ProgressBarCallback(TrainCallback):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Optional, Self
|
||||
from typing import Optional, Self
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
|
@ -35,11 +35,9 @@ class TrainContextBuilder:
|
|||
def __init__(
|
||||
self,
|
||||
config: TrainConfig,
|
||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||
):
|
||||
self.config = config
|
||||
self._checkpoint: Optional[Checkpoint] = None
|
||||
self._load_extra_fn = load_extra_fn
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._checkpoint = checkpoint
|
||||
|
|
@ -71,9 +69,6 @@ class TrainContextBuilder:
|
|||
context.optimizer = self.config.optimizer_fn(context.model)
|
||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||
|
||||
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
|
||||
self._load_extra_fn(self._checkpoint.extra, context)
|
||||
|
||||
cfg = self.config
|
||||
sampler_offset = context.iteration * cfg.batch_size
|
||||
sampler = ResumableDistributedSampler(
|
||||
|
|
|
|||
|
|
@ -35,6 +35,33 @@ def test_single_process():
|
|||
assert loaded_checkpoint.iteration == 30
|
||||
|
||||
|
||||
def test_checkpoint_with_extra():
|
||||
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
|
||||
extra = {
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": {"last_epoch": 5},
|
||||
}
|
||||
checkpoint = Checkpoint(
|
||||
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir)
|
||||
|
||||
import os
|
||||
|
||||
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||
|
||||
loaded = Checkpoint.load(tmpdir)
|
||||
assert loaded.extra["scheduler"]["last_epoch"] == 5
|
||||
assert "state" in loaded.extra["optimizer"]
|
||||
|
||||
|
||||
def simple_training():
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
|
|
|
|||
Loading…
Reference in New Issue