Merge pull request #17 from yegroup001/main

增加多机DDP
This commit is contained in:
ViperEkura 2026-06-02 10:29:07 +08:00 committed by GitHub
commit d6899100ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 68 additions and 8 deletions

View File

@ -18,6 +18,7 @@ Key properties:
""" """
import bisect import bisect
import glob
import json import json
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -113,13 +114,17 @@ def detect_format(load_path: str) -> str:
return "h5" return "h5"
raise ValueError(f"Unsupported file format: {suffix}") raise ValueError(f"Unsupported file format: {suffix}")
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5")) h5_files = [
Path(p)
for pattern in ("*.h5", "*.hdf5")
for p in glob.glob(str(root / "**" / pattern), recursive=True)
]
if h5_files: if h5_files:
return "h5" return "h5"
bin_files = list(root.rglob("*.bin")) bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
if bin_files: if bin_files:
has_meta = (root / "meta.json").exists() or len( has_meta = (root / "meta.json").exists() or len(
list(root.rglob("meta.json")) [Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
) > 0 ) > 0
if has_meta: if has_meta:
return "bin" return "bin"
@ -250,7 +255,9 @@ class MmapStore(Store):
self._mmap_refs = [] self._mmap_refs = []
root = Path(path) root = Path(path)
all_raw: Dict[str, List[Tensor]] = {} all_raw: Dict[str, List[Tensor]] = {}
meta_paths = list(root.rglob("meta.json")) meta_paths = [
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
]
for meta_path in meta_paths: for meta_path in meta_paths:
raw = load_bin(str(meta_path.parent)) raw = load_bin(str(meta_path.parent))
for key, tensors in raw.items(): for key, tensors in raw.items():

View File

@ -2,6 +2,7 @@
import contextlib import contextlib
import logging import logging
import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple from typing import Optional, Tuple
@ -181,7 +182,7 @@ class DDPExecutor(BaseExecutor):
if not self.use_distributed: if not self.use_distributed:
logger.warning("DDP backend selected but world_size=1, model not wrapped") logger.warning("DDP backend selected but world_size=1, model not wrapped")
return model return model
local_rank = get_rank() local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
model = DDP( model = DDP(
model, model,
device_ids=[local_rank], device_ids=[local_rank],

View File

@ -44,11 +44,12 @@ def setup_parallel(
yield None yield None
return return
device_id = torch.device(device_type, rank) local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank
device_id = torch.device(device_type, local_rank)
os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id) os.environ["LOCAL_DEVICE"] = str(device_id)
@ -126,7 +127,23 @@ def spawn_parallel_fn(
start_method: str = "spawn", start_method: str = "spawn",
**kwargs, **kwargs,
): ):
# clear environment variables # Multi-node support: if RANK env var is set, init process group
# and run function directly (no local spawn).
if "RANK" in os.environ:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
with setup_parallel(
rank=rank,
world_size=world_size,
backend=backend,
master_addr=os.environ.get("MASTER_ADDR", master_addr),
master_port=os.environ.get("MASTER_PORT", master_port),
device_type=device_type,
):
func(**kwargs)
return
# clear environment variables (single-node path)
for key in [ for key in [
"MASTER_ADDR", "MASTER_ADDR",
"MASTER_PORT", "MASTER_PORT",

View File

@ -8,6 +8,7 @@ import torch.optim as optim
from astrai.config import AutoRegressiveLMConfig, TrainConfig from astrai.config import AutoRegressiveLMConfig, TrainConfig
from astrai.dataset import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.model import AutoRegressiveLM from astrai.model import AutoRegressiveLM
from astrai.model.components.decoder_block import DecoderBlock
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer
@ -115,6 +116,12 @@ def parse_args() -> argparse.Namespace:
default=0.05, default=0.05,
help="cross_entropy function label smoothing parameter", help="cross_entropy function label smoothing parameter",
) )
parser.add_argument(
"--gradient_checkpointing",
action=argparse.BooleanOptionalAction,
default=False,
help="Enable activation checkpointing for DecoderBlock modules.",
)
parser.add_argument( parser.add_argument(
"--ckpt_interval", "--ckpt_interval",
@ -141,6 +148,24 @@ def parse_args() -> argparse.Namespace:
"--start_batch", type=int, default=0, help="Start batch for training." "--start_batch", type=int, default=0, help="Start batch for training."
) )
parser.add_argument(
"--master_addr",
type=str,
default="localhost",
help="Master node address for distributed training.",
)
parser.add_argument(
"--master_port",
type=str,
default="29500",
help="Master node port for distributed training.",
)
parser.add_argument(
"--backend",
type=str,
default="nccl",
help="Distributed training backend.",
)
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.") parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
parser.add_argument( parser.add_argument(
"--parallel_mode", "--parallel_mode",
@ -222,11 +247,15 @@ def train(
random_seed: int, random_seed: int,
num_workers: int, num_workers: int,
pin_memory: bool, pin_memory: bool,
gradient_checkpointing: bool,
window_size: int, window_size: int,
stride: int, stride: int,
nprocs: int, nprocs: int,
parallel_mode: str, parallel_mode: str,
device_type: str, device_type: str,
backend: str,
master_addr: str,
master_port: str,
start_method: str, start_method: str,
): ):
assert train_type in ["seq", "sft", "dpo", "grpo"] assert train_type in ["seq", "sft", "dpo", "grpo"]
@ -303,7 +332,13 @@ def train(
random_seed=random_seed, random_seed=random_seed,
num_workers=num_workers, num_workers=num_workers,
pin_memory=pin_memory, pin_memory=pin_memory,
gradient_checkpointing_modules=[DecoderBlock]
if gradient_checkpointing
else [],
nprocs=nprocs, nprocs=nprocs,
backend=backend,
master_addr=master_addr,
master_port=master_port,
parallel_mode=parallel_mode, parallel_mode=parallel_mode,
device_type=device_type, device_type=device_type,
start_method=start_method, start_method=start_method,