diff --git a/astrai/parallel/setup.py b/astrai/parallel/setup.py index 3debe97..df792f2 100644 --- a/astrai/parallel/setup.py +++ b/astrai/parallel/setup.py @@ -1,4 +1,5 @@ import os +from abc import ABC, abstractmethod from contextlib import contextmanager from functools import wraps from typing import Callable @@ -30,6 +31,7 @@ def get_rank() -> int: def setup_parallel( rank: int, world_size: int, + local_rank: int, backend: str = "nccl", master_addr: str = "localhost", master_port: str = "29500", @@ -41,10 +43,13 @@ def setup_parallel( return if world_size <= 1: + device_id = torch.device(device_type, local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = "1" + os.environ["LOCAL_DEVICE"] = str(device_id) yield None return - local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank device_id = torch.device(device_type, local_rank) os.environ["MASTER_ADDR"] = master_addr @@ -91,7 +96,7 @@ def only_on_rank(rank, sync=False): return decorator -def wrapper_spawn_func( +def _run_single_rank( rank: int, world_size: int, backend: str, @@ -101,20 +106,121 @@ def wrapper_spawn_func( func: Callable, kwargs: dict, ): - try: + with setup_parallel( + rank=rank, + world_size=world_size, + local_rank=rank, + backend=backend, + master_addr=master_addr, + master_port=master_port, + device_type=device_type, + ): + func(**kwargs) + + +class LaunchStrategy(ABC): + """Strategy for launching a function in a distributed context.""" + + def __init__( + self, + world_size: int, + backend: str, + master_addr: str, + master_port: str, + device_type: str, + start_method: str, + ): + self.world_size = world_size + self.backend = backend + self.master_addr = master_addr + self.master_port = master_port + self.device_type = device_type + self.start_method = start_method + + @abstractmethod + def launch(self, func: Callable, **kwargs): + raise NotImplementedError + + +class TorchrunStrategy(LaunchStrategy): + """External orchestrator (torchrun, SLURM, K8s) — env vars pre-set.""" + + def launch(self, func: Callable, **kwargs): + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) with setup_parallel( rank=rank, world_size=world_size, - backend=backend, - master_addr=master_addr, - master_port=master_port, - device_type=device_type, + local_rank=local_rank, + backend=self.backend, + master_addr=os.environ.get("MASTER_ADDR", self.master_addr), + master_port=os.environ.get("MASTER_PORT", self.master_port), + device_type=self.device_type, ): func(**kwargs) - except Exception as e: - print(f"Error in rank {rank}: {e}") - raise + +class LocalStrategy(LaunchStrategy): + """Local launcher — single-process or mp.start_processes.""" + + def _clear_env(self): + for key in ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_DEVICE", + ): + os.environ.pop(key, None) + + def launch(self, func: Callable, **kwargs): + self._clear_env() + + args = ( + self.world_size, + self.backend, + self.master_addr, + self.master_port, + self.device_type, + func, + kwargs, + ) + + if self.world_size == 1: + _run_single_rank(0, *args) + return + + ctx = mp.start_processes( + _run_single_rank, + args=args, + nprocs=self.world_size, + start_method=self.start_method, + join=False, + ) + try: + while not ctx.join(): + pass + except BaseException: + for p in ctx.processes: + p.terminate() + ctx.join() + raise + + +def _detect_launcher() -> str: + """Detect the distributed launcher from environment. + + Returns one of: "torchelastic", "torchrun", "external", "local". + """ + if dist.is_torchelastic_launched(): + return "torchelastic" + if "LOCAL_WORLD_SIZE" in os.environ: + return "torchrun" + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + return "external" + return "local" def spawn_parallel_fn( @@ -127,57 +233,13 @@ def spawn_parallel_fn( start_method: str = "spawn", **kwargs, ): - # Multi-node support: if RANK env var is set, init process group - # and run function directly (no local spawn). - if "RANK" in os.environ: - rank = int(os.environ["RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - with setup_parallel( - rank=rank, - world_size=world_size, - backend=backend, - master_addr=os.environ.get("MASTER_ADDR", master_addr), - master_port=os.environ.get("MASTER_PORT", master_port), - device_type=device_type, - ): - func(**kwargs) - return - - # clear environment variables (single-node path) - for key in [ - "MASTER_ADDR", - "MASTER_PORT", - "RANK", - "WORLD_SIZE", - "LOCAL_RANK", - "LOCAL_DEVICE", - ]: - if key in os.environ: - del os.environ[key] - - if world_size == 1: - device_id = torch.device(device_type, 0) - os.environ["LOCAL_RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" - os.environ["LOCAL_DEVICE"] = str(device_id) - - func(**kwargs) - return - - wrapper_spawn_func_args = ( - world_size, - backend, - master_addr, - master_port, - device_type, - func, - kwargs, - ) - - mp.start_processes( - wrapper_spawn_func, - args=wrapper_spawn_func_args, - nprocs=world_size, - start_method=start_method, - join=True, - ) + launcher = _detect_launcher() + if launcher in ("torchelastic", "torchrun", "external"): + strategy = TorchrunStrategy( + world_size, backend, master_addr, master_port, device_type, start_method + ) + else: + strategy = LocalStrategy( + world_size, backend, master_addr, master_port, device_type, start_method + ) + strategy.launch(func, **kwargs) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 376a709..5a47c0e 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -315,6 +315,8 @@ def train( }, ) + grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else [] + train_config = TrainConfig( model_fn=model_fn, strategy=train_type, @@ -332,9 +334,6 @@ def train( random_seed=random_seed, num_workers=num_workers, pin_memory=pin_memory, - gradient_checkpointing_modules=[DecoderBlock] - if gradient_checkpointing - else [], nprocs=nprocs, backend=backend, master_addr=master_addr, @@ -342,6 +341,7 @@ def train( parallel_mode=parallel_mode, device_type=device_type, start_method=start_method, + gradient_checkpointing_modules=grad_ckpt_modules, executor_kwargs=executor_kwargs, extra_kwargs=strategy_kwargs, )