From 746a1475b231f85e11a46678fd09b84c71211cda Mon Sep 17 00:00:00 2001 From: yegroup001 Date: Tue, 2 Jun 2026 01:01:00 +0800 Subject: [PATCH] =?UTF-8?q?fix=20:=20=E4=BF=AE=E5=A4=8D=E5=AD=98=E5=82=A8?= =?UTF-8?q?=E5=B1=82=20rglob=20=E6=AD=BB=E9=94=81=E3=80=81DDP=20LOCAL=5FRA?= =?UTF-8?q?NK=20=E7=BB=91=E5=AE=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/dataset/storage.py | 15 +++++++++++---- astrai/parallel/executor.py | 3 ++- astrai/parallel/setup.py | 23 ++++++++++++++++++++--- 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index cf9a7ed..72c4667 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -18,6 +18,7 @@ Key properties: """ import bisect +import glob import json import os from abc import ABC, abstractmethod @@ -113,13 +114,17 @@ def detect_format(load_path: str) -> str: return "h5" raise ValueError(f"Unsupported file format: {suffix}") - h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) + h5_files = [ + Path(p) + for pattern in ("*.h5", "*.hdf5") + for p in glob.glob(str(root / "**" / pattern), recursive=True) + ] if h5_files: return "h5" - bin_files = list(root.rglob("*.bin")) + bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)] if bin_files: has_meta = (root / "meta.json").exists() or len( - list(root.rglob("meta.json")) + [Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)] ) > 0 if has_meta: return "bin" @@ -250,7 +255,9 @@ class MmapStore(Store): self._mmap_refs = [] root = Path(path) all_raw: Dict[str, List[Tensor]] = {} - meta_paths = list(root.rglob("meta.json")) + meta_paths = [ + Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True) + ] for meta_path in meta_paths: raw = load_bin(str(meta_path.parent)) for key, tensors in raw.items(): diff --git a/astrai/parallel/executor.py b/astrai/parallel/executor.py index ce2d935..4987823 100644 --- a/astrai/parallel/executor.py +++ b/astrai/parallel/executor.py @@ -2,6 +2,7 @@ import contextlib import logging +import os from contextlib import contextmanager from typing import Optional, Tuple @@ -181,7 +182,7 @@ class DDPExecutor(BaseExecutor): if not self.use_distributed: logger.warning("DDP backend selected but world_size=1, model not wrapped") return model - local_rank = get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", get_rank())) model = DDP( model, device_ids=[local_rank], diff --git a/astrai/parallel/setup.py b/astrai/parallel/setup.py index b879347..3debe97 100644 --- a/astrai/parallel/setup.py +++ b/astrai/parallel/setup.py @@ -44,11 +44,12 @@ def setup_parallel( yield None return - device_id = torch.device(device_type, rank) + local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank + device_id = torch.device(device_type, local_rank) os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port - os.environ["LOCAL_RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(local_rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_DEVICE"] = str(device_id) @@ -126,7 +127,23 @@ def spawn_parallel_fn( start_method: str = "spawn", **kwargs, ): - # clear environment variables + # Multi-node support: if RANK env var is set, init process group + # and run function directly (no local spawn). + if "RANK" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + with setup_parallel( + rank=rank, + world_size=world_size, + backend=backend, + master_addr=os.environ.get("MASTER_ADDR", master_addr), + master_port=os.environ.get("MASTER_PORT", master_port), + device_type=device_type, + ): + func(**kwargs) + return + + # clear environment variables (single-node path) for key in [ "MASTER_ADDR", "MASTER_PORT",