feat : load_json/load_safetensors 支持 broadcast,跨节点分布式加载

- load_json/load_safetensors/load_state_dict 新增 broadcast 参数
- broadcast=True 时 rank-0 读取后 broadcast_object_list 分发到所有 rank
- load_state_dict 改为逐张量 broadcast,避免大模型 pickle 内存瓶颈
- 删除 _get_meta/_get_config wrapper,Checkpoint.load 直接调用 load_json
- 参数注解 str | Path 统一为 Union[str, Path]
This commit is contained in:
ViperEkura 2026-05-28 20:44:58 +08:00
parent c424dfc293
commit 6031020e37
1 changed files with 40 additions and 19 deletions

View File

@ -3,7 +3,7 @@ import json
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple from typing import Any, Dict, Union
import safetensors.torch as st import safetensors.torch as st
import torch import torch
@ -16,29 +16,50 @@ _CONFIG_FILE = "config.json"
_WEIGHTS_FILE = "model.safetensors" _WEIGHTS_FILE = "model.safetensors"
def save_safetensors(state_dict: dict, path: str | Path): def save_safetensors(state_dict: dict, path: Union[str, Path]):
st.save_file(state_dict, str(path)) st.save_file(state_dict, str(path))
def load_safetensors(path: str | Path) -> dict: def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
return st.load_file(str(path)) return st.load_file(str(path))
rank = get_rank()
if rank == 0:
state_dict = st.load_file(str(path))
else:
state_dict = {}
tmp = [state_dict]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_json(data: dict, path: str | Path):
def save_json(data: dict, path: Union[str, Path]):
with open(str(path), "w") as f: with open(str(path), "w") as f:
json.dump(data, f, indent=2) json.dump(data, f, indent=2)
def load_json(path: str | Path) -> dict: def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
if not broadcast or not dist.is_initialized():
with open(str(path), "r") as f: with open(str(path), "r") as f:
return json.load(f) return json.load(f)
rank = get_rank()
if rank == 0:
with open(str(path), "r") as f:
data = json.load(f)
else:
data = {}
tmp = [data]
dist.broadcast_object_list(tmp, src=0)
return tmp[0]
def save_torch(obj: Any, path: str | Path):
def save_torch(obj: Any, path: Union[str, Path]):
torch.save(obj, str(path)) torch.save(obj, str(path))
def load_torch(path: str | Path, broadcast: bool = False) -> Any: def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
if not broadcast or not dist.is_initialized(): if not broadcast or not dist.is_initialized():
return torch.load(str(path), map_location="cpu", weights_only=False) return torch.load(str(path), map_location="cpu", weights_only=False)
@ -76,17 +97,18 @@ def load_model_config(save_directory: str) -> dict:
def load_model_weights(save_directory: str) -> dict: def load_model_weights(save_directory: str) -> dict:
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE) return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict: def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
path = Path(path)
if not broadcast or not dist.is_initialized(): if not broadcast or not dist.is_initialized():
return load_safetensors(save_path / _WEIGHTS_FILE) return load_safetensors(path)
rank = get_rank() rank = get_rank()
if rank == 0: if rank == 0:
state_dict = load_safetensors(save_path / _WEIGHTS_FILE) state_dict = load_safetensors(path)
specs: List[Tuple[str, List[int], str]] = [ specs = [
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1]) (k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
for k in sorted(state_dict) for k in sorted(state_dict)
] ]
@ -142,10 +164,9 @@ class Checkpoint:
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint": def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
save_path = Path(save_dir) save_path = Path(save_dir)
meta = load_json(save_path / _META_FILE) meta = load_json(save_path / _META_FILE, broadcast)
state_dict = _load_state_dict(save_path, broadcast=broadcast) config = load_json(save_path / _CONFIG_FILE, broadcast)
config_path = save_path / _CONFIG_FILE state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
config = load_json(config_path)
extra = {} extra = {}
for f in sorted(save_path.iterdir()): for f in sorted(save_path.iterdir()):