commit
d6899100ac
|
|
@ -18,6 +18,7 @@ Key properties:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import bisect
|
import bisect
|
||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -113,13 +114,17 @@ def detect_format(load_path: str) -> str:
|
||||||
return "h5"
|
return "h5"
|
||||||
raise ValueError(f"Unsupported file format: {suffix}")
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
h5_files = list(root.rglob("*.h5")) + list(root.rglob("*.hdf5"))
|
h5_files = [
|
||||||
|
Path(p)
|
||||||
|
for pattern in ("*.h5", "*.hdf5")
|
||||||
|
for p in glob.glob(str(root / "**" / pattern), recursive=True)
|
||||||
|
]
|
||||||
if h5_files:
|
if h5_files:
|
||||||
return "h5"
|
return "h5"
|
||||||
bin_files = list(root.rglob("*.bin"))
|
bin_files = [Path(p) for p in glob.glob(str(root / "**" / "*.bin"), recursive=True)]
|
||||||
if bin_files:
|
if bin_files:
|
||||||
has_meta = (root / "meta.json").exists() or len(
|
has_meta = (root / "meta.json").exists() or len(
|
||||||
list(root.rglob("meta.json"))
|
[Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)]
|
||||||
) > 0
|
) > 0
|
||||||
if has_meta:
|
if has_meta:
|
||||||
return "bin"
|
return "bin"
|
||||||
|
|
@ -250,7 +255,9 @@ class MmapStore(Store):
|
||||||
self._mmap_refs = []
|
self._mmap_refs = []
|
||||||
root = Path(path)
|
root = Path(path)
|
||||||
all_raw: Dict[str, List[Tensor]] = {}
|
all_raw: Dict[str, List[Tensor]] = {}
|
||||||
meta_paths = list(root.rglob("meta.json"))
|
meta_paths = [
|
||||||
|
Path(p) for p in glob.glob(str(root / "**" / "meta.json"), recursive=True)
|
||||||
|
]
|
||||||
for meta_path in meta_paths:
|
for meta_path in meta_paths:
|
||||||
raw = load_bin(str(meta_path.parent))
|
raw = load_bin(str(meta_path.parent))
|
||||||
for key, tensors in raw.items():
|
for key, tensors in raw.items():
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
@ -181,7 +182,7 @@ class DDPExecutor(BaseExecutor):
|
||||||
if not self.use_distributed:
|
if not self.use_distributed:
|
||||||
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
logger.warning("DDP backend selected but world_size=1, model not wrapped")
|
||||||
return model
|
return model
|
||||||
local_rank = get_rank()
|
local_rank = int(os.environ.get("LOCAL_RANK", get_rank()))
|
||||||
model = DDP(
|
model = DDP(
|
||||||
model,
|
model,
|
||||||
device_ids=[local_rank],
|
device_ids=[local_rank],
|
||||||
|
|
|
||||||
|
|
@ -44,11 +44,12 @@ def setup_parallel(
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
device_id = torch.device(device_type, rank)
|
local_rank = int(os.environ["LOCAL_RANK"]) if "LOCAL_RANK" in os.environ else rank
|
||||||
|
device_id = torch.device(device_type, local_rank)
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ["MASTER_PORT"] = master_port
|
os.environ["MASTER_PORT"] = master_port
|
||||||
os.environ["LOCAL_RANK"] = str(rank)
|
os.environ["LOCAL_RANK"] = str(local_rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
|
|
@ -126,7 +127,23 @@ def spawn_parallel_fn(
|
||||||
start_method: str = "spawn",
|
start_method: str = "spawn",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# clear environment variables
|
# Multi-node support: if RANK env var is set, init process group
|
||||||
|
# and run function directly (no local spawn).
|
||||||
|
if "RANK" in os.environ:
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
with setup_parallel(
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
backend=backend,
|
||||||
|
master_addr=os.environ.get("MASTER_ADDR", master_addr),
|
||||||
|
master_port=os.environ.get("MASTER_PORT", master_port),
|
||||||
|
device_type=device_type,
|
||||||
|
):
|
||||||
|
func(**kwargs)
|
||||||
|
return
|
||||||
|
|
||||||
|
# clear environment variables (single-node path)
|
||||||
for key in [
|
for key in [
|
||||||
"MASTER_ADDR",
|
"MASTER_ADDR",
|
||||||
"MASTER_PORT",
|
"MASTER_PORT",
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ import torch.optim as optim
|
||||||
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
from astrai.config import AutoRegressiveLMConfig, TrainConfig
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.model import AutoRegressiveLM
|
from astrai.model import AutoRegressiveLM
|
||||||
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -115,6 +116,12 @@ def parse_args() -> argparse.Namespace:
|
||||||
default=0.05,
|
default=0.05,
|
||||||
help="cross_entropy function label smoothing parameter",
|
help="cross_entropy function label smoothing parameter",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--gradient_checkpointing",
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=False,
|
||||||
|
help="Enable activation checkpointing for DecoderBlock modules.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ckpt_interval",
|
"--ckpt_interval",
|
||||||
|
|
@ -141,6 +148,24 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--start_batch", type=int, default=0, help="Start batch for training."
|
"--start_batch", type=int, default=0, help="Start batch for training."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master_addr",
|
||||||
|
type=str,
|
||||||
|
default="localhost",
|
||||||
|
help="Master node address for distributed training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--master_port",
|
||||||
|
type=str,
|
||||||
|
default="29500",
|
||||||
|
help="Master node port for distributed training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="nccl",
|
||||||
|
help="Distributed training backend.",
|
||||||
|
)
|
||||||
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
parser.add_argument("--nprocs", type=int, default=1, help="Number of GPUs to use.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--parallel_mode",
|
"--parallel_mode",
|
||||||
|
|
@ -222,11 +247,15 @@ def train(
|
||||||
random_seed: int,
|
random_seed: int,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
|
gradient_checkpointing: bool,
|
||||||
window_size: int,
|
window_size: int,
|
||||||
stride: int,
|
stride: int,
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
parallel_mode: str,
|
parallel_mode: str,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
|
backend: str,
|
||||||
|
master_addr: str,
|
||||||
|
master_port: str,
|
||||||
start_method: str,
|
start_method: str,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||||
|
|
@ -303,7 +332,13 @@ def train(
|
||||||
random_seed=random_seed,
|
random_seed=random_seed,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
pin_memory=pin_memory,
|
pin_memory=pin_memory,
|
||||||
|
gradient_checkpointing_modules=[DecoderBlock]
|
||||||
|
if gradient_checkpointing
|
||||||
|
else [],
|
||||||
nprocs=nprocs,
|
nprocs=nprocs,
|
||||||
|
backend=backend,
|
||||||
|
master_addr=master_addr,
|
||||||
|
master_port=master_port,
|
||||||
parallel_mode=parallel_mode,
|
parallel_mode=parallel_mode,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
start_method=start_method,
|
start_method=start_method,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue