refactor: 移除 device_ids 参数设计,统一通过 CUDA_VISIBLE_DEVICES 控制 GPU 分配;更新 README 训练示例
- setup.py: 移除 device_ids 参数,setup_parallel 直接用 rank 作为设备索引 - train_config.py: 移除 device_ids 字段 - trainer.py: 不再传递 device_ids - train.py: ddp_wrap 用 get_rank() 直接取值 - README.md, README-zh-CN.md: 训练示例改为多行命令风格,去掉参数表格
This commit is contained in:
parent
283bcaf2ff
commit
b98c9cefdc
43
README.md
43
README.md
|
|
@ -68,41 +68,18 @@ pip install -e ".[dev]"
|
||||||
#### Train a Model
|
#### Train a Model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||||
|
--train_type seq \
|
||||||
|
--data_root_path /path/to/dataset \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--batch_size 4 \
|
||||||
|
--accumulation_steps 8 \
|
||||||
|
--max_lr 3e-4 \
|
||||||
|
--warmup_steps 1000 \
|
||||||
|
--n_epoch 1
|
||||||
```
|
```
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||||
|-----------|-------------|---------|
|
|
||||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
|
||||||
| `--param_path` | Model / checkpoint path | required |
|
|
||||||
| `--n_epoch` | Training epochs | 1 |
|
|
||||||
| `--batch_size` | Batch size | 1 |
|
|
||||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
|
||||||
| `--warmup_steps` | LR warmup steps | 1000 |
|
|
||||||
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
|
||||||
| `--max_grad_norm` | Max gradient norm for clipping | 1.0 |
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
|
||||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
|
||||||
| `--random_seed` | Random seed | 3407 |
|
|
||||||
| `--num_workers` | DataLoader workers | 4 |
|
|
||||||
| `--window_size` | Max input sequence length | auto |
|
|
||||||
| `--stride` | Sequence stride | auto |
|
|
||||||
| `--label_smoothing` | Label smoothing for cross entropy | 0.1 |
|
|
||||||
| `--dpo_beta` | DPO beta | 0.1 |
|
|
||||||
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
|
||||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 |
|
|
||||||
| `--group_size` | GRPO group size | 4 |
|
|
||||||
| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 |
|
|
||||||
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
|
||||||
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
|
||||||
| `--start_epoch` | Start epoch (for resume) | 0 |
|
|
||||||
| `--start_batch` | Start batch (for resume) | 0 |
|
|
||||||
| `--nprocs` | Number of GPUs | 1 |
|
|
||||||
| `--device_type` | Device type | cuda |
|
|
||||||
|
|
||||||
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
|
||||||
|
|
||||||
#### Generate Text
|
#### Generate Text
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -74,41 +74,18 @@ pip install -e ".[dev]"
|
||||||
#### 训练模型
|
#### 训练模型
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
|
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||||
|
--train_type seq \
|
||||||
|
--data_root_path /path/to/dataset \
|
||||||
|
--param_path /path/to/model \
|
||||||
|
--batch_size 4 \
|
||||||
|
--accumulation_steps 8 \
|
||||||
|
--max_lr 3e-4 \
|
||||||
|
--warmup_steps 1000 \
|
||||||
|
--n_epoch 1
|
||||||
```
|
```
|
||||||
|
|
||||||
| 参数 | 说明 | 默认值 |
|
完整参数列表见[参数说明](./params.md)。
|
||||||
|------|------|--------|
|
|
||||||
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo`) | 必填 |
|
|
||||||
| `--data_root_path` | 数据集根目录 | 必填 |
|
|
||||||
| `--param_path` | 模型参数或断点路径 | 必填 |
|
|
||||||
| `--n_epoch` | 训练轮数 | 1 |
|
|
||||||
| `--batch_size` | 批次大小 | 1 |
|
|
||||||
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
|
||||||
| `--warmup_steps` | 预热步数 | 1000 |
|
|
||||||
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
|
||||||
| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 |
|
|
||||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
|
||||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
|
||||||
| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 |
|
|
||||||
| `--random_seed` | 随机种子 | 3407 |
|
|
||||||
| `--num_workers` | 数据加载线程数 | 4 |
|
|
||||||
| `--window_size` | 最大输入序列长度 | auto |
|
|
||||||
| `--stride` | 序列步长 | auto |
|
|
||||||
| `--label_smoothing` | 交叉熵标签平滑 | 0.1 |
|
|
||||||
| `--dpo_beta` | DPO beta | 0.1 |
|
|
||||||
| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 |
|
|
||||||
| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 |
|
|
||||||
| `--group_size` | GRPO 组大小 | 4 |
|
|
||||||
| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 |
|
|
||||||
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
|
||||||
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
|
||||||
| `--start_epoch` | 起始轮次(用于断点续训) | 0 |
|
|
||||||
| `--start_batch` | 起始批次(用于断点续训) | 0 |
|
|
||||||
| `--nprocs` | GPU 数量 | 1 |
|
|
||||||
| `--device_type` | 设备类型 | cuda |
|
|
||||||
|
|
||||||
完整参数列表见[参数说明](./params.md#training-parameters)。
|
|
||||||
|
|
||||||
#### 文本生成
|
#### 文本生成
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
@ -74,9 +74,6 @@ class TrainConfig:
|
||||||
)
|
)
|
||||||
|
|
||||||
# others
|
# others
|
||||||
device_ids: Optional[List[int]] = field(
|
|
||||||
default=None, metadata={"help": "Device ids for distributed training."}
|
|
||||||
)
|
|
||||||
device_type: str = field(
|
device_type: str = field(
|
||||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
@ -34,7 +34,6 @@ def setup_parallel(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None,
|
|
||||||
):
|
):
|
||||||
|
|
||||||
if dist.is_available() and dist.is_initialized():
|
if dist.is_available() and dist.is_initialized():
|
||||||
|
|
@ -45,16 +44,10 @@ def setup_parallel(
|
||||||
yield None
|
yield None
|
||||||
return
|
return
|
||||||
|
|
||||||
if device_ids is None:
|
device_id = torch.device(device_type, rank)
|
||||||
device_ids = [i for i in range(world_size)]
|
|
||||||
|
|
||||||
effective_rank = rank % len(device_ids)
|
|
||||||
device_id = torch.device(device_type, device_ids[effective_rank])
|
|
||||||
rank = device_ids[effective_rank]
|
|
||||||
|
|
||||||
os.environ["MASTER_ADDR"] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ["MASTER_PORT"] = master_port
|
os.environ["MASTER_PORT"] = master_port
|
||||||
|
|
||||||
os.environ["LOCAL_RANK"] = str(rank)
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
@ -104,7 +97,6 @@ def wrapper_spawn_func(
|
||||||
master_addr: str,
|
master_addr: str,
|
||||||
master_port: str,
|
master_port: str,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
device_ids: List[int],
|
|
||||||
func: Callable,
|
func: Callable,
|
||||||
kwargs: dict,
|
kwargs: dict,
|
||||||
):
|
):
|
||||||
|
|
@ -116,7 +108,6 @@ def wrapper_spawn_func(
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
device_type=device_type,
|
device_type=device_type,
|
||||||
device_ids=device_ids,
|
|
||||||
):
|
):
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
||||||
|
|
@ -132,7 +123,6 @@ def spawn_parallel_fn(
|
||||||
master_addr: str = "localhost",
|
master_addr: str = "localhost",
|
||||||
master_port: str = "29500",
|
master_port: str = "29500",
|
||||||
device_type: str = "cuda",
|
device_type: str = "cuda",
|
||||||
device_ids: Optional[List[int]] = None,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# clear environment variables
|
# clear environment variables
|
||||||
|
|
@ -148,8 +138,9 @@ def spawn_parallel_fn(
|
||||||
del os.environ[key]
|
del os.environ[key]
|
||||||
|
|
||||||
if world_size == 1:
|
if world_size == 1:
|
||||||
device_ids = device_ids or [0]
|
device_id = torch.device(device_type, 0)
|
||||||
device_id = torch.device(device_type, device_ids[0])
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
os.environ["WORLD_SIZE"] = "1"
|
||||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||||
|
|
||||||
func(**kwargs)
|
func(**kwargs)
|
||||||
|
|
@ -161,7 +152,6 @@ def spawn_parallel_fn(
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
device_type,
|
device_type,
|
||||||
device_ids,
|
|
||||||
func,
|
func,
|
||||||
kwargs,
|
kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,6 @@ class Trainer:
|
||||||
master_addr=config.master_addr,
|
master_addr=config.master_addr,
|
||||||
master_port=config.master_port,
|
master_port=config.master_port,
|
||||||
device_type=config.device_type,
|
device_type=config.device_type,
|
||||||
device_ids=config.device_ids,
|
|
||||||
checkpoint=checkpoint,
|
checkpoint=checkpoint,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -155,7 +155,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
|
|
||||||
def ddp_wrap(model: nn.Module):
|
def ddp_wrap(model: nn.Module):
|
||||||
local_rank = get_rank()
|
local_rank = get_rank()
|
||||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
model = model.to(dtype=torch.bfloat16)
|
||||||
ddp_model = DDP(
|
ddp_model = DDP(
|
||||||
model,
|
model,
|
||||||
device_ids=[local_rank],
|
device_ids=[local_rank],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue