refactor : 并行启动 Strategy 模式重构,local_rank 解耦
- setup_parallel 接收 local_rank 参数,不再读环境变量推导 - TorchrunStrategy 从 env 读取 LOCAL_RANK,LocalStrategy 用 rank - _detect_launcher() 分级检测替代内联 RANK 检查 - _run_single_rank 统一入口,消除 _run_single/_run_multi 重复 - 优雅退出:except BaseException 终止子进程并 re-join - gradient_checkpointing_modules 判定提取到外部变量
This commit is contained in:
parent
d6899100ac
commit
9b416c1bbb
|
|
@ -1,4 +1,5 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
|
@ -30,6 +31,7 @@ def get_rank() -> int:
|
|||
def setup_parallel(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
local_rank: int,
|
||||
backend: str = "nccl",
|
||||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
|
|
@ -41,10 +43,13 @@ def setup_parallel(
|
|||
return
|
||||
|
||||
if world_size <= 1:
|
||||
device_id = torch.device(device_type, local_rank)
|
||||
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
yield None
|
||||
return
|
||||
|
||||
local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank
|
||||
device_id = torch.device(device_type, local_rank)
|
||||
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
|
|
@ -91,7 +96,7 @@ def only_on_rank(rank, sync=False):
|
|||
return decorator
|
||||
|
||||
|
||||
def wrapper_spawn_func(
|
||||
def _run_single_rank(
|
||||
rank: int,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
|
|
@ -101,20 +106,121 @@ def wrapper_spawn_func(
|
|||
func: Callable,
|
||||
kwargs: dict,
|
||||
):
|
||||
try:
|
||||
with setup_parallel(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
local_rank=rank,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
device_type=device_type,
|
||||
):
|
||||
func(**kwargs)
|
||||
|
||||
|
||||
class LaunchStrategy(ABC):
|
||||
"""Strategy for launching a function in a distributed context."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_size: int,
|
||||
backend: str,
|
||||
master_addr: str,
|
||||
master_port: str,
|
||||
device_type: str,
|
||||
start_method: str,
|
||||
):
|
||||
self.world_size = world_size
|
||||
self.backend = backend
|
||||
self.master_addr = master_addr
|
||||
self.master_port = master_port
|
||||
self.device_type = device_type
|
||||
self.start_method = start_method
|
||||
|
||||
@abstractmethod
|
||||
def launch(self, func: Callable, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TorchrunStrategy(LaunchStrategy):
|
||||
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
|
||||
|
||||
def launch(self, func: Callable, **kwargs):
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
||||
with setup_parallel(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
device_type=device_type,
|
||||
local_rank=local_rank,
|
||||
backend=self.backend,
|
||||
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
|
||||
master_port=os.environ.get("MASTER_PORT", self.master_port),
|
||||
device_type=self.device_type,
|
||||
):
|
||||
func(**kwargs)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in rank {rank}: {e}")
|
||||
raise
|
||||
|
||||
class LocalStrategy(LaunchStrategy):
|
||||
"""Local launcher — single-process or mp.start_processes."""
|
||||
|
||||
def _clear_env(self):
|
||||
for key in (
|
||||
"MASTER_ADDR",
|
||||
"MASTER_PORT",
|
||||
"RANK",
|
||||
"WORLD_SIZE",
|
||||
"LOCAL_RANK",
|
||||
"LOCAL_DEVICE",
|
||||
):
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def launch(self, func: Callable, **kwargs):
|
||||
self._clear_env()
|
||||
|
||||
args = (
|
||||
self.world_size,
|
||||
self.backend,
|
||||
self.master_addr,
|
||||
self.master_port,
|
||||
self.device_type,
|
||||
func,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
if self.world_size == 1:
|
||||
_run_single_rank(0, *args)
|
||||
return
|
||||
|
||||
ctx = mp.start_processes(
|
||||
_run_single_rank,
|
||||
args=args,
|
||||
nprocs=self.world_size,
|
||||
start_method=self.start_method,
|
||||
join=False,
|
||||
)
|
||||
try:
|
||||
while not ctx.join():
|
||||
pass
|
||||
except BaseException:
|
||||
for p in ctx.processes:
|
||||
p.terminate()
|
||||
ctx.join()
|
||||
raise
|
||||
|
||||
|
||||
def _detect_launcher() -> str:
|
||||
"""Detect the distributed launcher from environment.
|
||||
|
||||
Returns one of: "torchelastic", "torchrun", "external", "local".
|
||||
"""
|
||||
if dist.is_torchelastic_launched():
|
||||
return "torchelastic"
|
||||
if "LOCAL_WORLD_SIZE" in os.environ:
|
||||
return "torchrun"
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
return "external"
|
||||
return "local"
|
||||
|
||||
|
||||
def spawn_parallel_fn(
|
||||
|
|
@ -127,57 +233,13 @@ def spawn_parallel_fn(
|
|||
start_method: str = "spawn",
|
||||
**kwargs,
|
||||
):
|
||||
# Multi-node support: if RANK env var is set, init process group
|
||||
# and run function directly (no local spawn).
|
||||
if "RANK" in os.environ:
|
||||
rank = int(os.environ["RANK"])
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
with setup_parallel(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
master_addr=os.environ.get("MASTER_ADDR", master_addr),
|
||||
master_port=os.environ.get("MASTER_PORT", master_port),
|
||||
device_type=device_type,
|
||||
):
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
# clear environment variables (single-node path)
|
||||
for key in [
|
||||
"MASTER_ADDR",
|
||||
"MASTER_PORT",
|
||||
"RANK",
|
||||
"WORLD_SIZE",
|
||||
"LOCAL_RANK",
|
||||
"LOCAL_DEVICE",
|
||||
]:
|
||||
if key in os.environ:
|
||||
del os.environ[key]
|
||||
|
||||
if world_size == 1:
|
||||
device_id = torch.device(device_type, 0)
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
func(**kwargs)
|
||||
return
|
||||
|
||||
wrapper_spawn_func_args = (
|
||||
world_size,
|
||||
backend,
|
||||
master_addr,
|
||||
master_port,
|
||||
device_type,
|
||||
func,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
mp.start_processes(
|
||||
wrapper_spawn_func,
|
||||
args=wrapper_spawn_func_args,
|
||||
nprocs=world_size,
|
||||
start_method=start_method,
|
||||
join=True,
|
||||
)
|
||||
launcher = _detect_launcher()
|
||||
if launcher in ("torchelastic", "torchrun", "external"):
|
||||
strategy = TorchrunStrategy(
|
||||
world_size, backend, master_addr, master_port, device_type, start_method
|
||||
)
|
||||
else:
|
||||
strategy = LocalStrategy(
|
||||
world_size, backend, master_addr, master_port, device_type, start_method
|
||||
)
|
||||
strategy.launch(func, **kwargs)
|
||||
|
|
|
|||
|
|
@ -315,6 +315,8 @@ def train(
|
|||
},
|
||||
)
|
||||
|
||||
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
|
||||
|
||||
train_config = TrainConfig(
|
||||
model_fn=model_fn,
|
||||
strategy=train_type,
|
||||
|
|
@ -332,9 +334,6 @@ def train(
|
|||
random_seed=random_seed,
|
||||
num_workers=num_workers,
|
||||
pin_memory=pin_memory,
|
||||
gradient_checkpointing_modules=[DecoderBlock]
|
||||
if gradient_checkpointing
|
||||
else [],
|
||||
nprocs=nprocs,
|
||||
backend=backend,
|
||||
master_addr=master_addr,
|
||||
|
|
@ -342,6 +341,7 @@ def train(
|
|||
parallel_mode=parallel_mode,
|
||||
device_type=device_type,
|
||||
start_method=start_method,
|
||||
gradient_checkpointing_modules=grad_ckpt_modules,
|
||||
executor_kwargs=executor_kwargs,
|
||||
extra_kwargs=strategy_kwargs,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue