diff --git a/astrai/serialization.py b/astrai/serialization.py index e73b51b..9537fe9 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -3,7 +3,7 @@ import json import time from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import safetensors.torch as st import torch @@ -180,3 +180,22 @@ class Checkpoint: extra=extra, config=config, ) + + @classmethod + def load_any(cls, save_dir: str, broadcast: bool = False) -> Optional["Checkpoint"]: + save_path = Path(save_dir) + meta_path = save_path / _META_FILE + weights_path = save_path / _WEIGHTS_FILE + + if meta_path.exists(): + return cls.load(save_dir, broadcast=broadcast) + + if weights_path.exists(): + state_dict = load_state_dict(weights_path, broadcast=broadcast) + config = {} + config_path = save_path / _CONFIG_FILE + if config_path.exists(): + config = load_json(config_path, broadcast) + return cls(state_dict=state_dict, config=config) + + return None diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 8993716..71b4000 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -12,7 +12,7 @@ from astrai.model.components.lora import inject_lora from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.protocols import OptimizerProtocol, SchedulerProtocol -from astrai.serialization import Checkpoint, load_json, load_model_weights +from astrai.serialization import Checkpoint, load_json from astrai.trainer.strategy import BaseStrategy, StrategyFactory @@ -83,21 +83,15 @@ class TrainContextBuilder: executor=executor, ) - if self._resume_dir is not None: - resume_path = Path(self._resume_dir) - if (resume_path / "meta.json").exists(): - checkpoint = Checkpoint.load(self._resume_dir) - state_dict = checkpoint.state_dict + if self._resume_dir: + checkpoint = Checkpoint.load_any(self._resume_dir) + if checkpoint is not None: + model.load_state_dict(checkpoint.state_dict, strict=False) if checkpoint.config: context.model_config = checkpoint.config - else: - checkpoint = None - state_dict = load_model_weights(self._resume_dir) - model.load_state_dict(state_dict, strict=False) - if checkpoint is not None: - context.epoch = cfg.start_epoch - context.iteration = cfg.start_batch - context.checkpoint = checkpoint + context.epoch = checkpoint.epoch or cfg.start_epoch + context.iteration = checkpoint.iteration or cfg.start_batch + context.checkpoint = checkpoint if cfg.lora is not None: inject_lora(