feat : load_json/load_safetensors 支持 broadcast,跨节点分布式加载
- load_json/load_safetensors/load_state_dict 新增 broadcast 参数 - broadcast=True 时 rank-0 读取后 broadcast_object_list 分发到所有 rank - load_state_dict 改为逐张量 broadcast,避免大模型 pickle 内存瓶颈 - 删除 _get_meta/_get_config wrapper,Checkpoint.load 直接调用 load_json - 参数注解 str | Path 统一为 Union[str, Path]
This commit is contained in:
parent
c424dfc293
commit
6031020e37
|
|
@ -3,7 +3,7 @@ import json
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Tuple
|
from typing import Any, Dict, Union
|
||||||
|
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -16,29 +16,50 @@ _CONFIG_FILE = "config.json"
|
||||||
_WEIGHTS_FILE = "model.safetensors"
|
_WEIGHTS_FILE = "model.safetensors"
|
||||||
|
|
||||||
|
|
||||||
def save_safetensors(state_dict: dict, path: str | Path):
|
def save_safetensors(state_dict: dict, path: Union[str, Path]):
|
||||||
st.save_file(state_dict, str(path))
|
st.save_file(state_dict, str(path))
|
||||||
|
|
||||||
|
|
||||||
def load_safetensors(path: str | Path) -> dict:
|
def load_safetensors(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
return st.load_file(str(path))
|
return st.load_file(str(path))
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
state_dict = st.load_file(str(path))
|
||||||
|
else:
|
||||||
|
state_dict = {}
|
||||||
|
tmp = [state_dict]
|
||||||
|
dist.broadcast_object_list(tmp, src=0)
|
||||||
|
return tmp[0]
|
||||||
|
|
||||||
def save_json(data: dict, path: str | Path):
|
|
||||||
|
def save_json(data: dict, path: Union[str, Path]):
|
||||||
with open(str(path), "w") as f:
|
with open(str(path), "w") as f:
|
||||||
json.dump(data, f, indent=2)
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def load_json(path: str | Path) -> dict:
|
def load_json(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
if not broadcast or not dist.is_initialized():
|
||||||
with open(str(path), "r") as f:
|
with open(str(path), "r") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
|
rank = get_rank()
|
||||||
|
if rank == 0:
|
||||||
|
with open(str(path), "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
tmp = [data]
|
||||||
|
dist.broadcast_object_list(tmp, src=0)
|
||||||
|
return tmp[0]
|
||||||
|
|
||||||
def save_torch(obj: Any, path: str | Path):
|
|
||||||
|
def save_torch(obj: Any, path: Union[str, Path]):
|
||||||
torch.save(obj, str(path))
|
torch.save(obj, str(path))
|
||||||
|
|
||||||
|
|
||||||
def load_torch(path: str | Path, broadcast: bool = False) -> Any:
|
def load_torch(path: Union[str, Path], broadcast: bool = False) -> Any:
|
||||||
if not broadcast or not dist.is_initialized():
|
if not broadcast or not dist.is_initialized():
|
||||||
return torch.load(str(path), map_location="cpu", weights_only=False)
|
return torch.load(str(path), map_location="cpu", weights_only=False)
|
||||||
|
|
||||||
|
|
@ -76,17 +97,18 @@ def load_model_config(save_directory: str) -> dict:
|
||||||
|
|
||||||
|
|
||||||
def load_model_weights(save_directory: str) -> dict:
|
def load_model_weights(save_directory: str) -> dict:
|
||||||
return load_safetensors(Path(save_directory) / _WEIGHTS_FILE)
|
return load_state_dict(Path(save_directory) / _WEIGHTS_FILE)
|
||||||
|
|
||||||
|
|
||||||
def _load_state_dict(save_path: Path, broadcast: bool = False) -> dict:
|
def load_state_dict(path: Union[str, Path], broadcast: bool = False) -> dict:
|
||||||
|
path = Path(path)
|
||||||
if not broadcast or not dist.is_initialized():
|
if not broadcast or not dist.is_initialized():
|
||||||
return load_safetensors(save_path / _WEIGHTS_FILE)
|
return load_safetensors(path)
|
||||||
|
|
||||||
rank = get_rank()
|
rank = get_rank()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
state_dict = load_safetensors(save_path / _WEIGHTS_FILE)
|
state_dict = load_safetensors(path)
|
||||||
specs: List[Tuple[str, List[int], str]] = [
|
specs = [
|
||||||
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
(k, list(state_dict[k].shape), str(state_dict[k].dtype).split(".")[-1])
|
||||||
for k in sorted(state_dict)
|
for k in sorted(state_dict)
|
||||||
]
|
]
|
||||||
|
|
@ -142,10 +164,9 @@ class Checkpoint:
|
||||||
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
def load(cls, save_dir: str, broadcast: bool = False) -> "Checkpoint":
|
||||||
save_path = Path(save_dir)
|
save_path = Path(save_dir)
|
||||||
|
|
||||||
meta = load_json(save_path / _META_FILE)
|
meta = load_json(save_path / _META_FILE, broadcast)
|
||||||
state_dict = _load_state_dict(save_path, broadcast=broadcast)
|
config = load_json(save_path / _CONFIG_FILE, broadcast)
|
||||||
config_path = save_path / _CONFIG_FILE
|
state_dict = load_state_dict(save_path / _WEIGHTS_FILE, broadcast=broadcast)
|
||||||
config = load_json(config_path)
|
|
||||||
|
|
||||||
extra = {}
|
extra = {}
|
||||||
for f in sorted(save_path.iterdir()):
|
for f in sorted(save_path.iterdir()):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue