feat : checkpoint 支持保存 config.json

- Checkpoint.save 写入独立的 config.json(模型架构参数)
- Checkpoint.load 读取 config.json,恢复时覆盖 context.model_config
- TrainContext 新增 model_config 字段,builder 从 resume_dir/config.json 加载
- BaseConfig.to_dict 支持 tuple 和嵌套 dataclass(如 LoRAConfig)
- 删除 _get_meta/_get_config wrapper,直接使用 load_json
This commit is contained in:
ViperEkura 2026-05-28 20:17:49 +08:00
parent 3a28e52e98
commit c424dfc293
4 changed files with 33 additions and 17 deletions

View File

@ -13,12 +13,21 @@ class BaseConfig:
d[fld.name] = v d[fld.name] = v
elif v is None: elif v is None:
d[fld.name] = None d[fld.name] = None
elif isinstance(v, (dict, list)): elif isinstance(v, (dict, list, tuple)):
try: try:
json.dumps(v) val = list(v) if isinstance(v, tuple) else v
d[fld.name] = v json.dumps(val)
d[fld.name] = val
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
elif isinstance(v, BaseConfig):
d[fld.name] = v.to_dict()
elif hasattr(v, "__dataclass_fields__"):
sub = {}
for f in fields(v):
a = getattr(v, f.name)
sub[f.name] = list(a) if isinstance(a, tuple) else a
d[fld.name] = sub
return d return d
@classmethod @classmethod

View File

@ -79,17 +79,6 @@ def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
def _get_meta(save_path: Path) -> dict:
meta = {}
if get_rank() == 0:
meta = load_json(save_path / _META_FILE)
if dist.is_initialized():
meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0]
return meta
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict: def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized(): if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE) return load_safetensors(save_path / _WEIGHTS_FILE)
@ -128,6 +117,7 @@ class Checkpoint:
iteration: int = 0 iteration: int = 0
extra: Dict[str, Any] = field(default_factory=dict) extra: Dict[str, Any] = field(default_factory=dict)
meta: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict)
config: Dict[str, Any] = field(default_factory=dict)
def save(self, save_dir: str): def save(self, save_dir: str):
save_path = Path(save_dir) save_path = Path(save_dir)
@ -143,6 +133,7 @@ class Checkpoint:
**self.meta, **self.meta,
} }
save_json(meta, save_path / _META_FILE) save_json(meta, save_path / _META_FILE)
save_json(self.config, save_path / _CONFIG_FILE)
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE) save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
for key, value in self.extra.items(): for key, value in self.extra.items():
save_torch(value, save_path / f"{key}.pt") save_torch(value, save_path / f"{key}.pt")
@ -151,8 +142,10 @@ class Checkpoint:
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir) save_path = Path(save_dir)
meta = _get_meta(save_path) meta = load_json(save_path / _META_FILE)
state_dict = _load_state_dict(save_path, broadcast=broadcast) state_dict = _load_state_dict(save_path, broadcast=broadcast)
config_path = save_path / _CONFIG_FILE
config = load_json(config_path)
extra = {} extra = {}
for f in sorted(save_path.iterdir()): for f in sorted(save_path.iterdir()):
@ -164,4 +157,5 @@ class Checkpoint:
epoch=meta.get("epoch", 0), epoch=meta.get("epoch", 0),
iteration=meta.get("iteration", 0), iteration=meta.get("iteration", 0),
extra=extra, extra=extra,
config=config,
) )

View File

@ -160,7 +160,7 @@ class CheckpointCallback(TrainCallback):
epoch=context.epoch, epoch=context.epoch,
iteration=context.iteration, iteration=context.iteration,
extra=extra, extra=extra,
meta=context.config.to_dict(), config=context.model_config,
) )
context.checkpoint.save(save_path) context.checkpoint.save(save_path)

View File

@ -11,7 +11,7 @@ from astrai.model.components.lora import inject_lora
from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.executor import BaseExecutor, ExecutorFactory
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.protocols import OptimizerProtocol, SchedulerProtocol from astrai.protocols import OptimizerProtocol, SchedulerProtocol
from astrai.serialization import Checkpoint, load_model_weights from astrai.serialization import Checkpoint, load_json, load_model_weights
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory
@ -24,6 +24,7 @@ class TrainContext:
scheduler: SchedulerProtocol = field(default=None) scheduler: SchedulerProtocol = field(default=None)
checkpoint: Checkpoint = field(default=None) checkpoint: Checkpoint = field(default=None)
config: TrainConfig = field(default=None) config: TrainConfig = field(default=None)
model_config: dict = field(default_factory=dict)
executor: BaseExecutor = field(default=None) executor: BaseExecutor = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
@ -62,11 +63,21 @@ class TrainContextBuilder:
model = cfg.model_fn() model = cfg.model_fn()
model = model.to(device=device) model = model.to(device=device)
model_config = {}
if self._resume_dir:
config_path = Path(self._resume_dir) / "config.json"
if config_path.exists():
model_config = load_json(config_path)
if not model_config and hasattr(model, "config"):
model_config = model.config.to_dict()
context = TrainContext( context = TrainContext(
model=model, model=model,
world_size=get_world_size(), world_size=get_world_size(),
rank=get_rank(), rank=get_rank(),
config=cfg, config=cfg,
model_config=model_config,
executor=executor, executor=executor,
) )
@ -75,6 +86,8 @@ class TrainContextBuilder:
if (resume_path / "meta.json").exists(): if (resume_path / "meta.json").exists():
checkpoint = Checkpoint.load(self._resume_dir) checkpoint = Checkpoint.load(self._resume_dir)
state_dict = checkpoint.state_dict state_dict = checkpoint.state_dict
if checkpoint.config:
context.model_config = checkpoint.config
else: else:
checkpoint = None checkpoint = None
state_dict = load_model_weights(self._resume_dir) state_dict = load_model_weights(self._resume_dir)