fix : resume_dir 无权重文件时不强制加载,支持仅配置训练
- Checkpoint.load_any 统一处理 meta.json / model.safetensors / 无文件三种情况 - train_context.py 调用简化为单一路径,移除 load_model_weights 直接依赖
This commit is contained in:
parent
457e16ea3c
commit
a2512f8a5a
|
|
@ -3,7 +3,7 @@ import json
|
|||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
|
|
@ -180,3 +180,22 @@ class Checkpoint:
|
|||
extra=extra,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_any(cls, save_dir: str, broadcast: bool = False) -> Optional["Checkpoint"]:
|
||||
save_path = Path(save_dir)
|
||||
meta_path = save_path / _META_FILE
|
||||
weights_path = save_path / _WEIGHTS_FILE
|
||||
|
||||
if meta_path.exists():
|
||||
return cls.load(save_dir, broadcast=broadcast)
|
||||
|
||||
if weights_path.exists():
|
||||
state_dict = load_state_dict(weights_path, broadcast=broadcast)
|
||||
config = {}
|
||||
config_path = save_path / _CONFIG_FILE
|
||||
if config_path.exists():
|
||||
config = load_json(config_path, broadcast)
|
||||
return cls(state_dict=state_dict, config=config)
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from astrai.model.components.lora import inject_lora
|
|||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
||||
from astrai.serialization import Checkpoint, load_json
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
||||
|
|
@ -83,20 +83,14 @@ class TrainContextBuilder:
|
|||
executor=executor,
|
||||
)
|
||||
|
||||
if self._resume_dir is not None:
|
||||
resume_path = Path(self._resume_dir)
|
||||
if (resume_path / "meta.json").exists():
|
||||
checkpoint = Checkpoint.load(self._resume_dir)
|
||||
state_dict = checkpoint.state_dict
|
||||
if self._resume_dir:
|
||||
checkpoint = Checkpoint.load_any(self._resume_dir)
|
||||
if checkpoint is not None:
|
||||
model.load_state_dict(checkpoint.state_dict, strict=False)
|
||||
if checkpoint.config:
|
||||
context.model_config = checkpoint.config
|
||||
else:
|
||||
checkpoint = None
|
||||
state_dict = load_model_weights(self._resume_dir)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if checkpoint is not None:
|
||||
context.epoch = cfg.start_epoch
|
||||
context.iteration = cfg.start_batch
|
||||
context.epoch = checkpoint.epoch or cfg.start_epoch
|
||||
context.iteration = checkpoint.iteration or cfg.start_batch
|
||||
context.checkpoint = checkpoint
|
||||
|
||||
if cfg.lora is not None:
|
||||
|
|
|
|||
Loading…
Reference in New Issue