From 6031020e37d76b75c4379e3fc00d0816e45c8259 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 28 May 2026 20:44:58 +0800 Subject: [PATCH] =?UTF-8?q?feat=20:=20load=5Fjson/load=5Fsafetensors=20?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20broadcast=EF=BC=8C=E8=B7=A8=E8=8A=82?= =?UTF-8?q?=E7=82=B9=E5=88=86=E5=B8=83=E5=BC=8F=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - load_json/load_safetensors/load_state_dict 新增 broadcast 参数 - broadcast=True 时 rank-0 读取后 broadcast_object_list 分发到所有 rank - load_state_dict 改为逐张量 broadcast,避免大模型 pickle 内存瓶颈 - 删除 _get_meta/_get_config wrapper,Checkpoint.load 直接调用 load_json - 参数注解 str | Path 统一为 Union[str, Path] --- astrai/serialization.py | 59 ++++++++++++++++++++++++++++------------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/astrai/serialization.py b/astrai/serialization.py index 721780d..e73b51b 100644 --- a/astrai/serialization.py +++ b/astrai/serialization.py @@ -3,7 +3,7 @@ import json import time from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, Union import safetensors.torch as st import torch @@ -16,29 +16,50 @@ _CONFIG_FILE = "config.json" _WEIGHTS_FILE = "model.safetensors" -def save_safetensors(state_dict: dict, path: str | Path): +def save_safetensors(state_dict: dict, path: Union[str, Path]): st.save_file(state_dict, str(path)) -def load_safetensors(path: str | Path) -> dict: - return st.load_file(str(path)) +def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict: + if not broadcast or not dist.is_initialized(): + return st.load_file(str(path)) + + rank = get_rank() + if rank == 0: + state_dict = st.load_file(str(path)) + else: + state_dict = {} + tmp = [state_dict] + dist.broadcast_object_list(tmp, src=0) + return tmp[0] -def save_json(data: dict, path: str | Path): +def save_json(data: dict, path: Union[str, Path]): with open(str(path), "w") as f: json.dump(data, f, indent=2) -def load_json(path: str | Path) -> dict: - with open(str(path), "r") as f: - return json.load(f) +def load_json(path: Union[str, Path], broadcast: bool = False) -> dict: + if not broadcast or not dist.is_initialized(): + with open(str(path), "r") as f: + return json.load(f) + + rank = get_rank() + if rank == 0: + with open(str(path), "r") as f: + data = json.load(f) + else: + data = {} + tmp = [data] + dist.broadcast_object_list(tmp, src=0) + return tmp[0] -def save_torch(obj: Any, path: str | Path): +def save_torch(obj: Any, path: Union[str, Path]): torch.save(obj, str(path)) -def load_torch(path: str | Path, broadcast: bool = False) -> Any: +def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any: if not broadcast or not dist.is_initialized(): return torch.load(str(path), map_location="cpu", weights_only=False) @@ -76,17 +97,18 @@ def load_model_config(save_directory: str) -> dict: def load_model_weights(save_directory: str) -> dict: - return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) + return load_state_dict(Path(save_directory) / _WEIGHTS_FILE) -def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict: +def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict: + path = Path(path) if not broadcast or not dist.is_initialized(): - return load_safetensors(save_path / _WEIGHTS_FILE) + return load_safetensors(path) rank = get_rank() if rank == 0: - state_dict = load_safetensors(save_path / _WEIGHTS_FILE) - specs: List[Tuple[str, List[int], str]] = [ + state_dict = load_safetensors(path) + specs = [ (k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1]) for k in sorted(state_dict) ] @@ -142,10 +164,9 @@ class Checkpoint: def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": save_path = Path(save_dir) - meta = load_json(save_path / _META_FILE) - state_dict = _load_state_dict(save_path, broadcast=broadcast) - config_path = save_path / _CONFIG_FILE - config = load_json(config_path) + meta = load_json(save_path / _META_FILE, broadcast) + config = load_json(save_path / _CONFIG_FILE, broadcast) + state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast) extra = {} for f in sorted(save_path.iterdir()):