diff --git a/astrai/serialization.py b/astrai/serialization.py index d5a99c6..ba0aab4 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -1,7 +1,7 @@ import json import os from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import h5py import safetensors.torch as st @@ -54,10 +54,12 @@ class Checkpoint: state_dict: Dict[str, Any], epoch: int = 0, iteration: int = 0, + extra: Optional[Dict[str, Any]] = None, ): self.state_dict = state_dict self.epoch = epoch self.iteration = iteration + self.extra = extra or {} def save( self, @@ -77,6 +79,8 @@ class Checkpoint: json.dump(meta, f, indent=2) st.save_file(self.state_dict, save_path / "state_dict.safetensors") + if self.extra: + torch.save(self.extra, save_path / "extra.pt") @classmethod def load( @@ -99,8 +103,14 @@ class Checkpoint: state_dict = st.load_file(save_path / "state_dict.safetensors") + extra = None + extra_path = save_path / "extra.pt" + if extra_path.exists(): + extra = torch.load(extra_path, map_location="cpu", weights_only=False) + return cls( state_dict=state_dict, epoch=meta["epoch"], iteration=meta["iteration"], + extra=extra, ) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 6381b31..623cda6 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -121,11 +121,13 @@ class CheckpointCallback(TrainCallback): interval: int, weight_only: bool = False, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, + save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only self.state_dict_fn = state_dict_fn + self.save_extra_fn = save_extra_fn self.last_ckpt_iter = 0 @only_on_rank(0) @@ -139,8 +141,12 @@ class CheckpointCallback(TrainCallback): else context.model.state_dict() ) + extra = self.save_extra_fn(context) if self.save_extra_fn else None context.checkpoint = Checkpoint( - state_dict=state_dict, epoch=context.epoch, iteration=context.iteration + state_dict=state_dict, + epoch=context.epoch, + iteration=context.iteration, + extra=extra, ) context.checkpoint.save(save_path) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 04eeb27..95a798d 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Self +from typing import Callable, Optional, Self import torch.nn as nn from torch.optim import Optimizer @@ -32,9 +32,14 @@ class TrainContext: class TrainContextBuilder: - def __init__(self, config: TrainConfig): + def __init__( + self, + config: TrainConfig, + load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None, + ): self.config = config self._checkpoint: Optional[Checkpoint] = None + self._load_extra_fn = load_extra_fn def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: self._checkpoint = checkpoint @@ -66,6 +71,9 @@ class TrainContextBuilder: context.optimizer = self.config.optimizer_fn(context.model) context.scheduler = self.config.scheduler_fn(context.optimizer) + if self._checkpoint and self._checkpoint.extra and self._load_extra_fn: + self._load_extra_fn(self._checkpoint.extra, context) + cfg = self.config sampler_offset = context.iteration * cfg.batch_size sampler = ResumableDistributedSampler(