feat : load_json/load_safetensors 支持 broadcast,跨节点分布式加载

- load_json/load_safetensors/load_state_dict 新增 broadcast 参数
- broadcast=True 时 rank-0 读取后 broadcast_object_list 分发到所有 rank
- load_state_dict 改为逐张量 broadcast,避免大模型 pickle 内存瓶颈
- 删除 _get_meta/_get_config wrapper,Checkpoint.load 直接调用 load_json
- 参数注解 str | Path 统一为 Union[str, Path]
This commit is contained in:
ViperEkura 2026-05-28 20:44:58 +08:00
parent c424dfc293
commit 6031020e37
1 changed files with 40 additions and 19 deletions

View File

@ -3,7 +3,7 @@ import json
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, Union
import safetensors.torch as st
import torch
@ -16,29 +16,50 @@ _CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: str | Path):
def save_safetensors(state_dict: dict, path: Union[str, Path]):
st.save_file(state_dict, str(path))
def load_safetensors(path: str | Path) -> dict:
return st.load_file(str(path))
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return st.load_file(str(path))
rank = get_rank()
if rank == 0:
state_dict = st.load_file(str(path))
else:
state_dict = {}
tmp = [state_dict]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_json(data: dict, path: str | Path):
def save_json(data: dict, path: Union[str, Path]):
with open(str(path), "w") as f:
json.dump(data, f, indent=2)
def load_json(path: str | Path) -> dict:
with open(str(path), "r") as f:
return json.load(f)
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
with open(str(path), "r") as f:
return json.load(f)
rank = get_rank()
if rank == 0:
with open(str(path), "r") as f:
data = json.load(f)
else:
data = {}
tmp = [data]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_torch(obj: Any, path: str | Path):
def save_torch(obj: Any, path: Union[str, Path]):
torch.save(obj, str(path))
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False)
@ -76,17 +97,18 @@ def load_model_config(save_directory: str) -> dict:
def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
path = Path(path)
if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE)
return load_safetensors(path)
rank = get_rank()
if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
specs: List[Tuple[str, List[int], str]] = [
state_dict = load_safetensors(path)
specs = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict)
]
@ -142,10 +164,9 @@ class Checkpoint:
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir)
meta = load_json(save_path / _META_FILE)
state_dict = _load_state_dict(save_path, broadcast=broadcast)
config_path = save_path / _CONFIG_FILE
config = load_json(config_path)
meta = load_json(save_path / _META_FILE, broadcast)
config = load_json(save_path / _CONFIG_FILE, broadcast)
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
extra = {}
for f in sorted(save_path.iterdir()):