diff --git a/astrai/dataset/storage.py b/astrai/dataset/storage.py index cf9a7ed..72c4667 100644 --- a/astrai/dataset/storage.py +++ b/astrai/dataset/storage.py @@ -18,6 +18,7 @@ Key properties: """ import bisect +import glob import json import os from abc import ABC, abstractmethod @@ -113,13 +114,17 @@ def detect_format(load_path: str) -> str: return "h5" raise ValueError(f"Unsupported file format: {suffix}") - h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) + h5_files = [ + Path(p) + for pattern in ("*.h5", "*.hdf5") + for p in glob.glob(str(root / "**" / pattern), recursive=True) + ] if h5_files: return "h5" - bin_files = list(root.rglob("*.bin")) + bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)] if bin_files: has_meta = (root / "meta.json").exists() or len( - list(root.rglob("meta.json")) + [Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)] ) > 0 if has_meta: return "bin" @@ -250,7 +255,9 @@ class MmapStore(Store): self._mmap_refs = [] root = Path(path) all_raw: Dict[str, List[Tensor]] = {} - meta_paths = list(root.rglob("meta.json")) + meta_paths = [ + Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True) + ] for meta_path in meta_paths: raw = load_bin(str(meta_path.parent)) for key, tensors in raw.items(): diff --git a/astrai/parallel/executor.py b/astrai/parallel/executor.py index ce2d935..4987823 100644 --- a/astrai/parallel/executor.py +++ b/astrai/parallel/executor.py @@ -2,6 +2,7 @@ import contextlib import logging +import os from contextlib import contextmanager from typing import Optional, Tuple @@ -181,7 +182,7 @@ class DDPExecutor(BaseExecutor): if not self.use_distributed: logger.warning("DDP backend selected but world_size=1, model not wrapped") return model - local_rank = get_rank() + local_rank = int(os.environ.get("LOCAL_RANK", get_rank())) model = DDP( model, device_ids=[local_rank], diff --git a/astrai/parallel/setup.py b/astrai/parallel/setup.py index b879347..3debe97 100644 --- a/astrai/parallel/setup.py +++ b/astrai/parallel/setup.py @@ -44,11 +44,12 @@ def setup_parallel( yield None return - device_id = torch.device(device_type, rank) + local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank + device_id = torch.device(device_type, local_rank) os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port - os.environ["LOCAL_RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(local_rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_DEVICE"] = str(device_id) @@ -126,7 +127,23 @@ def spawn_parallel_fn( start_method: str = "spawn", **kwargs, ): - # clear environment variables + # Multi-node support: if RANK env var is set, init process group + # and run function directly (no local spawn). + if "RANK" in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + with setup_parallel( + rank=rank, + world_size=world_size, + backend=backend, + master_addr=os.environ.get("MASTER_ADDR", master_addr), + master_port=os.environ.get("MASTER_PORT", master_port), + device_type=device_type, + ): + func(**kwargs) + return + + # clear environment variables (single-node path) for key in [ "MASTER_ADDR", "MASTER_PORT", diff --git a/scripts/tools/train.py b/scripts/tools/train.py index e5be30e..376a709 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -8,6 +8,7 @@ import torch.optim as optim from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.dataset import DatasetFactory from astrai.model import AutoRegressiveLM +from astrai.model.components.decoder_block import DecoderBlock from astrai.trainer import SchedulerFactory, Trainer @@ -115,6 +116,12 @@ def parse_args() -> argparse.Namespace: default=0.05, help="cross_entropy function label smoothing parameter", ) + parser.add_argument( + "--gradient_checkpointing", + action=argparse.BooleanOptionalAction, + default=False, + help="Enable activation checkpointing for DecoderBlock modules.", + ) parser.add_argument( "--ckpt_interval", @@ -141,6 +148,24 @@ def parse_args() -> argparse.Namespace: "--start_batch", type=int, default=0, help="Start batch for training." ) + parser.add_argument( + "--master_addr", + type=str, + default="localhost", + help="Master node address for distributed training.", + ) + parser.add_argument( + "--master_port", + type=str, + default="29500", + help="Master node port for distributed training.", + ) + parser.add_argument( + "--backend", + type=str, + default="nccl", + help="Distributed training backend.", + ) parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument( "--parallel_mode", @@ -222,11 +247,15 @@ def train( random_seed: int, num_workers: int, pin_memory: bool, + gradient_checkpointing: bool, window_size: int, stride: int, nprocs: int, parallel_mode: str, device_type: str, + backend: str, + master_addr: str, + master_port: str, start_method: str, ): assert train_type in ["seq", "sft", "dpo", "grpo"] @@ -303,7 +332,13 @@ def train( random_seed=random_seed, num_workers=num_workers, pin_memory=pin_memory, + gradient_checkpointing_modules=[DecoderBlock] + if gradient_checkpointing + else [], nprocs=nprocs, + backend=backend, + master_addr=master_addr, + master_port=master_port, parallel_mode=parallel_mode, device_type=device_type, start_method=start_method,