feat : load_json/load_safetensors 支持 broadcast,跨节点分布式加载
- load_json/load_safetensors/load_state_dict 新增 broadcast 参数 - broadcast=True 时 rank-0 读取后 broadcast_object_list 分发到所有 rank - load_state_dict 改为逐张量 broadcast,避免大模型 pickle 内存瓶颈 - 删除 _get_meta/_get_config wrapper,Checkpoint.load 直接调用 load_json - 参数注解 str | Path 统一为 Union[str, Path]
This commit is contained in:
parent
c424dfc293
commit
6031020e37
|
|
@ -3,7 +3,7 @@ import json
|
|||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
|
|
@ -16,29 +16,50 @@ _CONFIG_FILE = "config.json"
|
|||
_WEIGHTS_FILE = "model.safetensors"
|
||||
|
||||
|
||||
def save_safetensors(state_dict: dict, path: str | Path):
|
||||
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
||||
st.save_file(state_dict, str(path))
|
||||
|
||||
|
||||
def load_safetensors(path: str | Path) -> dict:
|
||||
return st.load_file(str(path))
|
||||
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return st.load_file(str(path))
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
state_dict = st.load_file(str(path))
|
||||
else:
|
||||
state_dict = {}
|
||||
tmp = [state_dict]
|
||||
dist.broadcast_object_list(tmp, src=0)
|
||||
return tmp[0]
|
||||
|
||||
|
||||
def save_json(data: dict, path: str | Path):
|
||||
def save_json(data: dict, path: Union[str, Path]):
|
||||
with open(str(path), "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def load_json(path: str | Path) -> dict:
|
||||
with open(str(path), "r") as f:
|
||||
return json.load(f)
|
||||
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
with open(str(path), "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
with open(str(path), "r") as f:
|
||||
data = json.load(f)
|
||||
else:
|
||||
data = {}
|
||||
tmp = [data]
|
||||
dist.broadcast_object_list(tmp, src=0)
|
||||
return tmp[0]
|
||||
|
||||
|
||||
def save_torch(obj: Any, path: str | Path):
|
||||
def save_torch(obj: Any, path: Union[str, Path]):
|
||||
torch.save(obj, str(path))
|
||||
|
||||
|
||||
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
|
||||
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||
|
||||
|
|
@ -76,17 +97,18 @@ def load_model_config(save_directory: str) -> dict:
|
|||
|
||||
|
||||
def load_model_weights(save_directory: str) -> dict:
|
||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
||||
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
||||
|
||||
|
||||
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
|
||||
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||
path = Path(path)
|
||||
if not broadcast or not dist.is_initialized():
|
||||
return load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
return load_safetensors(path)
|
||||
|
||||
rank = get_rank()
|
||||
if rank == 0:
|
||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
||||
specs: List[Tuple[str, List[int], str]] = [
|
||||
state_dict = load_safetensors(path)
|
||||
specs = [
|
||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||
for k in sorted(state_dict)
|
||||
]
|
||||
|
|
@ -142,10 +164,9 @@ class Checkpoint:
|
|||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||
save_path = Path(save_dir)
|
||||
|
||||
meta = load_json(save_path / _META_FILE)
|
||||
state_dict = _load_state_dict(save_path, broadcast=broadcast)
|
||||
config_path = save_path / _CONFIG_FILE
|
||||
config = load_json(config_path)
|
||||
meta = load_json(save_path / _META_FILE, broadcast)
|
||||
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
||||
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
||||
|
||||
extra = {}
|
||||
for f in sorted(save_path.iterdir()):
|
||||
|
|
|
|||
Loading…
Reference in New Issue