fix : resume_dir 无权重文件时不强制加载,支持仅配置训练

- Checkpoint.load_any 统一处理 meta.json / model.safetensors / 无文件三种情况
- train_context.py 调用简化为单一路径,移除 load_model_weights 直接依赖
This commit is contained in:
ViperEkura 2026-06-13 15:40:14 +08:00
parent 457e16ea3c
commit a2512f8a5a
2 changed files with 28 additions and 15 deletions

View File

@ -3,7 +3,7 @@ import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Union from typing import Any, Dict, Optional, Union
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@ -180,3 +180,22 @@ class Checkpoint:
extra=extra, extra=extra,
config=config, config=config,
) )
@classmethod
def load_any(cls, save_dir: str, broadcast: bool = False) -> Optional["Checkpoint"]:
save_path = Path(save_dir)
meta_path = save_path / _META_FILE
weights_path = save_path / _WEIGHTS_FILE
if meta_path.exists():
return cls.load(save_dir, broadcast=broadcast)
if weights_path.exists():
state_dict = load_state_dict(weights_path, broadcast=broadcast)
config = {}
config_path = save_path / _CONFIG_FILE
if config_path.exists():
config = load_json(config_path, broadcast)
return cls(state_dict=state_dict, config=config)
return None

View File

@ -12,7 +12,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_json, load_model_weights from astrai.serialization import Checkpoint, load_json
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -83,20 +83,14 @@ class TrainContextBuilder:
executor=executor, executor=executor,
) )
if self._resume_dir is not None: if self._resume_dir:
resume_path = Path(self._resume_dir) checkpoint = Checkpoint.load_any(self._resume_dir)
if (resume_path / "meta.json").exists(): if checkpoint is not None:
checkpoint = Checkpoint.load(self._resume_dir) model.load_state_dict(checkpoint.state_dict, strict=False)
state_dict = checkpoint.state_dict
if checkpoint.config: if checkpoint.config:
context.model_config = checkpoint.config context.model_config = checkpoint.config
else: context.epoch = checkpoint.epoch or cfg.start_epoch
checkpoint = None context.iteration = checkpoint.iteration or cfg.start_batch
state_dict = load_model_weights(self._resume_dir)
model.load_state_dict(state_dict, strict=False)
if checkpoint is not None:
context.epoch = cfg.start_epoch
context.iteration = cfg.start_batch
context.checkpoint = checkpoint context.checkpoint = checkpoint
if cfg.lora is not None: if cfg.lora is not None: