refactor: 统一序列化层,消除分散的 I/O 路径

- Checkpoint 改为 @dataclass,内聚 save/load 方法
- 提取 save_safetensors/load_safetensors/save_json/load_json 共享工具
- 新增 save_model/load_model_config/load_model_weights 模块函数
- automodel 和 lora 统一委托到 serialization 模块
This commit is contained in:
ViperEkura 2026-05-26 16:29:22 +08:00
parent 1d26aa2e93
commit 65ab69543b
4 changed files with 98 additions and 78 deletions

View File

@ -2,16 +2,15 @@
AutoModel base class for model loading and saving. AutoModel base class for model loading and saving.
""" """
import json
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Self, Union from typing import Self, Union
import safetensors.torch as st
import torch.nn as nn import torch.nn as nn
from astrai.config.model_config import BaseModelConfig, ConfigFactory from astrai.config.model_config import BaseModelConfig, ConfigFactory
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.serialization import load_model_config, load_model_weights, save_model
@contextmanager @contextmanager
@ -60,25 +59,22 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
model_path = Path(path) model_path = Path(path)
# Load config
config_path = model_path / "config.json" config_path = model_path / "config.json"
if config_path.exists(): if not config_path.exists():
with open(config_path, "r") as f:
raw = json.load(f)
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
else:
raise FileNotFoundError(f"Config file not found: {config_path}") raise FileNotFoundError(f"Config file not found: {config_path}")
raw = load_model_config(str(model_path))
config = ConfigFactory.load(raw)
model_type = config.model_type or "autoregressive_lm"
actual_cls = AutoModel.get_component_class(model_type) actual_cls = AutoModel.get_component_class(model_type)
with _disable_random_init(enable=disable_random_init): with _disable_random_init(enable=disable_random_init):
model = actual_cls(config) model = actual_cls(config)
# Load weights
weights_path = model_path / "model.safetensors" weights_path = model_path / "model.safetensors"
if weights_path.exists(): if weights_path.exists():
state_dict = st.load_file(str(weights_path)) state_dict = load_model_weights(str(model_path))
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
return model return model
@ -87,14 +83,11 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
self, self,
save_directory: Union[str, Path], save_directory: Union[str, Path],
) -> None: ) -> None:
save_path = Path(save_directory) save_model(
save_path.mkdir(parents=True, exist_ok=True) config=self.config.to_dict(),
state_dict=self.state_dict(),
# Save config save_directory=str(save_directory),
self.config.to_file(str(save_path / "config.json")) )
# Save weights
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
def to(self, *args, **kwargs) -> Self: def to(self, *args, **kwargs) -> Self:
"""Move model to device/dtype.""" """Move model to device/dtype."""

View File

@ -1,15 +1,19 @@
import json
import logging import logging
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from pathlib import Path from pathlib import Path
from typing import Optional, Set from typing import Optional, Set
import safetensors.torch as st
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from astrai.model.components.linear import Linear from astrai.model.components.linear import Linear
from astrai.serialization import (
load_json,
load_safetensors,
save_json,
save_safetensors,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -128,16 +132,14 @@ def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
path = Path(save_dir) path = Path(save_dir)
path.mkdir(parents=True, exist_ok=True) path.mkdir(parents=True, exist_ok=True)
st.save_file(lora_sd, str(path / "adapter_model.safetensors")) save_safetensors(lora_sd, path / "adapter_model.safetensors")
with open(path / "adapter_config.json", "w") as f: save_json(asdict(config), path / "adapter_config.json")
json.dump(asdict(config), f, indent=2)
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd)) logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig: def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
path = Path(load_dir) path = Path(load_dir)
with open(path / "adapter_config.json") as f: raw = load_json(path / "adapter_config.json")
raw = json.load(f)
config = LoRAConfig( config = LoRAConfig(
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"]) r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
) )
@ -157,7 +159,7 @@ def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
target_modules=set(config.target_modules), target_modules=set(config.target_modules),
) )
weights = st.load_file(str(path / "adapter_model.safetensors")) weights = load_safetensors(path / "adapter_model.safetensors")
try: try:
missing, unexpected = model.load_state_dict(weights, strict=False) missing, unexpected = model.load_state_dict(weights, strict=False)
except RuntimeError as e: except RuntimeError as e:

View File

@ -1,7 +1,8 @@
import json import json
import time import time
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@ -9,75 +10,101 @@ import torch.distributed as dist
from astrai.parallel.setup import get_rank from astrai.parallel.setup import get_rank
_META_FILE = "meta.json"
_WEIGHTS_FILE = "model.safetensors"
_MODEL_CONFIG_FILE = "config.json"
def save_safetensors(state_dict: dict, path: str | Path) -> None:
st.save_file(state_dict, str(path))
def load_safetensors(path: str | Path) -> dict:
return st.load_file(str(path))
def save_json(data: dict, path: str | Path) -> None:
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
def load_json(path: str | Path) -> dict:
with open(str(path), "r") as f:
return json.load(f)
def save_torch(obj: Any, path: str | Path) -> None:
torch.save(obj, str(path))
def load_torch(path: str | Path) -> Any:
return torch.load(str(path), map_location="cpu", weights_only=False)
@dataclass
class Checkpoint: class Checkpoint:
def __init__( state_dict: Dict[str, Any] = field(default_factory=dict)
self, epoch: int = 0
state_dict: Dict[str, Any], iteration: int = 0
epoch: int = 0, extra: Dict[str, Any] = field(default_factory=dict)
iteration: int = 0, meta: Dict[str, Any] = field(default_factory=dict)
extra: Optional[Dict[str, Any]] = None,
meta: Optional[Dict[str, Any]] = None,
):
self.state_dict = state_dict
self.epoch = epoch
self.iteration = iteration
self.extra = extra or {}
self.meta = meta or {}
def save(
self,
save_dir: str,
) -> None:
def save(self, save_dir: str) -> None:
save_path = Path(save_dir) save_path = Path(save_dir)
save_path.mkdir(parents=True, exist_ok=True) save_path.mkdir(parents=True, exist_ok=True)
rank = get_rank() if get_rank() != 0:
if rank == 0: return
meta = {
"epoch": self.epoch,
"iteration": self.iteration,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
}
meta.update(self.meta)
with open(save_path / "meta.json", "w") as f:
json.dump(meta, f, indent=2)
st.save_file(self.state_dict, save_path / "state_dict.safetensors") meta = {
if self.extra: "epoch": self.epoch,
for key, value in self.extra.items(): "iteration": self.iteration,
torch.save(value, save_path / f"{key}.pt") "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
**self.meta,
}
save_json(meta, save_path / _META_FILE)
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
for key, value in self.extra.items():
save_torch(value, save_path / f"{key}.pt")
@classmethod @classmethod
def load( def load(cls, save_dir: str) -> "Checkpoint":
cls,
save_dir: str,
) -> "Checkpoint":
rank = get_rank()
save_path = Path(save_dir) save_path = Path(save_dir)
meta = {} meta = {}
if rank == 0: if get_rank() == 0:
with open(Path(save_dir) / "meta.json", "r") as f: meta = load_json(save_path / _META_FILE)
meta = json.load(f)
if dist.is_initialized(): if dist.is_initialized():
meta_list = [meta] meta_list = [meta]
dist.broadcast_object_list(meta_list, src=0) dist.broadcast_object_list(meta_list, src=0)
meta = meta_list[0] meta = meta_list[0]
state_dict = st.load_file(save_path / "state_dict.safetensors") state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
extra = {} extra = {}
for f in save_path.iterdir(): for f in save_path.iterdir():
if f.suffix == ".pt" and f.stem not in ("meta",): if f.suffix == ".pt":
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False) extra[f.stem] = load_torch(f)
return cls( return cls(
state_dict=state_dict, state_dict=state_dict,
epoch=meta["epoch"], epoch=meta.get("epoch", 0),
iteration=meta["iteration"], iteration=meta.get("iteration", 0),
extra=extra or None, extra=extra,
) )
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
save_path = Path(save_directory)
save_path.mkdir(parents=True, exist_ok=True)
save_json(config, save_path / _MODEL_CONFIG_FILE)
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
def load_model_config(save_directory: str) -> dict:
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)

View File

@ -1,3 +1,4 @@
import os
import tempfile import tempfile
import torch import torch
@ -36,7 +37,6 @@ def test_single_process():
def test_checkpoint_with_extra(): def test_checkpoint_with_extra():
"""Verify extra keys are saved as individual .pt files and loaded back."""
model = torch.nn.Linear(10, 5) model = torch.nn.Linear(10, 5)
optimizer = AdamW(model.parameters(), lr=1e-3) optimizer = AdamW(model.parameters(), lr=1e-3)
optimizer.step() optimizer.step()
@ -52,8 +52,6 @@ def test_checkpoint_with_extra():
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
checkpoint.save(tmpdir) checkpoint.save(tmpdir)
import os
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt")) assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt")) assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))