refactor : 并行启动 Strategy 模式重构,local_rank 解耦
- setup_parallel 接收 local_rank 参数,不再读环境变量推导 - TorchrunStrategy 从 env 读取 LOCAL_RANK,LocalStrategy 用 rank - _detect_launcher() 分级检测替代内联 RANK 检查 - _run_single_rank 统一入口,消除 _run_single/_run_multi 重复 - 优雅退出:except BaseException 终止子进程并 re-join - gradient_checkpointing_modules 判定提取到外部变量
This commit is contained in:
parent
d6899100ac
commit
9b416c1bbb
|
|
@ -1,4 +1,5 @@
|
||||||
import os
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
@ -30,6 +31,7 @@ def get_rank() -> int:
|
||||||
def setup_parallel(
|
def setup_parallel(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
|
local_rank: int,
|
||||||
backend: str = "nccl",
|
backend: str = "nccl",
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
|
|
@ -41,10 +43,13 @@ def setup_parallel(
|
||||||
return
|
return
|
||||||
|
|
||||||
if world_size <= 1:
|
if world_size <= 1:
|
||||||
|
device_id = torch.device(device_type, local_rank)
|
||||||
|
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||||
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank
|
|
||||||
device_id = torch.device(device_type, local_rank)
|
device_id = torch.device(device_type, local_rank)
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
|
|
@ -91,7 +96,7 @@ def only_on_rank(rank, sync=False):
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def wrapper_spawn_func(
|
def _run_single_rank(
|
||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
backend: str,
|
backend: str,
|
||||||
|
|
@ -101,20 +106,121 @@ def wrapper_spawn_func(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
):
|
):
|
||||||
try:
|
with setup_parallel(
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
local_rank=rank,
|
||||||
|
backend=backend,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
|
device_type=device_type,
|
||||||
|
):
|
||||||
|
func(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LaunchStrategy(ABC):
|
||||||
|
"""Strategy for launching a function in a distributed context."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
world_size: int,
|
||||||
|
backend: str,
|
||||||
|
master_addr: str,
|
||||||
|
master_port: str,
|
||||||
|
device_type: str,
|
||||||
|
start_method: str,
|
||||||
|
):
|
||||||
|
self.world_size = world_size
|
||||||
|
self.backend = backend
|
||||||
|
self.master_addr = master_addr
|
||||||
|
self.master_port = master_port
|
||||||
|
self.device_type = device_type
|
||||||
|
self.start_method = start_method
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def launch(self, func: Callable, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TorchrunStrategy(LaunchStrategy):
|
||||||
|
"""External orchestrator (torchrun, SLURM, K8s) — env vars pre-set."""
|
||||||
|
|
||||||
|
def launch(self, func: Callable, **kwargs):
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", rank))
|
||||||
with setup_parallel(
|
with setup_parallel(
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
backend=backend,
|
local_rank=local_rank,
|
||||||
master_addr=master_addr,
|
backend=self.backend,
|
||||||
master_port=master_port,
|
master_addr=os.environ.get("MASTER_ADDR", self.master_addr),
|
||||||
device_type=device_type,
|
master_port=os.environ.get("MASTER_PORT", self.master_port),
|
||||||
|
device_type=self.device_type,
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error in rank {rank}: {e}")
|
class LocalStrategy(LaunchStrategy):
|
||||||
raise
|
"""Local launcher — single-process or mp.start_processes."""
|
||||||
|
|
||||||
|
def _clear_env(self):
|
||||||
|
for key in (
|
||||||
|
"MASTER_ADDR",
|
||||||
|
"MASTER_PORT",
|
||||||
|
"RANK",
|
||||||
|
"WORLD_SIZE",
|
||||||
|
"LOCAL_RANK",
|
||||||
|
"LOCAL_DEVICE",
|
||||||
|
):
|
||||||
|
os.environ.pop(key, None)
|
||||||
|
|
||||||
|
def launch(self, func: Callable, **kwargs):
|
||||||
|
self._clear_env()
|
||||||
|
|
||||||
|
args = (
|
||||||
|
self.world_size,
|
||||||
|
self.backend,
|
||||||
|
self.master_addr,
|
||||||
|
self.master_port,
|
||||||
|
self.device_type,
|
||||||
|
func,
|
||||||
|
kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.world_size == 1:
|
||||||
|
_run_single_rank(0, *args)
|
||||||
|
return
|
||||||
|
|
||||||
|
ctx = mp.start_processes(
|
||||||
|
_run_single_rank,
|
||||||
|
args=args,
|
||||||
|
nprocs=self.world_size,
|
||||||
|
start_method=self.start_method,
|
||||||
|
join=False,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
while not ctx.join():
|
||||||
|
pass
|
||||||
|
except BaseException:
|
||||||
|
for p in ctx.processes:
|
||||||
|
p.terminate()
|
||||||
|
ctx.join()
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_launcher() -> str:
|
||||||
|
"""Detect the distributed launcher from environment.
|
||||||
|
|
||||||
|
Returns one of: "torchelastic", "torchrun", "external", "local".
|
||||||
|
"""
|
||||||
|
if dist.is_torchelastic_launched():
|
||||||
|
return "torchelastic"
|
||||||
|
if "LOCAL_WORLD_SIZE" in os.environ:
|
||||||
|
return "torchrun"
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
return "external"
|
||||||
|
return "local"
|
||||||
|
|
||||||
|
|
||||||
def spawn_parallel_fn(
|
def spawn_parallel_fn(
|
||||||
|
|
@ -127,57 +233,13 @@ def spawn_parallel_fn(
|
||||||
start_method: str = "spawn",
|
start_method: str = "spawn",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# Multi-node support: if RANK env var is set, init process group
|
launcher = _detect_launcher()
|
||||||
# and run function directly (no local spawn).
|
if launcher in ("torchelastic", "torchrun", "external"):
|
||||||
if "RANK" in os.environ:
|
strategy = TorchrunStrategy(
|
||||||
rank = int(os.environ["RANK"])
|
world_size, backend, master_addr, master_port, device_type, start_method
|
||||||
world_size = int(os.environ["WORLD_SIZE"])
|
)
|
||||||
with setup_parallel(
|
else:
|
||||||
rank=rank,
|
strategy = LocalStrategy(
|
||||||
world_size=world_size,
|
world_size, backend, master_addr, master_port, device_type, start_method
|
||||||
backend=backend,
|
)
|
||||||
master_addr=os.environ.get("MASTER_ADDR", master_addr),
|
strategy.launch(func, **kwargs)
|
||||||
master_port=os.environ.get("MASTER_PORT", master_port),
|
|
||||||
device_type=device_type,
|
|
||||||
):
|
|
||||||
func(**kwargs)
|
|
||||||
return
|
|
||||||
|
|
||||||
# clear environment variables (single-node path)
|
|
||||||
for key in [
|
|
||||||
"MASTER_ADDR",
|
|
||||||
"MASTER_PORT",
|
|
||||||
"RANK",
|
|
||||||
"WORLD_SIZE",
|
|
||||||
"LOCAL_RANK",
|
|
||||||
"LOCAL_DEVICE",
|
|
||||||
]:
|
|
||||||
if key in os.environ:
|
|
||||||
del os.environ[key]
|
|
||||||
|
|
||||||
if world_size == 1:
|
|
||||||
device_id = torch.device(device_type, 0)
|
|
||||||
os.environ["LOCAL_RANK"] = "0"
|
|
||||||
os.environ["WORLD_SIZE"] = "1"
|
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
|
||||||
|
|
||||||
func(**kwargs)
|
|
||||||
return
|
|
||||||
|
|
||||||
wrapper_spawn_func_args = (
|
|
||||||
world_size,
|
|
||||||
backend,
|
|
||||||
master_addr,
|
|
||||||
master_port,
|
|
||||||
device_type,
|
|
||||||
func,
|
|
||||||
kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
mp.start_processes(
|
|
||||||
wrapper_spawn_func,
|
|
||||||
args=wrapper_spawn_func_args,
|
|
||||||
nprocs=world_size,
|
|
||||||
start_method=start_method,
|
|
||||||
join=True,
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -315,6 +315,8 @@ def train(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
grad_ckpt_modules = [DecoderBlock] if gradient_checkpointing else []
|
||||||
|
|
||||||
train_config = TrainConfig(
|
train_config = TrainConfig(
|
||||||
model_fn=model_fn,
|
model_fn=model_fn,
|
||||||
strategy=train_type,
|
strategy=train_type,
|
||||||
|
|
@ -332,9 +334,6 @@ def train(
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
gradient_checkpointing_modules=[DecoderBlock]
|
|
||||||
if gradient_checkpointing
|
|
||||||
else [],
|
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
backend=backend,
|
backend=backend,
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
|
|
@ -342,6 +341,7 @@ def train(
|
||||||
parallel_mode=parallel_mode,
|
parallel_mode=parallel_mode,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
start_method=start_method,
|
start_method=start_method,
|
||||||
|
gradient_checkpointing_modules=grad_ckpt_modules,
|
||||||
executor_kwargs=executor_kwargs,
|
executor_kwargs=executor_kwargs,
|
||||||
extra_kwargs=strategy_kwargs,
|
extra_kwargs=strategy_kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue