refactor: 统一序列化层,消除分散的 I/O 路径
- Checkpoint 改为 @dataclass,内聚 save/load 方法 - 提取 save_safetensors/load_safetensors/save_json/load_json 共享工具 - 新增 save_model/load_model_config/load_model_weights 模块函数 - automodel 和 lora 统一委托到 serialization 模块
This commit is contained in:
parent
1d26aa2e93
commit
65ab69543b
|
|
@ -2,16 +2,15 @@
|
|||
AutoModel base class for model loading and saving.
|
||||
"""
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Self, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config.model_config import BaseModelConfig, ConfigFactory
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.serialization import load_model_config, load_model_weights, save_model
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -60,25 +59,22 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
|
||||
model_path = Path(path)
|
||||
|
||||
# Load config
|
||||
config_path = model_path / "config.json"
|
||||
if config_path.exists():
|
||||
with open(config_path, "r") as f:
|
||||
raw = json.load(f)
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
else:
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {config_path}")
|
||||
|
||||
raw = load_model_config(str(model_path))
|
||||
config = ConfigFactory.load(raw)
|
||||
model_type = config.model_type or "autoregressive_lm"
|
||||
|
||||
actual_cls = AutoModel.get_component_class(model_type)
|
||||
|
||||
with _disable_random_init(enable=disable_random_init):
|
||||
model = actual_cls(config)
|
||||
|
||||
# Load weights
|
||||
weights_path = model_path / "model.safetensors"
|
||||
if weights_path.exists():
|
||||
state_dict = st.load_file(str(weights_path))
|
||||
state_dict = load_model_weights(str(model_path))
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
return model
|
||||
|
|
@ -87,14 +83,11 @@ class AutoModel(BaseFactory["AutoModel"], nn.Module):
|
|||
self,
|
||||
save_directory: Union[str, Path],
|
||||
) -> None:
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
self.config.to_file(str(save_path / "config.json"))
|
||||
|
||||
# Save weights
|
||||
st.save_file(self.state_dict(), str(save_path / "model.safetensors"))
|
||||
save_model(
|
||||
config=self.config.to_dict(),
|
||||
state_dict=self.state_dict(),
|
||||
save_directory=str(save_directory),
|
||||
)
|
||||
|
||||
def to(self, *args, **kwargs) -> Self:
|
||||
"""Move model to device/dtype."""
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.serialization import (
|
||||
load_json,
|
||||
load_safetensors,
|
||||
save_json,
|
||||
save_safetensors,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -128,16 +132,14 @@ def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
|||
|
||||
path = Path(save_dir)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
st.save_file(lora_sd, str(path / "adapter_model.safetensors"))
|
||||
with open(path / "adapter_config.json", "w") as f:
|
||||
json.dump(asdict(config), f, indent=2)
|
||||
save_safetensors(lora_sd, path / "adapter_model.safetensors")
|
||||
save_json(asdict(config), path / "adapter_config.json")
|
||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||
|
||||
|
||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||
path = Path(load_dir)
|
||||
with open(path / "adapter_config.json") as f:
|
||||
raw = json.load(f)
|
||||
raw = load_json(path / "adapter_config.json")
|
||||
config = LoRAConfig(
|
||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||
)
|
||||
|
|
@ -157,7 +159,7 @@ def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
|||
target_modules=set(config.target_modules),
|
||||
)
|
||||
|
||||
weights = st.load_file(str(path / "adapter_model.safetensors"))
|
||||
weights = load_safetensors(path / "adapter_model.safetensors")
|
||||
try:
|
||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||
except RuntimeError as e:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
|
|
@ -9,75 +10,101 @@ import torch.distributed as dist
|
|||
|
||||
from astrai.parallel.setup import get_rank
|
||||
|
||||
_META_FILE = "meta.json"
|
||||
_WEIGHTS_FILE = "model.safetensors"
|
||||
_MODEL_CONFIG_FILE = "config.json"
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, path: str | Path) -> None:
|
||||
st.save_file(state_dict, str(path))
|
||||
|
||||
|
||||
def load_safetensors(path: str | Path) -> dict:
|
||||
return st.load_file(str(path))
|
||||
|
||||
|
||||
def save_json(data: dict, path: str | Path) -> None:
|
||||
with open(str(path), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def load_json(path: str | Path) -> dict:
|
||||
with open(str(path), "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def save_torch(obj: Any, path: str | Path) -> None:
|
||||
torch.save(obj, str(path))
|
||||
|
||||
|
||||
def load_torch(path: str | Path) -> Any:
|
||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
def __init__(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
epoch: int = 0,
|
||||
iteration: int = 0,
|
||||
extra: Optional[Dict[str, Any]] = None,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
self.state_dict = state_dict
|
||||
self.epoch = epoch
|
||||
self.iteration = iteration
|
||||
self.extra = extra or {}
|
||||
self.meta = meta or {}
|
||||
|
||||
def save(
|
||||
self,
|
||||
save_dir: str,
|
||||
) -> None:
|
||||
state_dict: Dict[str, Any] = field(default_factory=dict)
|
||||
epoch: int = 0
|
||||
iteration: int = 0
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def save(self, save_dir: str) -> None:
|
||||
save_path = Path(save_dir)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
}
|
||||
meta.update(self.meta)
|
||||
with open(save_path / "meta.json", "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
if get_rank() != 0:
|
||||
return
|
||||
|
||||
st.save_file(self.state_dict, save_path / "state_dict.safetensors")
|
||||
if self.extra:
|
||||
for key, value in self.extra.items():
|
||||
torch.save(value, save_path / f"{key}.pt")
|
||||
meta = {
|
||||
"epoch": self.epoch,
|
||||
"iteration": self.iteration,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||
**self.meta,
|
||||
}
|
||||
save_json(meta, save_path / _META_FILE)
|
||||
save_safetensors(self.state_dict, save_path / _WEIGHTS_FILE)
|
||||
for key, value in self.extra.items():
|
||||
save_torch(value, save_path / f"{key}.pt")
|
||||
|
||||
@classmethod
|
||||
def load(
|
||||
cls,
|
||||
save_dir: str,
|
||||
) -> "Checkpoint":
|
||||
|
||||
rank = get_rank()
|
||||
def load(cls, save_dir: str) -> "Checkpoint":
|
||||
save_path = Path(save_dir)
|
||||
|
||||
meta = {}
|
||||
if rank == 0:
|
||||
with open(Path(save_dir) / "meta.json", "r") as f:
|
||||
meta = json.load(f)
|
||||
if get_rank() == 0:
|
||||
meta = load_json(save_path / _META_FILE)
|
||||
|
||||
if dist.is_initialized():
|
||||
meta_list = [meta]
|
||||
dist.broadcast_object_list(meta_list, src=0)
|
||||
meta = meta_list[0]
|
||||
|
||||
state_dict = st.load_file(save_path / "state_dict.safetensors")
|
||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
|
||||
extra = {}
|
||||
for f in save_path.iterdir():
|
||||
if f.suffix == ".pt" and f.stem not in ("meta",):
|
||||
extra[f.stem] = torch.load(f, map_location="cpu", weights_only=False)
|
||||
if f.suffix == ".pt":
|
||||
extra[f.stem] = load_torch(f)
|
||||
|
||||
return cls(
|
||||
state_dict=state_dict,
|
||||
epoch=meta["epoch"],
|
||||
iteration=meta["iteration"],
|
||||
extra=extra or None,
|
||||
epoch=meta.get("epoch", 0),
|
||||
iteration=meta.get("iteration", 0),
|
||||
extra=extra,
|
||||
)
|
||||
|
||||
|
||||
def save_model(config: dict, state_dict: dict, save_directory: str) -> None:
|
||||
save_path = Path(save_directory)
|
||||
save_path.mkdir(parents=True, exist_ok=True)
|
||||
save_json(config, save_path / _MODEL_CONFIG_FILE)
|
||||
save_safetensors(state_dict, save_path / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def load_model_config(save_directory: str) -> dict:
|
||||
return load_json(Path(save_directory) / _MODEL_CONFIG_FILE)
|
||||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
|
@ -36,7 +37,6 @@ def test_single_process():
|
|||
|
||||
|
||||
def test_checkpoint_with_extra():
|
||||
"""Verify extra keys are saved as individual .pt files and loaded back."""
|
||||
model = torch.nn.Linear(10, 5)
|
||||
optimizer = AdamW(model.parameters(), lr=1e-3)
|
||||
optimizer.step()
|
||||
|
|
@ -52,8 +52,6 @@ def test_checkpoint_with_extra():
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint.save(tmpdir)
|
||||
|
||||
import os
|
||||
|
||||
assert os.path.exists(os.path.join(tmpdir, "optimizer.pt"))
|
||||
assert os.path.exists(os.path.join(tmpdir, "scheduler.pt"))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue