feat : checkpoint 支持保存 config.json
- Checkpoint.save 写入独立的 config.json(模型架构参数) - Checkpoint.load 读取 config.json,恢复时覆盖 context.model_config - TrainContext 新增 model_config 字段,builder 从 resume_dir/config.json 加载 - BaseConfig.to_dict 支持 tuple 和嵌套 dataclass(如 LoRAConfig) - 删除 _get_meta/_get_config wrapper,直接使用 load_json
This commit is contained in:
parent
3a28e52e98
commit
c424dfc293
|
|
@ -13,12 +13,21 @@ class BaseConfig:
|
||||||
d[fld.name] = v
|
d[fld.name] = v
|
||||||
elif v is None:
|
elif v is None:
|
||||||
d[fld.name] = None
|
d[fld.name] = None
|
||||||
elif isinstance(v, (dict, list)):
|
elif isinstance(v, (dict, list, tuple)):
|
||||||
try:
|
try:
|
||||||
json.dumps(v)
|
val = list(v) if isinstance(v, tuple) else v
|
||||||
d[fld.name] = v
|
json.dumps(val)
|
||||||
|
d[fld.name] = val
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
elif isinstance(v, BaseConfig):
|
||||||
|
d[fld.name] = v.to_dict()
|
||||||
|
elif hasattr(v, "__dataclass_fields__"):
|
||||||
|
sub = {}
|
||||||
|
for f in fields(v):
|
||||||
|
a = getattr(v, f.name)
|
||||||
|
sub[f.name] = list(a) if isinstance(a, tuple) else a
|
||||||
|
d[fld.name] = sub
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -79,17 +79,6 @@ def load_model_weights(save_directory: str) -> dict:
|
||||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
def _get_meta(save_path: Path) -> dict:
|
|
||||||
meta = {}
|
|
||||||
if get_rank() == 0:
|
|
||||||
meta = load_json(save_path / _META_FILE)
|
|
||||||
if dist.is_initialized():
|
|
||||||
meta_list = [meta]
|
|
||||||
dist.broadcast_object_list(meta_list, src=0)
|
|
||||||
meta = meta_list[0]
|
|
||||||
return meta
|
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
|
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
|
||||||
if not broadcast or not dist.is_initialized():
|
if not broadcast or not dist.is_initialized():
|
||||||
return load_safetensors(save_path / _WEIGHTS_FILE)
|
return load_safetensors(save_path / _WEIGHTS_FILE)
|
||||||
|
|
@ -128,6 +117,7 @@ class Checkpoint:
|
||||||
iteration: int = 0
|
iteration: int = 0
|
||||||
extra: Dict[str, Any] = field(default_factory=dict)
|
extra: Dict[str, Any] = field(default_factory=dict)
|
||||||
meta: Dict[str, Any] = field(default_factory=dict)
|
meta: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
config: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def save(self, save_dir: str):
|
def save(self, save_dir: str):
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
@ -143,6 +133,7 @@ class Checkpoint:
|
||||||
**self.meta,
|
**self.meta,
|
||||||
}
|
}
|
||||||
save_json(meta, save_path / _META_FILE)
|
save_json(meta, save_path / _META_FILE)
|
||||||
|
save_json(self.config, save_path / _CONFIG_FILE)
|
||||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
||||||
for key, value in self.extra.items():
|
for key, value in self.extra.items():
|
||||||
save_torch(value, save_path / f"{key}.pt")
|
save_torch(value, save_path / f"{key}.pt")
|
||||||
|
|
@ -151,8 +142,10 @@ class Checkpoint:
|
||||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = _get_meta(save_path)
|
meta = load_json(save_path / _META_FILE)
|
||||||
state_dict = _load_state_dict(save_path, broadcast=broadcast)
|
state_dict = _load_state_dict(save_path, broadcast=broadcast)
|
||||||
|
config_path = save_path / _CONFIG_FILE
|
||||||
|
config = load_json(config_path)
|
||||||
|
|
||||||
extra = {}
|
extra = {}
|
||||||
for f in sorted(save_path.iterdir()):
|
for f in sorted(save_path.iterdir()):
|
||||||
|
|
@ -164,4 +157,5 @@ class Checkpoint:
|
||||||
epoch=meta.get("epoch", 0),
|
epoch=meta.get("epoch", 0),
|
||||||
iteration=meta.get("iteration", 0),
|
iteration=meta.get("iteration", 0),
|
||||||
extra=extra,
|
extra=extra,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
epoch=context.epoch,
|
epoch=context.epoch,
|
||||||
iteration=context.iteration,
|
iteration=context.iteration,
|
||||||
extra=extra,
|
extra=extra,
|
||||||
meta=context.config.to_dict(),
|
config=context.model_config,
|
||||||
)
|
)
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ from astrai.model.components.lora import inject_lora
|
||||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||||
from astrai.serialization import Checkpoint, load_model_weights
|
from astrai.serialization import Checkpoint, load_json, load_model_weights
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -24,6 +24,7 @@ class TrainContext:
|
||||||
scheduler: SchedulerProtocol = field(default=None)
|
scheduler: SchedulerProtocol = field(default=None)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
checkpoint: Checkpoint = field(default=None)
|
||||||
config: TrainConfig = field(default=None)
|
config: TrainConfig = field(default=None)
|
||||||
|
model_config: dict = field(default_factory=dict)
|
||||||
executor: BaseExecutor = field(default=None)
|
executor: BaseExecutor = field(default=None)
|
||||||
|
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
|
|
@ -62,11 +63,21 @@ class TrainContextBuilder:
|
||||||
model = cfg.model_fn()
|
model = cfg.model_fn()
|
||||||
model = model.to(device=device)
|
model = model.to(device=device)
|
||||||
|
|
||||||
|
model_config = {}
|
||||||
|
if self._resume_dir:
|
||||||
|
config_path = Path(self._resume_dir) / "config.json"
|
||||||
|
if config_path.exists():
|
||||||
|
model_config = load_json(config_path)
|
||||||
|
|
||||||
|
if not model_config and hasattr(model, "config"):
|
||||||
|
model_config = model.config.to_dict()
|
||||||
|
|
||||||
context = TrainContext(
|
context = TrainContext(
|
||||||
model=model,
|
model=model,
|
||||||
world_size=get_world_size(),
|
world_size=get_world_size(),
|
||||||
rank=get_rank(),
|
rank=get_rank(),
|
||||||
config=cfg,
|
config=cfg,
|
||||||
|
model_config=model_config,
|
||||||
executor=executor,
|
executor=executor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -75,6 +86,8 @@ class TrainContextBuilder:
|
||||||
if (resume_path / "meta.json").exists():
|
if (resume_path / "meta.json").exists():
|
||||||
checkpoint = Checkpoint.load(self._resume_dir)
|
checkpoint = Checkpoint.load(self._resume_dir)
|
||||||
state_dict = checkpoint.state_dict
|
state_dict = checkpoint.state_dict
|
||||||
|
if checkpoint.config:
|
||||||
|
context.model_config = checkpoint.config
|
||||||
else:
|
else:
|
||||||
checkpoint = None
|
checkpoint = None
|
||||||
state_dict = load_model_weights(self._resume_dir)
|
state_dict = load_model_weights(self._resume_dir)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue