refactor: checkpoint 按 HF 方式存独立 .pt 文件,callback 接管恢复
- Checkpoint.save/load: extra 逐 key 写为 {key}.pt 而非单个 extra.pt
- meta.json 新增 timestamp
- CheckpointCallback: save_extra/load_extra 静态方法 + extra_keys 类属性
- on_train_begin 接管 optimizer/scheduler 恢复,TrainContextBuilder 不再传 load_extra_fn
This commit is contained in:
parent
026d1fc33d
commit
7dea929788
|
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
@ -35,13 +36,15 @@ class Checkpoint:
|
||||||
meta = {
|
meta = {
|
||||||
"epoch": self.epoch,
|
"epoch": self.epoch,
|
||||||
"iteration": self.iteration,
|
"iteration": self.iteration,
|
||||||
|
"timestamp": time.time(),
|
||||||
}
|
}
|
||||||
with open(save_path / "meta.json", "w") as f:
|
with open(save_path / "meta.json", "w") as f:
|
||||||
json.dump(meta, f, indent=2)
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||||
if self.extra:
|
if self.extra:
|
||||||
torch.save(self.extra, save_path / "extra.pt")
|
for key, value in self.extra.items():
|
||||||
|
torch.save(value, save_path / f"{key}.pt")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -64,14 +67,14 @@ class Checkpoint:
|
||||||
|
|
||||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||||
|
|
||||||
extra = None
|
extra = {}
|
||||||
extra_path = save_path / "extra.pt"
|
for f in save_path.iterdir():
|
||||||
if extra_path.exists():
|
if f.suffix == ".pt" and f.stem not in ("meta",):
|
||||||
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
|
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=meta["epoch"],
|
epoch=meta["epoch"],
|
||||||
iteration=meta["iteration"],
|
iteration=meta["iteration"],
|
||||||
extra=extra,
|
extra=extra or None,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,8 @@ class CheckpointCallback(TrainCallback):
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
extra_keys = ("optimizer", "scheduler")
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
|
|
@ -97,12 +99,14 @@ class CheckpointCallback(TrainCallback):
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
|
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
self.state_dict_fn = state_dict_fn
|
||||||
self.save_extra_fn = save_extra_fn
|
self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra
|
||||||
|
self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -116,7 +120,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
else context.model.state_dict()
|
else context.model.state_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
extra = self.save_extra_fn(context) if self.save_extra_fn else None
|
extra = self.save_extra_fn(context)
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=context.epoch,
|
epoch=context.epoch,
|
||||||
|
|
@ -127,6 +131,10 @@ class CheckpointCallback(TrainCallback):
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.iteration
|
self.last_ckpt_iter = context.iteration
|
||||||
|
|
||||||
|
def on_train_begin(self, context: TrainContext):
|
||||||
|
if context.checkpoint and context.checkpoint.extra:
|
||||||
|
self.load_extra_fn(context.checkpoint.extra, context)
|
||||||
|
|
||||||
def on_batch_end(self, context: TrainContext):
|
def on_batch_end(self, context: TrainContext):
|
||||||
if context.iteration - self.last_ckpt_iter >= self.interval:
|
if context.iteration - self.last_ckpt_iter >= self.interval:
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
@ -138,6 +146,21 @@ class CheckpointCallback(TrainCallback):
|
||||||
def on_error(self, context: TrainContext):
|
def on_error(self, context: TrainContext):
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_extra(context: TrainContext) -> dict:
|
||||||
|
extra = {}
|
||||||
|
for name in CheckpointCallback.extra_keys:
|
||||||
|
obj = getattr(context, name, None)
|
||||||
|
if obj:
|
||||||
|
extra[name] = obj.state_dict()
|
||||||
|
return extra
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_extra(extra: dict, context: TrainContext):
|
||||||
|
for name in CheckpointCallback.extra_keys:
|
||||||
|
if name in extra:
|
||||||
|
getattr(context, name).load_state_dict(extra[name])
|
||||||
|
|
||||||
|
|
||||||
@CallbackFactory.register("progress_bar")
|
@CallbackFactory.register("progress_bar")
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Optional, Self
|
from typing import Optional, Self
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -35,11 +35,9 @@ class TrainContextBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: TrainConfig,
|
config: TrainConfig,
|
||||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._checkpoint: Optional[Checkpoint] = None
|
self._checkpoint: Optional[Checkpoint] = None
|
||||||
self._load_extra_fn = load_extra_fn
|
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
self._checkpoint = checkpoint
|
self._checkpoint = checkpoint
|
||||||
|
|
@ -71,9 +69,6 @@ class TrainContextBuilder:
|
||||||
context.optimizer = self.config.optimizer_fn(context.model)
|
context.optimizer = self.config.optimizer_fn(context.model)
|
||||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
|
|
||||||
self._load_extra_fn(self._checkpoint.extra, context)
|
|
||||||
|
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
sampler_offset = context.iteration * cfg.batch_size
|
sampler_offset = context.iteration * cfg.batch_size
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,33 @@ def test_single_process():
|
||||||
assert loaded_checkpoint.iteration == 30
|
assert loaded_checkpoint.iteration == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_with_extra():
|
||||||
|
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
||||||
|
model = torch.nn.Linear(10, 5)
|
||||||
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
extra = {
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"scheduler": {"last_epoch": 5},
|
||||||
|
}
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra
|
||||||
|
)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
checkpoint.save(tmpdir)
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||||
|
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||||
|
|
||||||
|
loaded = Checkpoint.load(tmpdir)
|
||||||
|
assert loaded.extra["scheduler"]["last_epoch"] == 5
|
||||||
|
assert "state" in loaded.extra["optimizer"]
|
||||||
|
|
||||||
|
|
||||||
def simple_training():
|
def simple_training():
|
||||||
model = torch.nn.Linear(10, 5)
|
model = torch.nn.Linear(10, 5)
|
||||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue