From c424dfc29312e5122dff4adb162598fec46c40cf Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 28 May 2026 20:17:49 +0800 Subject: [PATCH] =?UTF-8?q?feat=20:=20checkpoint=20=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=20config.json?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Checkpoint.save 写入独立的 config.json(模型架构参数) - Checkpoint.load 读取 config.json,恢复时覆盖 context.model_config - TrainContext 新增 model_config 字段,builder 从 resume_dir/config.json 加载 - BaseConfig.to_dict 支持 tuple 和嵌套 dataclass(如 LoRAConfig) - 删除 _get_meta/_get_config wrapper,直接使用 load_json --- astrai/config/base.py | 15 ++++++++++++--- astrai/serialization.py | 18 ++++++------------ astrai/trainer/train_callback.py | 2 +- astrai/trainer/train_context.py | 15 ++++++++++++++- 4 files changed, 33 insertions(+), 17 deletions(-) diff --git a/astrai/config/base.py b/astrai/config/base.py index 1d34295..b67507c 100644 --- a/astrai/config/base.py +++ b/astrai/config/base.py @@ -13,12 +13,21 @@ class BaseConfig: d[fld.name] = v elif v is None: d[fld.name] = None - elif isinstance(v, (dict, list)): + elif isinstance(v, (dict, list, tuple)): try: - json.dumps(v) - d[fld.name] = v + val = list(v) if isinstance(v, tuple) else v + json.dumps(val) + d[fld.name] = val except (TypeError, ValueError): pass + elif isinstance(v, BaseConfig): + d[fld.name] = v.to_dict() + elif hasattr(v, "__dataclass_fields__"): + sub = {} + for f in fields(v): + a = getattr(v, f.name) + sub[f.name] = list(a) if isinstance(a, tuple) else a + d[fld.name] = sub return d @classmethod diff --git a/astrai/serialization.py b/astrai/serialization.py index 2243d28..721780d 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -79,17 +79,6 @@ def load_model_weights(save_directory: str) -> dict: return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) -def _get_meta(save_path: Path) -> dict: - meta = {} - if get_rank() == 0: - meta = load_json(save_path / _META_FILE) - if dist.is_initialized(): - meta_list = [meta] - dist.broadcast_object_list(meta_list, src=0) - meta = meta_list[0] - return meta - - def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict: if not broadcast or not dist.is_initialized(): return load_safetensors(save_path / _WEIGHTS_FILE) @@ -128,6 +117,7 @@ class Checkpoint: iteration: int = 0 extra: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict) + config: Dict[str, Any] = field(default_factory=dict) def save(self, save_dir: str): save_path = Path(save_dir) @@ -143,6 +133,7 @@ class Checkpoint: **self.meta, } save_json(meta, save_path / _META_FILE) + save_json(self.config, save_path / _CONFIG_FILE) save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE) for key, value in self.extra.items(): save_torch(value, save_path / f"{key}.pt") @@ -151,8 +142,10 @@ class Checkpoint: def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": save_path = Path(save_dir) - meta = _get_meta(save_path) + meta = load_json(save_path / _META_FILE) state_dict = _load_state_dict(save_path, broadcast=broadcast) + config_path = save_path / _CONFIG_FILE + config = load_json(config_path) extra = {} for f in sorted(save_path.iterdir()): @@ -164,4 +157,5 @@ class Checkpoint: epoch=meta.get("epoch", 0), iteration=meta.get("iteration", 0), extra=extra, + config=config, ) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 6ab65b5..31f2260 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -160,7 +160,7 @@ class CheckpointCallback(TrainCallback): epoch=context.epoch, iteration=context.iteration, extra=extra, - meta=context.config.to_dict(), + config=context.model_config, ) context.checkpoint.save(save_path) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 9fe0ae4..879830b 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -11,7 +11,7 @@ from astrai.model.components.lora import inject_lora from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.protocols import OptimizerProtocol, SchedulerProtocol -from astrai.serialization import Checkpoint, load_model_weights +from astrai.serialization import Checkpoint, load_json, load_model_weights from astrai.trainer.strategy import BaseStrategy, StrategyFactory @@ -24,6 +24,7 @@ class TrainContext: scheduler: SchedulerProtocol = field(default=None) checkpoint: Checkpoint = field(default=None) config: TrainConfig = field(default=None) + model_config: dict = field(default_factory=dict) executor: BaseExecutor = field(default=None) epoch: int = field(default=0) @@ -62,11 +63,21 @@ class TrainContextBuilder: model = cfg.model_fn() model = model.to(device=device) + model_config = {} + if self._resume_dir: + config_path = Path(self._resume_dir) / "config.json" + if config_path.exists(): + model_config = load_json(config_path) + + if not model_config and hasattr(model, "config"): + model_config = model.config.to_dict() + context = TrainContext( model=model, world_size=get_world_size(), rank=get_rank(), config=cfg, + model_config=model_config, executor=executor, ) @@ -75,6 +86,8 @@ class TrainContextBuilder: if (resume_path / "meta.json").exists(): checkpoint = Checkpoint.load(self._resume_dir) state_dict = checkpoint.state_dict + if checkpoint.config: + context.model_config = checkpoint.config else: checkpoint = None state_dict = load_model_weights(self._resume_dir)