feat: Checkpoint 支持 extra 通用扩展数据,用户通过函数自定义保存/恢复优化器等状态
- serialization.py: Checkpoint 新增 extra: dict 字段, save() 写入 extra.pt,load() 自动恢复 - train_callback.py: CheckpointCallback 新增 save_extra_fn 参数,用户传入 (context) -> dict 决定保存哪些额外状态 - train_context.py: TrainContextBuilder 新增 load_extra_fn 参数,用户传入 (extra, context) 从 checkpoint 恢复状态
This commit is contained in:
parent
db99d8b254
commit
ca4e6b907c
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import h5py
|
||||
import safetensors.torch as st
|
||||
|
|
@ -54,10 +54,12 @@ class Checkpoint:
|
|||
state_dict: Dict[str, Any],
|
||||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.state_dict = state_dict
|
||||
self.epoch = epoch
|
||||
self.iteration = iteration
|
||||
self.extra = extra or {}
|
||||
|
||||
def save(
|
||||
self,
|
||||
|
|
@ -77,6 +79,8 @@ class Checkpoint:
|
|||
json.dump(meta, f, indent=2)
|
||||
|
||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||
if self.extra:
|
||||
torch.save(self.extra, save_path / "extra.pt")
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
|
|
@ -99,8 +103,14 @@ class Checkpoint:
|
|||
|
||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||
|
||||
extra = None
|
||||
extra_path = save_path / "extra.pt"
|
||||
if extra_path.exists():
|
||||
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
extra=extra,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -121,11 +121,13 @@ class CheckpointCallback(TrainCallback):
|
|||
interval: int,
|
||||
weight_only: bool = False,
|
||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||
):
|
||||
self.save_dir = save_dir
|
||||
self.interval = interval
|
||||
self.weight_only = weight_only
|
||||
self.state_dict_fn = state_dict_fn
|
||||
self.save_extra_fn = save_extra_fn
|
||||
self.last_ckpt_iter = 0
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
@ -139,8 +141,12 @@ class CheckpointCallback(TrainCallback):
|
|||
else context.model.state_dict()
|
||||
)
|
||||
|
||||
extra = self.save_extra_fn(context) if self.save_extra_fn else None
|
||||
context.checkpoint = Checkpoint(
|
||||
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
|
||||
state_dict=state_dict,
|
||||
epoch=context.epoch,
|
||||
iteration=context.iteration,
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
context.checkpoint.save(save_path)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Self
|
||||
from typing import Callable, Optional, Self
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
|
@ -32,9 +32,14 @@ class TrainContext:
|
|||
|
||||
|
||||
class TrainContextBuilder:
|
||||
def __init__(self, config: TrainConfig):
|
||||
def __init__(
|
||||
self,
|
||||
config: TrainConfig,
|
||||
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||
):
|
||||
self.config = config
|
||||
self._checkpoint: Optional[Checkpoint] = None
|
||||
self._load_extra_fn = load_extra_fn
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
self._checkpoint = checkpoint
|
||||
|
|
@ -66,6 +71,9 @@ class TrainContextBuilder:
|
|||
context.optimizer = self.config.optimizer_fn(context.model)
|
||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||
|
||||
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
|
||||
self._load_extra_fn(self._checkpoint.extra, context)
|
||||
|
||||
cfg = self.config
|
||||
sampler_offset = context.iteration * cfg.batch_size
|
||||
sampler = ResumableDistributedSampler(
|
||||
|
|
|
|||
Loading…
Reference in New Issue