refactor : 并行启动 Strategy 模式重构,local_rank 解耦

- setup_parallel 接收 local_rank 参数,不再读环境变量推导
- TorchrunStrategy 从 env 读取 LOCAL_RANK,LocalStrategy 用 rank
- _detect_launcher() 分级检测替代内联 RANK 检查
- _run_single_rank 统一入口,消除 _run_single/_run_multi 重复
- 优雅退出:except BaseException 终止子进程并 re-join
- gradient_checkpointing_modules 判定提取到外部变量
This commit is contained in:
ViperEkura 2026-06-02 11:22:24 +08:00
parent d6899100ac
commit 9b416c1bbb
2 changed files with 129 additions and 67 deletions

View File

@ -1,4 +1,5 @@
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from functools import wraps
from typing import Callable
@ -30,6 +31,7 @@ def get_rank() -> int:
def setup_parallel(
rank: int,
world_size: int,
local_rank: int,
backend: str = "nccl",
master_addr: str = "localhost",
master_port: str = "29500",
@ -41,10 +43,13 @@ def setup_parallel(
return
if world_size <= 1:
device_id = torch.device(device_type, local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
yield None
return
local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank
device_id = torch.device(device_type, local_rank)
os.environ["MASTER_ADDR"] = master_addr
@ -91,7 +96,7 @@ def only_on_rank(rank, sync=False):
return decorator
def wrapper_spawn_func(
def _run_single_rank(
rank: int,
world_size: int,
backend: str,
@ -101,10 +106,10 @@ def wrapper_spawn_func(
func: Callable,
kwargs: dict,
):
try:
with setup_parallel(
rank=rank,
world_size=world_size,
local_rank=rank,
backend=backend,
master_addr=master_addr,
master_port=master_port,
@ -112,11 +117,112 @@ def wrapper_spawn_func(
):
func(**kwargs)
except Exception as e:
print(f"Error in rank {rank}: {e}")
class LaunchStrategy(ABC):
"""Strategy for launching a function in a distributed context."""
def __init__(
self,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
device_type: str,
start_method: str,
):
self.world_size = world_size
self.backend = backend
self.master_addr = master_addr
self.master_port = master_port
self.device_type = device_type
self.start_method = start_method
@abstractmethod
def launch(self, func: Callable, **kwargs):
raise NotImplementedError
class TorchrunStrategy(LaunchStrategy):
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
def launch(self, func: Callable, **kwargs):
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
with setup_parallel(
rank=rank,
world_size=world_size,
local_rank=local_rank,
backend=self.backend,
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
master_port=os.environ.get("MASTER_PORT", self.master_port),
device_type=self.device_type,
):
func(**kwargs)
class LocalStrategy(LaunchStrategy):
"""Local launcher — single-process or mp.start_processes."""
def _clear_env(self):
for key in (
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
):
os.environ.pop(key, None)
def launch(self, func: Callable, **kwargs):
self._clear_env()
args = (
self.world_size,
self.backend,
self.master_addr,
self.master_port,
self.device_type,
func,
kwargs,
)
if self.world_size == 1:
_run_single_rank(0, *args)
return
ctx = mp.start_processes(
_run_single_rank,
args=args,
nprocs=self.world_size,
start_method=self.start_method,
join=False,
)
try:
while not ctx.join():
pass
except BaseException:
for p in ctx.processes:
p.terminate()
ctx.join()
raise
def _detect_launcher() -> str:
"""Detect the distributed launcher from environment.
Returns one of: "torchelastic", "torchrun", "external", "local".
"""
if dist.is_torchelastic_launched():
return "torchelastic"
if "LOCAL_WORLD_SIZE" in os.environ:
return "torchrun"
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
return "external"
return "local"
def spawn_parallel_fn(
func: Callable,
world_size: int,
@ -127,57 +233,13 @@ def spawn_parallel_fn(
start_method: str = "spawn",
**kwargs,
):
# Multi-node support: if RANK env var is set, init process group
# and run function directly (no local spawn).
if "RANK" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=os.environ.get("MASTER_ADDR", master_addr),
master_port=os.environ.get("MASTER_PORT", master_port),
device_type=device_type,
):
func(**kwargs)
return
# clear environment variables (single-node path)
for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ:
del os.environ[key]
if world_size == 1:
device_id = torch.device(device_type, 0)
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return
wrapper_spawn_func_args = (
world_size,
backend,
master_addr,
master_port,
device_type,
func,
kwargs,
launcher = _detect_launcher()
if launcher in ("torchelastic", "torchrun", "external"):
strategy = TorchrunStrategy(
world_size, backend, master_addr, master_port, device_type, start_method
)
mp.start_processes(
wrapper_spawn_func,
args=wrapper_spawn_func_args,
nprocs=world_size,
start_method=start_method,
join=True,
else:
strategy = LocalStrategy(
world_size, backend, master_addr, master_port, device_type, start_method
)
strategy.launch(func, **kwargs)

View File

@ -315,6 +315,8 @@ def train(
},
)
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
train_config = TrainConfig(
model_fn=model_fn,
strategy=train_type,
@ -332,9 +334,6 @@ def train(
random_seed=random_seed,
num_workers=num_workers,
pin_memory=pin_memory,
gradient_checkpointing_modules=[DecoderBlock]
if gradient_checkpointing
else [],
nprocs=nprocs,
backend=backend,
master_addr=master_addr,
@ -342,6 +341,7 @@ def train(
parallel_mode=parallel_mode,
device_type=device_type,
start_method=start_method,
gradient_checkpointing_modules=grad_ckpt_modules,
executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs,
)