feat: Checkpoint 支持 extra 通用扩展数据,用户通过函数自定义保存/恢复优化器等状态
- serialization.py: Checkpoint 新增 extra: dict 字段, save() 写入 extra.pt,load() 自动恢复 - train_callback.py: CheckpointCallback 新增 save_extra_fn 参数,用户传入 (context) -> dict 决定保存哪些额外状态 - train_context.py: TrainContextBuilder 新增 load_extra_fn 参数,用户传入 (extra, context) 从 checkpoint 恢复状态
This commit is contained in:
parent
db99d8b254
commit
ca4e6b907c
|
|
@ -1,7 +1,7 @@
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
|
|
@ -54,10 +54,12 @@ class Checkpoint:
|
||||||
state_dict: Dict[str, Any],
|
state_dict: Dict[str, Any],
|
||||||
epoch: int = 0,
|
epoch: int = 0,
|
||||||
iteration: int = 0,
|
iteration: int = 0,
|
||||||
|
extra: Optional[Dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
self.state_dict = state_dict
|
self.state_dict = state_dict
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
self.iteration = iteration
|
self.iteration = iteration
|
||||||
|
self.extra = extra or {}
|
||||||
|
|
||||||
def save(
|
def save(
|
||||||
self,
|
self,
|
||||||
|
|
@ -77,6 +79,8 @@ class Checkpoint:
|
||||||
json.dump(meta, f, indent=2)
|
json.dump(meta, f, indent=2)
|
||||||
|
|
||||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||||
|
if self.extra:
|
||||||
|
torch.save(self.extra, save_path / "extra.pt")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(
|
def load(
|
||||||
|
|
@ -99,8 +103,14 @@ class Checkpoint:
|
||||||
|
|
||||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||||
|
|
||||||
|
extra = None
|
||||||
|
extra_path = save_path / "extra.pt"
|
||||||
|
if extra_path.exists():
|
||||||
|
extra = torch.load(extra_path, map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
epoch=meta["epoch"],
|
epoch=meta["epoch"],
|
||||||
iteration=meta["iteration"],
|
iteration=meta["iteration"],
|
||||||
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -121,11 +121,13 @@ class CheckpointCallback(TrainCallback):
|
||||||
interval: int,
|
interval: int,
|
||||||
weight_only: bool = False,
|
weight_only: bool = False,
|
||||||
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
state_dict_fn: Optional[Callable[[nn.Module], dict]] = None,
|
||||||
|
save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None,
|
||||||
):
|
):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.weight_only = weight_only
|
self.weight_only = weight_only
|
||||||
self.state_dict_fn = state_dict_fn
|
self.state_dict_fn = state_dict_fn
|
||||||
|
self.save_extra_fn = save_extra_fn
|
||||||
self.last_ckpt_iter = 0
|
self.last_ckpt_iter = 0
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
@ -139,8 +141,12 @@ class CheckpointCallback(TrainCallback):
|
||||||
else context.model.state_dict()
|
else context.model.state_dict()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
extra = self.save_extra_fn(context) if self.save_extra_fn else None
|
||||||
context.checkpoint = Checkpoint(
|
context.checkpoint = Checkpoint(
|
||||||
state_dict=state_dict, epoch=context.epoch, iteration=context.iteration
|
state_dict=state_dict,
|
||||||
|
epoch=context.epoch,
|
||||||
|
iteration=context.iteration,
|
||||||
|
extra=extra,
|
||||||
)
|
)
|
||||||
|
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Self
|
from typing import Callable, Optional, Self
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -32,9 +32,14 @@ class TrainContext:
|
||||||
|
|
||||||
|
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
def __init__(self, config: TrainConfig):
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TrainConfig,
|
||||||
|
load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None,
|
||||||
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self._checkpoint: Optional[Checkpoint] = None
|
self._checkpoint: Optional[Checkpoint] = None
|
||||||
|
self._load_extra_fn = load_extra_fn
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
self._checkpoint = checkpoint
|
self._checkpoint = checkpoint
|
||||||
|
|
@ -66,6 +71,9 @@ class TrainContextBuilder:
|
||||||
context.optimizer = self.config.optimizer_fn(context.model)
|
context.optimizer = self.config.optimizer_fn(context.model)
|
||||||
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
context.scheduler = self.config.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
|
if self._checkpoint and self._checkpoint.extra and self._load_extra_fn:
|
||||||
|
self._load_extra_fn(self._checkpoint.extra, context)
|
||||||
|
|
||||||
cfg = self.config
|
cfg = self.config
|
||||||
sampler_offset = context.iteration * cfg.batch_size
|
sampler_offset = context.iteration * cfg.batch_size
|
||||||
sampler = ResumableDistributedSampler(
|
sampler = ResumableDistributedSampler(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue