diff --git a/astrai/serialization.py b/astrai/serialization.py index 87b4272..dfd9cf1 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -1,4 +1,5 @@ import json +import time from pathlib import Path from typing import Any, Dict, Optional @@ -35,13 +36,15 @@ class Checkpoint: meta = { "epoch": self.epoch, "iteration": self.iteration, + "timestamp": time.time(), } with open(save_path / "meta.json", "w") as f: json.dump(meta, f, indent=2) st.save_file(self.state_dict, save_path / "state_dict.safetensors") if self.extra: - torch.save(self.extra, save_path / "extra.pt") + for key, value in self.extra.items(): + torch.save(value, save_path / f"{key}.pt") @classmethod def load( @@ -64,14 +67,14 @@ class Checkpoint: state_dict = st.load_file(save_path / "state_dict.safetensors") - extra = None - extra_path = save_path / "extra.pt" - if extra_path.exists(): - extra = torch.load(extra_path, map_location="cpu", weights_only=False) + extra = {} + for f in save_path.iterdir(): + if f.suffix == ".pt" and f.stem not in ("meta",): + extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False) return cls( state_dict=state_dict, epoch=meta["epoch"], iteration=meta["iteration"], - extra=extra, + extra=extra or None, ) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 034ca93..07ab4eb 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -90,6 +90,8 @@ class CheckpointCallback(TrainCallback): Checkpoint callback for trainer. """ + extra_keys = ("optimizer", "scheduler") + def __init__( self, save_dir: str, @@ -97,12 +99,14 @@ class CheckpointCallback(TrainCallback): weight_only: bool = False, state_dict_fn: Optional[Callable[[nn.Module], dict]] = None, save_extra_fn: Optional[Callable[["TrainContext"], dict]] = None, + load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None, ): self.save_dir = save_dir self.interval = interval self.weight_only = weight_only self.state_dict_fn = state_dict_fn - self.save_extra_fn = save_extra_fn + self.save_extra_fn = save_extra_fn or CheckpointCallback.save_extra + self.load_extra_fn = load_extra_fn or CheckpointCallback.load_extra self.last_ckpt_iter = 0 @only_on_rank(0) @@ -116,7 +120,7 @@ class CheckpointCallback(TrainCallback): else context.model.state_dict() ) - extra = self.save_extra_fn(context) if self.save_extra_fn else None + extra = self.save_extra_fn(context) context.checkpoint = Checkpoint( state_dict=state_dict, epoch=context.epoch, @@ -127,6 +131,10 @@ class CheckpointCallback(TrainCallback): context.checkpoint.save(save_path) self.last_ckpt_iter = context.iteration + def on_train_begin(self, context: TrainContext): + if context.checkpoint and context.checkpoint.extra: + self.load_extra_fn(context.checkpoint.extra, context) + def on_batch_end(self, context: TrainContext): if context.iteration - self.last_ckpt_iter >= self.interval: self._save_checkpoint(context) @@ -138,6 +146,21 @@ class CheckpointCallback(TrainCallback): def on_error(self, context: TrainContext): self._save_checkpoint(context) + @staticmethod + def save_extra(context: TrainContext) -> dict: + extra = {} + for name in CheckpointCallback.extra_keys: + obj = getattr(context, name, None) + if obj: + extra[name] = obj.state_dict() + return extra + + @staticmethod + def load_extra(extra: dict, context: TrainContext): + for name in CheckpointCallback.extra_keys: + if name in extra: + getattr(context, name).load_state_dict(extra[name]) + @CallbackFactory.register("progress_bar") class ProgressBarCallback(TrainCallback): diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 95a798d..a81d23a 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Callable, Optional, Self +from typing import Optional, Self import torch.nn as nn from torch.optim import Optimizer @@ -35,11 +35,9 @@ class TrainContextBuilder: def __init__( self, config: TrainConfig, - load_extra_fn: Optional[Callable[[dict, "TrainContext"], None]] = None, ): self.config = config self._checkpoint: Optional[Checkpoint] = None - self._load_extra_fn = load_extra_fn def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: self._checkpoint = checkpoint @@ -71,9 +69,6 @@ class TrainContextBuilder: context.optimizer = self.config.optimizer_fn(context.model) context.scheduler = self.config.scheduler_fn(context.optimizer) - if self._checkpoint and self._checkpoint.extra and self._load_extra_fn: - self._load_extra_fn(self._checkpoint.extra, context) - cfg = self.config sampler_offset = context.iteration * cfg.batch_size sampler = ResumableDistributedSampler( diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index ce68447..f3556db 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -35,6 +35,33 @@ def test_single_process(): assert loaded_checkpoint.iteration == 30 +def test_checkpoint_with_extra(): + """Verify extra keys are saved as individual .pt files and loaded back.""" + model = torch.nn.Linear(10, 5) + optimizer = AdamW(model.parameters(), lr=1e-3) + optimizer.step() + + extra = { + "optimizer": optimizer.state_dict(), + "scheduler": {"last_epoch": 5}, + } + checkpoint = Checkpoint( + state_dict=model.state_dict(), epoch=1, iteration=10, extra=extra + ) + + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint.save(tmpdir) + + import os + + assert os.path.exists(os.path.join(tmpdir, "optimizer.pt")) + assert os.path.exists(os.path.join(tmpdir, "scheduler.pt")) + + loaded = Checkpoint.load(tmpdir) + assert loaded.extra["scheduler"]["last_epoch"] == 5 + assert "state" in loaded.extra["optimizer"] + + def simple_training(): model = torch.nn.Linear(10, 5) optimizer = AdamW(model.parameters(), lr=1e-3)