From ca4e6b907cc634be345c6a8c80096e3d04b2ace5 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 9 May 2026 15:50:38 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20Checkpoint=20=E6=94=AF=E6=8C=81=20extra?= =?UTF-8?q?=20=E9=80=9A=E7=94=A8=E6=89=A9=E5=B1=95=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=EF=BC=8C=E7=94=A8=E6=88=B7=E9=80=9A=E8=BF=87=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89=E4=BF=9D=E5=AD=98/=E6=81=A2?= =?UTF-8?q?=E5=A4=8D=E4=BC=98=E5=8C=96=E5=99=A8=E7=AD=89=E7=8A=B6=E6=80=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - serialization.py: Checkpoint 新增 extra: dict 字段, save() 写入 extra.pt,load() 自动恢复 - train_callback.py: CheckpointCallback 新增 save_extra_fn 参数,用户传入 (context) -> dict 决定保存哪些额外状态 - train_context.py: TrainContextBuilder 新增 load_extra_fn 参数,用户传入 (extra, context) 从 checkpoint 恢复状态 --- astrai/serialization.py | 12 +++++++++++- astrai/trainer/train_callback.py | 8 +++++++- astrai/trainer/train_context.py | 12 ++++++++++-- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/astrai/serialization.py b/astrai/serialization.py index d5a99c6..ba0aab4 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -1,7 +1,7 @@ import json import os from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import h5py import safetensors.torch as st @@ -54,10 +54,12 @@ class Checkpoint: state_dict: Dict[str, Any], epoch: int = 0, iteration: int = 0, + extra: Optional[Dict[str, Any]] = None, ): self.state_dict = state_dict self.epoch = epoch self.iteration = iteration + self.extra = extra or {} def save( self, @@ -77,6 +79,8 @@ class Checkpoint: json.dump(meta, f, indent=2) st.save_file(self.state_dict, save_path / "state_dict.safetensors") + if self.extra: + torch.save(self.extra, save_path / "extra.pt") @classmethod def load( @@ -99,8 +103,14 @@ class Checkpoint: state_dict = st.load_file(save_path / "state_dict.safetensors") + extra = None + extra_path = save_path / "extra.pt" + if extra_path.exists(): + extra = torch.load(extra_path, map_location="cpu", weights_only=False) + return cls( state_dict=state_dict, epoch=meta["epoch"], iteration=meta["iteration"], + extra=extra, ) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 6381b31..623cda6 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -121,11 +121,13 @@ class CheckpointCallback(TrainCallback): interval: int, weight_only: bool = False, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, + save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only self.state_dict_fn = state_dict_fn + self.save_extra_fn = save_extra_fn self.last_ckpt_iter = 0 @only_on_rank(0) @@ -139,8 +141,12 @@ class CheckpointCallback(TrainCallback): else context.model.state_dict() ) + extra = self.save_extra_fn(context) if self.save_extra_fn else None context.checkpoint = Checkpoint( - state_dict=state_dict, epoch=context.epoch, iteration=context.iteration + state_dict=state_dict, + epoch=context.epoch, + iteration=context.iteration, + extra=extra, ) context.checkpoint.save(save_path) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 04eeb27..95a798d 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Self +from typing import Callable, Optional, Self import torch.nn as nn from torch.optim import Optimizer @@ -32,9 +32,14 @@ class TrainContext: class TrainContextBuilder: - def __init__(self, config: TrainConfig): + def __init__( + self, + config: TrainConfig, + load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None, + ): self.config = config self._checkpoint: Optional[Checkpoint] = None + self._load_extra_fn = load_extra_fn def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: self._checkpoint = checkpoint @@ -66,6 +71,9 @@ class TrainContextBuilder: context.optimizer = self.config.optimizer_fn(context.model) context.scheduler = self.config.scheduler_fn(context.optimizer) + if self._checkpoint and self._checkpoint.extra and self._load_extra_fn: + self._load_extra_fn(self._checkpoint.extra, context) + cfg = self.config sampler_offset = context.iteration * cfg.batch_size sampler = ResumableDistributedSampler(