refactor : 并行启动 Strategy 模式重构,local_rank 解耦

- setup_parallel 接收 local_rank 参数,不再读环境变量推导
- TorchrunStrategy 从 env 读取 LOCAL_RANK,LocalStrategy 用 rank
- _detect_launcher() 分级检测替代内联 RANK 检查
- _run_single_rank 统一入口,消除 _run_single/_run_multi 重复
- 优雅退出:except BaseException 终止子进程并 re-join
- gradient_checkpointing_modules 判定提取到外部变量
This commit is contained in:
ViperEkura 2026-06-02 11:22:24 +08:00
parent d6899100ac
commit 9b416c1bbb
2 changed files with 129 additions and 67 deletions

View File

@ -1,4 +1,5 @@
import os import os
from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from typing import Callable from typing import Callable
@ -30,6 +31,7 @@ def get_rank() -> int:
def setup_parallel( def setup_parallel(
rank: int, rank: int,
world_size: int, world_size: int,
local_rank: int,
backend: str = "nccl", backend: str = "nccl",
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
@ -41,10 +43,13 @@ def setup_parallel(
return return
if world_size <= 1: if world_size <= 1:
device_id = torch.device(device_type, local_rank)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
yield None yield None
return return
local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank
device_id = torch.device(device_type, local_rank) device_id = torch.device(device_type, local_rank)
os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_ADDR"] = master_addr
@ -91,7 +96,7 @@ def only_on_rank(rank, sync=False):
return decorator return decorator
def wrapper_spawn_func( def _run_single_rank(
rank: int, rank: int,
world_size: int, world_size: int,
backend: str, backend: str,
@ -101,10 +106,10 @@ def wrapper_spawn_func(
func: Callable, func: Callable,
kwargs: dict, kwargs: dict,
): ):
try:
with setup_parallel( with setup_parallel(
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
local_rank=rank,
backend=backend, backend=backend,
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
@ -112,11 +117,112 @@ def wrapper_spawn_func(
): ):
func(**kwargs) func(**kwargs)
except Exception as e:
print(f"Error in rank {rank}: {e}") class LaunchStrategy(ABC):
"""Strategy for launching a function in a distributed context."""
def __init__(
self,
world_size: int,
backend: str,
master_addr: str,
master_port: str,
device_type: str,
start_method: str,
):
self.world_size = world_size
self.backend = backend
self.master_addr = master_addr
self.master_port = master_port
self.device_type = device_type
self.start_method = start_method
@abstractmethod
def launch(self, func: Callable, **kwargs):
raise NotImplementedError
class TorchrunStrategy(LaunchStrategy):
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
def launch(self, func: Callable, **kwargs):
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", rank))
with setup_parallel(
rank=rank,
world_size=world_size,
local_rank=local_rank,
backend=self.backend,
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
master_port=os.environ.get("MASTER_PORT", self.master_port),
device_type=self.device_type,
):
func(**kwargs)
class LocalStrategy(LaunchStrategy):
"""Local launcher — single-process or mp.start_processes."""
def _clear_env(self):
for key in (
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
):
os.environ.pop(key, None)
def launch(self, func: Callable, **kwargs):
self._clear_env()
args = (
self.world_size,
self.backend,
self.master_addr,
self.master_port,
self.device_type,
func,
kwargs,
)
if self.world_size == 1:
_run_single_rank(0, *args)
return
ctx = mp.start_processes(
_run_single_rank,
args=args,
nprocs=self.world_size,
start_method=self.start_method,
join=False,
)
try:
while not ctx.join():
pass
except BaseException:
for p in ctx.processes:
p.terminate()
ctx.join()
raise raise
def _detect_launcher() -> str:
"""Detect the distributed launcher from environment.
Returns one of: "torchelastic", "torchrun", "external", "local".
"""
if dist.is_torchelastic_launched():
return "torchelastic"
if "LOCAL_WORLD_SIZE" in os.environ:
return "torchrun"
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
return "external"
return "local"
def spawn_parallel_fn( def spawn_parallel_fn(
func: Callable, func: Callable,
world_size: int, world_size: int,
@ -127,57 +233,13 @@ def spawn_parallel_fn(
start_method: str = "spawn", start_method: str = "spawn",
**kwargs, **kwargs,
): ):
# Multi-node support: if RANK env var is set, init process group launcher = _detect_launcher()
# and run function directly (no local spawn). if launcher in ("torchelastic", "torchrun", "external"):
if "RANK" in os.environ: strategy = TorchrunStrategy(
rank = int(os.environ["RANK"]) world_size, backend, master_addr, master_port, device_type, start_method
world_size = int(os.environ["WORLD_SIZE"])
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=os.environ.get("MASTER_ADDR", master_addr),
master_port=os.environ.get("MASTER_PORT", master_port),
device_type=device_type,
):
func(**kwargs)
return
# clear environment variables (single-node path)
for key in [
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_DEVICE",
]:
if key in os.environ:
del os.environ[key]
if world_size == 1:
device_id = torch.device(device_type, 0)
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
return
wrapper_spawn_func_args = (
world_size,
backend,
master_addr,
master_port,
device_type,
func,
kwargs,
) )
else:
mp.start_processes( strategy = LocalStrategy(
wrapper_spawn_func, world_size, backend, master_addr, master_port, device_type, start_method
args=wrapper_spawn_func_args,
nprocs=world_size,
start_method=start_method,
join=True,
) )
strategy.launch(func, **kwargs)

View File

@ -315,6 +315,8 @@ def train(
}, },
) )
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
train_config = TrainConfig( train_config = TrainConfig(
model_fn=model_fn, model_fn=model_fn,
strategy=train_type, strategy=train_type,
@ -332,9 +334,6 @@ def train(
random_seed=random_seed, random_seed=random_seed,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
gradient_checkpointing_modules=[DecoderBlock]
if gradient_checkpointing
else [],
nprocs=nprocs, nprocs=nprocs,
backend=backend, backend=backend,
master_addr=master_addr, master_addr=master_addr,
@ -342,6 +341,7 @@ def train(
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device_type=device_type, device_type=device_type,
start_method=start_method, start_method=start_method,
gradient_checkpointing_modules=grad_ckpt_modules,
executor_kwargs=executor_kwargs, executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs, extra_kwargs=strategy_kwargs,
) )