From a2512f8a5aa019fd8f8250dfb7339f7f9d3e5f04 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 13 Jun 2026 15:40:14 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20resume=5Fdir=20=E6=97=A0=E6=9D=83?= =?UTF-8?q?=E9=87=8D=E6=96=87=E4=BB=B6=E6=97=B6=E4=B8=8D=E5=BC=BA=E5=88=B6?= =?UTF-8?q?=E5=8A=A0=E8=BD=BD=EF=BC=8C=E6=94=AF=E6=8C=81=E4=BB=85=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Checkpoint.load_any 统一处理 meta.json / model.safetensors / 无文件三种情况 - train_context.py 调用简化为单一路径,移除 load_model_weights 直接依赖 --- astrai/serialization.py | 21 ++++++++++++++++++++- astrai/trainer/train_context.py | 22 ++++++++-------------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/astrai/serialization.py b/astrai/serialization.py index e73b51b..9537fe9 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -3,7 +3,7 @@ import json import time from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import safetensors.torch as st import torch @@ -180,3 +180,22 @@ class Checkpoint: extra=extra, config=config, ) + + @classmethod + def load_any(cls, save_dir: str, broadcast: bool = False) -> Optional["Checkpoint"]: + save_path = Path(save_dir) + meta_path = save_path / _META_FILE + weights_path = save_path / _WEIGHTS_FILE + + if meta_path.exists(): + return cls.load(save_dir, broadcast=broadcast) + + if weights_path.exists(): + state_dict = load_state_dict(weights_path, broadcast=broadcast) + config = {} + config_path = save_path / _CONFIG_FILE + if config_path.exists(): + config = load_json(config_path, broadcast) + return cls(state_dict=state_dict, config=config) + + return None diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 8993716..71b4000 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -12,7 +12,7 @@ from astrai.model.components.lora import inject_lora from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.protocols import OptimizerProtocol, SchedulerProtocol -from astrai.serialization import Checkpoint, load_json, load_model_weights +from astrai.serialization import Checkpoint, load_json from astrai.trainer.strategy import BaseStrategy, StrategyFactory @@ -83,21 +83,15 @@ class TrainContextBuilder: executor=executor, ) - if self._resume_dir is not None: - resume_path = Path(self._resume_dir) - if (resume_path / "meta.json").exists(): - checkpoint = Checkpoint.load(self._resume_dir) - state_dict = checkpoint.state_dict + if self._resume_dir: + checkpoint = Checkpoint.load_any(self._resume_dir) + if checkpoint is not None: + model.load_state_dict(checkpoint.state_dict, strict=False) if checkpoint.config: context.model_config = checkpoint.config - else: - checkpoint = None - state_dict = load_model_weights(self._resume_dir) - model.load_state_dict(state_dict, strict=False) - if checkpoint is not None: - context.epoch = cfg.start_epoch - context.iteration = cfg.start_batch - context.checkpoint = checkpoint + context.epoch = checkpoint.epoch or cfg.start_epoch + context.iteration = checkpoint.iteration or cfg.start_batch + context.checkpoint = checkpoint if cfg.lora is not None: inject_lora(