diff --git a/astrai/config/base.py b/astrai/config/base.py index 1d34295..b67507c 100644 --- a/astrai/config/base.py +++ b/astrai/config/base.py @@ -13,12 +13,21 @@ class BaseConfig: d[fld.name] = v elif v is None: d[fld.name] = None - elif isinstance(v, (dict, list)): + elif isinstance(v, (dict, list, tuple)): try: - json.dumps(v) - d[fld.name] = v + val = list(v) if isinstance(v, tuple) else v + json.dumps(val) + d[fld.name] = val except (TypeError, ValueError): pass + elif isinstance(v, BaseConfig): + d[fld.name] = v.to_dict() + elif hasattr(v, "__dataclass_fields__"): + sub = {} + for f in fields(v): + a = getattr(v, f.name) + sub[f.name] = list(a) if isinstance(a, tuple) else a + d[fld.name] = sub return d @classmethod diff --git a/astrai/serialization.py b/astrai/serialization.py index 2243d28..721780d 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -79,17 +79,6 @@ def load_model_weights(save_directory: str) -> dict: return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) -def _get_meta(save_path: Path) -> dict: - meta = {} - if get_rank() == 0: - meta = load_json(save_path / _META_FILE) - if dist.is_initialized(): - meta_list = [meta] - dist.broadcast_object_list(meta_list, src=0) - meta = meta_list[0] - return meta - - def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict: if not broadcast or not dist.is_initialized(): return load_safetensors(save_path / _WEIGHTS_FILE) @@ -128,6 +117,7 @@ class Checkpoint: iteration: int = 0 extra: Dict[str, Any] = field(default_factory=dict) meta: Dict[str, Any] = field(default_factory=dict) + config: Dict[str, Any] = field(default_factory=dict) def save(self, save_dir: str): save_path = Path(save_dir) @@ -143,6 +133,7 @@ class Checkpoint: **self.meta, } save_json(meta, save_path / _META_FILE) + save_json(self.config, save_path / _CONFIG_FILE) save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE) for key, value in self.extra.items(): save_torch(value, save_path / f"{key}.pt") @@ -151,8 +142,10 @@ class Checkpoint: def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": save_path = Path(save_dir) - meta = _get_meta(save_path) + meta = load_json(save_path / _META_FILE) state_dict = _load_state_dict(save_path, broadcast=broadcast) + config_path = save_path / _CONFIG_FILE + config = load_json(config_path) extra = {} for f in sorted(save_path.iterdir()): @@ -164,4 +157,5 @@ class Checkpoint: epoch=meta.get("epoch", 0), iteration=meta.get("iteration", 0), extra=extra, + config=config, ) diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index 6ab65b5..31f2260 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -160,7 +160,7 @@ class CheckpointCallback(TrainCallback): epoch=context.epoch, iteration=context.iteration, extra=extra, - meta=context.config.to_dict(), + config=context.model_config, ) context.checkpoint.save(save_path) diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 9fe0ae4..879830b 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -11,7 +11,7 @@ from astrai.model.components.lora import inject_lora from astrai.parallel.executor import BaseExecutor, ExecutorFactory from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.protocols import OptimizerProtocol, SchedulerProtocol -from astrai.serialization import Checkpoint, load_model_weights +from astrai.serialization import Checkpoint, load_json, load_model_weights from astrai.trainer.strategy import BaseStrategy, StrategyFactory @@ -24,6 +24,7 @@ class TrainContext: scheduler: SchedulerProtocol = field(default=None) checkpoint: Checkpoint = field(default=None) config: TrainConfig = field(default=None) + model_config: dict = field(default_factory=dict) executor: BaseExecutor = field(default=None) epoch: int = field(default=0) @@ -62,11 +63,21 @@ class TrainContextBuilder: model = cfg.model_fn() model = model.to(device=device) + model_config = {} + if self._resume_dir: + config_path = Path(self._resume_dir) / "config.json" + if config_path.exists(): + model_config = load_json(config_path) + + if not model_config and hasattr(model, "config"): + model_config = model.config.to_dict() + context = TrainContext( model=model, world_size=get_world_size(), rank=get_rank(), config=cfg, + model_config=model_config, executor=executor, ) @@ -75,6 +86,8 @@ class TrainContextBuilder: if (resume_path / "meta.json").exists(): checkpoint = Checkpoint.load(self._resume_dir) state_dict = checkpoint.state_dict + if checkpoint.config: + context.model_config = checkpoint.config else: checkpoint = None state_dict = load_model_weights(self._resume_dir)