diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index ad5db1a..d48785a 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -2,16 +2,15 @@ AutoModel base class for model loading and saving. """ -import json from contextlib import contextmanager from pathlib import Path from typing import Self, Union -import safetensors.torch as st import torch.nn as nn from astrai.config.model_config import BaseModelConfig, ConfigFactory from astrai.factory import BaseFactory +from astrai.serialization import load_model_config, load_model_weights, save_model @contextmanager @@ -60,25 +59,22 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module): model_path = Path(path) - # Load config config_path = model_path / "config.json" - if config_path.exists(): - with open(config_path, "r") as f: - raw = json.load(f) - config = ConfigFactory.load(raw) - model_type = config.model_type or "autoregressive_lm" - else: + if not config_path.exists(): raise FileNotFoundError(f"Config file not found: {config_path}") + raw = load_model_config(str(model_path)) + config = ConfigFactory.load(raw) + model_type = config.model_type or "autoregressive_lm" + actual_cls = AutoModel.get_component_class(model_type) with _disable_random_init(enable=disable_random_init): model = actual_cls(config) - # Load weights weights_path = model_path / "model.safetensors" if weights_path.exists(): - state_dict = st.load_file(str(weights_path)) + state_dict = load_model_weights(str(model_path)) model.load_state_dict(state_dict, strict=strict) return model @@ -87,14 +83,11 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module): self, save_directory: Union[str, Path], ) -> None: - save_path = Path(save_directory) - save_path.mkdir(parents=True, exist_ok=True) - - # Save config - self.config.to_file(str(save_path / "config.json")) - - # Save weights - st.save_file(self.state_dict(), str(save_path / "model.safetensors")) + save_model( + config=self.config.to_dict(), + state_dict=self.state_dict(), + save_directory=str(save_directory), + ) def to(self, *args, **kwargs) -> Self: """Move model to device/dtype.""" diff --git a/astrai/model/components/lora.py b/astrai/model/components/lora.py index 197bd43..1850663 100644 --- a/astrai/model/components/lora.py +++ b/astrai/model/components/lora.py @@ -1,15 +1,19 @@ -import json import logging from dataclasses import asdict, dataclass from pathlib import Path from typing import Optional, Set -import safetensors.torch as st import torch import torch.nn as nn import torch.nn.functional as F from astrai.model.components.linear import Linear +from astrai.serialization import ( + load_json, + load_safetensors, + save_json, + save_safetensors, +) logger = logging.getLogger(__name__) @@ -128,16 +132,14 @@ def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig): path = Path(save_dir) path.mkdir(parents=True, exist_ok=True) - st.save_file(lora_sd, str(path / "adapter_model.safetensors")) - with open(path / "adapter_config.json", "w") as f: - json.dump(asdict(config), f, indent=2) + save_safetensors(lora_sd, path / "adapter_model.safetensors") + save_json(asdict(config), path / "adapter_config.json") logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd)) def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig: path = Path(load_dir) - with open(path / "adapter_config.json") as f: - raw = json.load(f) + raw = load_json(path / "adapter_config.json") config = LoRAConfig( r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"]) ) @@ -157,7 +159,7 @@ def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig: target_modules=set(config.target_modules), ) - weights = st.load_file(str(path / "adapter_model.safetensors")) + weights = load_safetensors(path / "adapter_model.safetensors") try: missing, unexpected = model.load_state_dict(weights, strict=False) except RuntimeError as e: diff --git a/astrai/serialization.py b/astrai/serialization.py index 857d23c..1990ae5 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -1,7 +1,8 @@ import json import time +from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict import safetensors.torch as st import torch @@ -9,75 +10,101 @@ import torch.distributed as dist from astrai.parallel.setup import get_rank +_META_FILE = "meta.json" +_WEIGHTS_FILE = "model.safetensors" +_MODEL_CONFIG_FILE = "config.json" + +def save_safetensors(state_dict: dict, path: str | Path) -> None: + st.save_file(state_dict, str(path)) + + +def load_safetensors(path: str | Path) -> dict: + return st.load_file(str(path)) + + +def save_json(data: dict, path: str | Path) -> None: + with open(str(path), "w") as f: + json.dump(data, f, indent=2) + + +def load_json(path: str | Path) -> dict: + with open(str(path), "r") as f: + return json.load(f) + + +def save_torch(obj: Any, path: str | Path) -> None: + torch.save(obj, str(path)) + + +def load_torch(path: str | Path) -> Any: + return torch.load(str(path), map_location="cpu", weights_only=False) + + +@dataclass class Checkpoint: - def __init__( - self, - state_dict: Dict[str, Any], - epoch: int = 0, - iteration: int = 0, - extra: Optional[Dict[str, Any]] = None, - meta: Optional[Dict[str, Any]] = None, - ): - self.state_dict = state_dict - self.epoch = epoch - self.iteration = iteration - self.extra = extra or {} - self.meta = meta or {} - - def save( - self, - save_dir: str, - ) -> None: + state_dict: Dict[str, Any] = field(default_factory=dict) + epoch: int = 0 + iteration: int = 0 + extra: Dict[str, Any] = field(default_factory=dict) + meta: Dict[str, Any] = field(default_factory=dict) + def save(self, save_dir: str) -> None: save_path = Path(save_dir) save_path.mkdir(parents=True, exist_ok=True) - rank = get_rank() - if rank == 0: - meta = { - "epoch": self.epoch, - "iteration": self.iteration, - "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), - } - meta.update(self.meta) - with open(save_path / "meta.json", "w") as f: - json.dump(meta, f, indent=2) + if get_rank() != 0: + return - st.save_file(self.state_dict, save_path / "state_dict.safetensors") - if self.extra: - for key, value in self.extra.items(): - torch.save(value, save_path / f"{key}.pt") + meta = { + "epoch": self.epoch, + "iteration": self.iteration, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"), + **self.meta, + } + save_json(meta, save_path / _META_FILE) + save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE) + for key, value in self.extra.items(): + save_torch(value, save_path / f"{key}.pt") @classmethod - def load( - cls, - save_dir: str, - ) -> "Checkpoint": - - rank = get_rank() + def load(cls, save_dir: str) -> "Checkpoint": save_path = Path(save_dir) meta = {} - if rank == 0: - with open(Path(save_dir) / "meta.json", "r") as f: - meta = json.load(f) + if get_rank() == 0: + meta = load_json(save_path / _META_FILE) if dist.is_initialized(): meta_list = [meta] dist.broadcast_object_list(meta_list, src=0) meta = meta_list[0] - state_dict = st.load_file(save_path / "state_dict.safetensors") + state_dict = load_safetensors(save_path / _WEIGHTS_FILE) extra = {} for f in save_path.iterdir(): - if f.suffix == ".pt" and f.stem not in ("meta",): - extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False) + if f.suffix == ".pt": + extra[f.stem] = load_torch(f) return cls( state_dict=state_dict, - epoch=meta["epoch"], - iteration=meta["iteration"], - extra=extra or None, + epoch=meta.get("epoch", 0), + iteration=meta.get("iteration", 0), + extra=extra, ) + + +def save_model(config: dict, state_dict: dict, save_directory: str) -> None: + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + save_json(config, save_path / _MODEL_CONFIG_FILE) + save_safetensors(state_dict, save_path / _WEIGHTS_FILE) + + +def load_model_config(save_directory: str) -> dict: + return load_json(Path(save_directory) / _MODEL_CONFIG_FILE) + + +def load_model_weights(save_directory: str) -> dict: + return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index f3556db..bfac737 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -1,3 +1,4 @@ +import os import tempfile import torch @@ -36,7 +37,6 @@ def test_single_process(): def test_checkpoint_with_extra(): - """Verify extra keys are saved as individual .pt files and loaded back.""" model = torch.nn.Linear(10, 5) optimizer = AdamW(model.parameters(), lr=1e-3) optimizer.step() @@ -52,8 +52,6 @@ def test_checkpoint_with_extra(): with tempfile.TemporaryDirectory() as tmpdir: checkpoint.save(tmpdir) - import os - assert os.path.exists(os.path.join(tmpdir, "optimizer.pt")) assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))