From 7dea9297881d9b2c97db4d204c6f6d61770019a3 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 16 May 2026 18:21:19 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20checkpoint=20=E6=8C=89=20HF=20?= =?UTF-8?q?=E6=96=B9=E5=BC=8F=E5=AD=98=E7=8B=AC=E7=AB=8B=20.pt=20=E6=96=87?= =?UTF-8?q?=E4=BB=B6=EF=BC=8Ccallback=20=E6=8E=A5=E7=AE=A1=E6=81=A2?= =?UTF-8?q?=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Checkpoint.save/load: extra 逐 key 写为 {key}.pt 而非单个 extra.pt - meta.json 新增 timestamp - CheckpointCallback: save_extra/load_extra 静态方法 + extra_keys 类属性 - on_train_begin 接管 optimizer/scheduler 恢复,TrainContextBuilder 不再传 load_extra_fn --- astrai/serialization.py | 15 +++++++++------ astrai/trainer/train_callback.py | 27 +++++++++++++++++++++++++-- astrai/trainer/train_context.py | 7 +------ tests/data/test_checkpoint.py | 27 +++++++++++++++++++++++++++ 4 files changed, 62 insertions(+), 14 deletions(-) diff --git a/astrai/serialization.py b/astrai/serialization.py index 87b4272..dfd9cf1 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -1,4 +1,5 @@ import json +import time from pathlib import Path from typing import Any, Dict, Optional @@ -35,13 +36,15 @@ class Checkpoint: meta = { "epoch": self.epoch, "iteration": self.iteration, + "timestamp": time.time(), } with open(save_path / "meta.json", "w") as f: json.dump(meta, f, indent=2) st.save_file(self.state_dict, save_path / "state_dict.safetensors") if self.extra: - torch.save(self.extra, save_path / "extra.pt") + for key, value in self.extra.items(): + torch.save(value, save_path / f"{key}.pt") @classmethod def load( @@ -64,14 +67,14 @@ class Checkpoint: state_dict = st.load_file(save_path / "state_dict.safetensors") - extra = None - extra_path = save_path / "extra.pt" - if extra_path.exists(): - extra = torch.load(extra_path, map_location="cpu", weights_only=False) + extra = {} + for f in save_path.iterdir(): + if f.suffix == ".pt" and f.stem not in ("meta",): + extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False) return cls( state_dict=state_dict, epoch=meta["epoch"], iteration=meta["iteration"], - extra=extra, + extra=extra or None, ) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 034ca93..07ab4eb 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -90,6 +90,8 @@ class CheckpointCallback(TrainCallback): Checkpoint callback for trainer. """ + extra_keys = ("optimizer", "scheduler") + def __init__( self, save_dir: str, @@ -97,12 +99,14 @@ class CheckpointCallback(TrainCallback): weight_only: bool = False, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, + load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only self.state_dict_fn = state_dict_fn - self.save_extra_fn = save_extra_fn + self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra + self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra self.last_ckpt_iter = 0 @only_on_rank(0) @@ -116,7 +120,7 @@ class CheckpointCallback(TrainCallback): else context.model.state_dict() ) - extra = self.save_extra_fn(context) if self.save_extra_fn else None + extra = self.save_extra_fn(context) context.checkpoint = Checkpoint( state_dict=state_dict, epoch=context.epoch, @@ -127,6 +131,10 @@ class CheckpointCallback(TrainCallback): context.checkpoint.save(save_path) self.last_ckpt_iter = context.iteration + def on_train_begin(self, context: TrainContext): + if context.checkpoint and context.checkpoint.extra: + self.load_extra_fn(context.checkpoint.extra, context) + def on_batch_end(self, context: TrainContext): if context.iteration - self.last_ckpt_iter >= self.interval: self._save_checkpoint(context) @@ -138,6 +146,21 @@ class CheckpointCallback(TrainCallback): def on_error(self, context: TrainContext): self._save_checkpoint(context) + @staticmethod + def save_extra(context: TrainContext) -> dict: + extra = {} + for name in CheckpointCallback.extra_keys: + obj = getattr(context, name, None) + if obj: + extra[name] = obj.state_dict() + return extra + + @staticmethod + def load_extra(extra: dict, context: TrainContext): + for name in CheckpointCallback.extra_keys: + if name in extra: + getattr(context, name).load_state_dict(extra[name]) + @CallbackFactory.register("progress_bar") class ProgressBarCallback(TrainCallback): diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 95a798d..a81d23a 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Callable, Optional, Self +from typing import Optional, Self import torch.nn as nn from torch.optim import Optimizer @@ -35,11 +35,9 @@ class TrainContextBuilder: def __init__( self, config: TrainConfig, - load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None, ): self.config = config self._checkpoint: Optional[Checkpoint] = None - self._load_extra_fn = load_extra_fn def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: self._checkpoint = checkpoint @@ -71,9 +69,6 @@ class TrainContextBuilder: context.optimizer = self.config.optimizer_fn(context.model) context.scheduler = self.config.scheduler_fn(context.optimizer) - if self._checkpoint and self._checkpoint.extra and self._load_extra_fn: - self._load_extra_fn(self._checkpoint.extra, context) - cfg = self.config sampler_offset = context.iteration * cfg.batch_size sampler = ResumableDistributedSampler( diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index ce68447..f3556db 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -35,6 +35,33 @@ def test_single_process(): assert loaded_checkpoint.iteration == 30 +def test_checkpoint_with_extra(): + """Verify extra keys are saved as individual .pt files and loaded back.""" + model = torch.nn.Linear(10, 5) + optimizer = AdamW(model.parameters(), lr=1e-3) + optimizer.step() + + extra = { + "optimizer": optimizer.state_dict(), + "scheduler": {"last_epoch": 5}, + } + checkpoint = Checkpoint( + state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint.save(tmpdir) + + import os + + assert os.path.exists(os.path.join(tmpdir, "optimizer.pt")) + assert os.path.exists(os.path.join(tmpdir, "scheduler.pt")) + + loaded = Checkpoint.load(tmpdir) + assert loaded.extra["scheduler"]["last_epoch"] == 5 + assert "state" in loaded.extra["optimizer"] + + def simple_training(): model = torch.nn.Linear(10, 5) optimizer = AdamW(model.parameters(), lr=1e-3)