246 lines
6.2 KiB
Python
246 lines
6.2 KiB
Python
import os
|
|
from abc import ABC, abstractmethod
|
|
from contextlib import contextmanager
|
|
from functools import wraps
|
|
from typing import Callable
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
|
|
def get_current_device():
|
|
return os.environ["LOCAL_DEVICE"]
|
|
|
|
|
|
def get_world_size() -> int:
|
|
if dist.is_available() and dist.is_initialized():
|
|
return dist.get_world_size()
|
|
else:
|
|
return 1
|
|
|
|
|
|
def get_rank() -> int:
|
|
if dist.is_available() and dist.is_initialized():
|
|
return dist.get_rank()
|
|
else:
|
|
return 0
|
|
|
|
|
|
@contextmanager
|
|
def setup_parallel(
|
|
rank: int,
|
|
world_size: int,
|
|
local_rank: int,
|
|
backend: str = "nccl",
|
|
master_addr: str = "localhost",
|
|
master_port: str = "29500",
|
|
device_type: str = "cuda",
|
|
):
|
|
|
|
if dist.is_available() and dist.is_initialized():
|
|
yield dist.group.WORLD
|
|
return
|
|
|
|
if world_size <= 1:
|
|
device_id = torch.device(device_type, local_rank)
|
|
os.environ["LOCAL_RANK"] = str(local_rank)
|
|
os.environ["WORLD_SIZE"] = "1"
|
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
|
yield None
|
|
return
|
|
|
|
device_id = torch.device(device_type, local_rank)
|
|
|
|
os.environ["MASTER_ADDR"] = master_addr
|
|
os.environ["MASTER_PORT"] = master_port
|
|
os.environ["LOCAL_RANK"] = str(local_rank)
|
|
os.environ["WORLD_SIZE"] = str(world_size)
|
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
|
|
|
dist.init_process_group(
|
|
rank=rank, world_size=world_size, backend=backend, device_id=device_id
|
|
)
|
|
|
|
try:
|
|
if backend == "nccl" and torch.cuda.is_available():
|
|
torch.cuda.set_device(device_id)
|
|
elif backend == "ccl" and hasattr(torch, "xpu") and torch.xpu.is_available():
|
|
torch.xpu.set_device(device_id)
|
|
|
|
yield dist.group.WORLD
|
|
finally:
|
|
if dist.is_initialized():
|
|
dist.destroy_process_group()
|
|
|
|
|
|
def only_on_rank(rank, sync=False):
|
|
"""
|
|
decorator to run a function only on a specific rank.
|
|
"""
|
|
|
|
def decorator(func):
|
|
@wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
ret_args = None
|
|
if get_rank() == rank:
|
|
ret_args = func(*args, **kwargs)
|
|
|
|
if sync and dist.is_available() and dist.is_initialized():
|
|
dist.barrier()
|
|
|
|
return ret_args
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def _run_single_rank(
|
|
rank: int,
|
|
world_size: int,
|
|
backend: str,
|
|
master_addr: str,
|
|
master_port: str,
|
|
device_type: str,
|
|
func: Callable,
|
|
kwargs: dict,
|
|
):
|
|
with setup_parallel(
|
|
rank=rank,
|
|
world_size=world_size,
|
|
local_rank=rank,
|
|
backend=backend,
|
|
master_addr=master_addr,
|
|
master_port=master_port,
|
|
device_type=device_type,
|
|
):
|
|
func(**kwargs)
|
|
|
|
|
|
class LaunchStrategy(ABC):
|
|
"""Strategy for launching a function in a distributed context."""
|
|
|
|
def __init__(
|
|
self,
|
|
world_size: int,
|
|
backend: str,
|
|
master_addr: str,
|
|
master_port: str,
|
|
device_type: str,
|
|
start_method: str,
|
|
):
|
|
self.world_size = world_size
|
|
self.backend = backend
|
|
self.master_addr = master_addr
|
|
self.master_port = master_port
|
|
self.device_type = device_type
|
|
self.start_method = start_method
|
|
|
|
@abstractmethod
|
|
def launch(self, func: Callable, **kwargs):
|
|
raise NotImplementedError
|
|
|
|
|
|
class TorchrunStrategy(LaunchStrategy):
|
|
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
|
|
|
|
def launch(self, func: Callable, **kwargs):
|
|
rank = int(os.environ["RANK"])
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
|
with setup_parallel(
|
|
rank=rank,
|
|
world_size=world_size,
|
|
local_rank=local_rank,
|
|
backend=self.backend,
|
|
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
|
|
master_port=os.environ.get("MASTER_PORT", self.master_port),
|
|
device_type=self.device_type,
|
|
):
|
|
func(**kwargs)
|
|
|
|
|
|
class LocalStrategy(LaunchStrategy):
|
|
"""Local launcher — single-process or mp.start_processes."""
|
|
|
|
def _clear_env(self):
|
|
for key in (
|
|
"MASTER_ADDR",
|
|
"MASTER_PORT",
|
|
"RANK",
|
|
"WORLD_SIZE",
|
|
"LOCAL_RANK",
|
|
"LOCAL_DEVICE",
|
|
):
|
|
os.environ.pop(key, None)
|
|
|
|
def launch(self, func: Callable, **kwargs):
|
|
self._clear_env()
|
|
|
|
args = (
|
|
self.world_size,
|
|
self.backend,
|
|
self.master_addr,
|
|
self.master_port,
|
|
self.device_type,
|
|
func,
|
|
kwargs,
|
|
)
|
|
|
|
if self.world_size == 1:
|
|
_run_single_rank(0, *args)
|
|
return
|
|
|
|
ctx = mp.start_processes(
|
|
_run_single_rank,
|
|
args=args,
|
|
nprocs=self.world_size,
|
|
start_method=self.start_method,
|
|
join=False,
|
|
)
|
|
try:
|
|
while not ctx.join():
|
|
pass
|
|
except BaseException:
|
|
for p in ctx.processes:
|
|
p.terminate()
|
|
ctx.join()
|
|
raise
|
|
|
|
|
|
def _detect_launcher() -> str:
|
|
"""Detect the distributed launcher from environment.
|
|
|
|
Returns one of: "torchelastic", "torchrun", "external", "local".
|
|
"""
|
|
if dist.is_torchelastic_launched():
|
|
return "torchelastic"
|
|
if "LOCAL_WORLD_SIZE" in os.environ:
|
|
return "torchrun"
|
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
|
return "external"
|
|
return "local"
|
|
|
|
|
|
def spawn_parallel_fn(
|
|
func: Callable,
|
|
world_size: int,
|
|
backend: str = "nccl",
|
|
master_addr: str = "localhost",
|
|
master_port: str = "29500",
|
|
device_type: str = "cuda",
|
|
start_method: str = "spawn",
|
|
**kwargs,
|
|
):
|
|
launcher = _detect_launcher()
|
|
if launcher in ("torchelastic", "torchrun", "external"):
|
|
strategy = TorchrunStrategy(
|
|
world_size, backend, master_addr, master_port, device_type, start_method
|
|
)
|
|
else:
|
|
strategy = LocalStrategy(
|
|
world_size, backend, master_addr, master_port, device_type, start_method
|
|
)
|
|
strategy.launch(func, **kwargs)
|