fix : 修复存储层 rglob 死锁、DDP LOCAL_RANK 绑定

This commit is contained in:
yegroup001 2026-06-02 01:01:00 +08:00
parent 01ce1fb9e3
commit 746a1475b2
3 changed files with 33 additions and 8 deletions

View File

@ -18,6 +18,7 @@ Key properties:
""" """
import bisect import bisect
import glob
import json import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -113,13 +114,17 @@ def detect_format(load_path: str) -> str:
return "h5" return "h5"
raise ValueError(f"Unsupported file format: {suffix}") raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) h5_files = [
Path(p)
for pattern in ("*.h5", "*.hdf5")
for p in glob.glob(str(root / "**" / pattern), recursive=True)
]
if h5_files: if h5_files:
return "h5" return "h5"
bin_files = list(root.rglob("*.bin")) bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
if bin_files: if bin_files:
has_meta = (root / "meta.json").exists() or len( has_meta = (root / "meta.json").exists() or len(
list(root.rglob("meta.json")) [Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
) > 0 ) > 0
if has_meta: if has_meta:
return "bin" return "bin"
@ -250,7 +255,9 @@ class MmapStore(Store):
self._mmap_refs = [] self._mmap_refs = []
root = Path(path) root = Path(path)
all_raw: Dict[str, List[Tensor]] = {} all_raw: Dict[str, List[Tensor]] = {}
meta_paths = list(root.rglob("meta.json")) meta_paths = [
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
]
for meta_path in meta_paths: for meta_path in meta_paths:
raw = load_bin(str(meta_path.parent)) raw = load_bin(str(meta_path.parent))
for key, tensors in raw.items(): for key, tensors in raw.items():

View File

@ -2,6 +2,7 @@
import contextlib import contextlib
import logging import logging
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple from typing import Optional, Tuple
@ -181,7 +182,7 @@ class DDPExecutor(BaseExecutor):
if not self.use_distributed: if not self.use_distributed:
logger.warning("DDP backend selected but world_size=1, model not wrapped") logger.warning("DDP backend selected but world_size=1, model not wrapped")
return model return model
local_rank = get_rank() local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
model = DDP( model = DDP(
model, model,
device_ids=[local_rank], device_ids=[local_rank],

View File

@ -44,11 +44,12 @@ def setup_parallel(
yield None yield None
return return
device_id = torch.device(device_type, rank) local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank
device_id = torch.device(device_type, local_rank)
os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id) os.environ["LOCAL_DEVICE"] = str(device_id)
@ -126,7 +127,23 @@ def spawn_parallel_fn(
start_method: str = "spawn", start_method: str = "spawn",
**kwargs, **kwargs,
): ):
# clear environment variables # Multi-node support: if RANK env var is set, init process group
# and run function directly (no local spawn).
if "RANK" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=os.environ.get("MASTER_ADDR", master_addr),
master_port=os.environ.get("MASTER_PORT", master_port),
device_type=device_type,
):
func(**kwargs)
return
# clear environment variables (single-node path)
for key in [ for key in [
"MASTER_ADDR", "MASTER_ADDR",
"MASTER_PORT", "MASTER_PORT",