refactor: 移除 device_ids 参数设计,统一通过 CUDA_VISIBLE_DEVICES 控制 GPU 分配;更新 README 训练示例

- setup.py: 移除 device_ids 参数,setup_parallel 直接用 rank 作为设备索引
- train_config.py: 移除 device_ids 字段
- trainer.py: 不再传递 device_ids
- train.py: ddp_wrap 用 get_rank() 直接取值
- README.md, README-zh-CN.md: 训练示例改为多行命令风格,去掉参数表格
This commit is contained in:
ViperEkura 2026-05-09 14:55:43 +08:00
parent 283bcaf2ff
commit b98c9cefdc
6 changed files with 27 additions and 87 deletions

View File

@ -68,41 +68,18 @@ pip install -e ".[dev]"
#### Train a Model
```bash
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
```
| Parameter | Description | Default |
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--max_grad_norm` | Max gradient norm for clipping | 1.0 |
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
| `--random_seed` | Random seed | 3407 |
| `--num_workers` | DataLoader workers | 4 |
| `--window_size` | Max input sequence length | auto |
| `--stride` | Sequence stride | auto |
| `--label_smoothing` | Label smoothing for cross entropy | 0.1 |
| `--dpo_beta` | DPO beta | 0.1 |
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 |
| `--group_size` | GRPO group size | 4 |
| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--start_epoch` | Start epoch (for resume) | 0 |
| `--start_batch` | Start batch (for resume) | 0 |
| `--nprocs` | Number of GPUs | 1 |
| `--device_type` | Device type | cuda |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
Full reference at [Parameter Guide](assets/docs/params.md).
#### Generate Text

View File

@ -74,41 +74,18 @@ pip install -e ".[dev]"
#### 训练模型
```bash
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
```
| 参数 | 说明 | 默认值 |
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 |
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 |
| `--random_seed` | 随机种子 | 3407 |
| `--num_workers` | 数据加载线程数 | 4 |
| `--window_size` | 最大输入序列长度 | auto |
| `--stride` | 序列步长 | auto |
| `--label_smoothing` | 交叉熵标签平滑 | 0.1 |
| `--dpo_beta` | DPO beta | 0.1 |
| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 |
| `--group_size` | GRPO 组大小 | 4 |
| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--start_epoch` | 起始轮次(用于断点续训) | 0 |
| `--start_batch` | 起始批次(用于断点续训) | 0 |
| `--nprocs` | GPU 数量 | 1 |
| `--device_type` | 设备类型 | cuda |
完整参数列表见[参数说明](./params.md#training-parameters)。
完整参数列表见[参数说明](./params.md)。
#### 文本生成

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Callable, List, Optional
from typing import Callable, Optional
import torch.nn as nn
from torch.optim import Optimizer
@ -74,9 +74,6 @@ class TrainConfig:
)
# others
device_ids: Optional[List[int]] = field(
default=None, metadata={"help": "Device ids for distributed training."}
)
device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."}
)

View File

@ -1,7 +1,7 @@
import os
from contextlib import contextmanager
from functools import wraps
from typing import Callable, List, Optional
from typing import Callable
import torch
import torch.distributed as dist
@ -34,7 +34,6 @@ def setup_parallel(
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
device_ids: Optional[List[int]] = None,
):
if dist.is_available() and dist.is_initialized():
@ -45,16 +44,10 @@ def setup_parallel(
yield None
return
if device_ids is None:
device_ids = [i for i in range(world_size)]
effective_rank = rank % len(device_ids)
device_id = torch.device(device_type, device_ids[effective_rank])
rank = device_ids[effective_rank]
device_id = torch.device(device_type, rank)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id)
@ -104,7 +97,6 @@ def wrapper_spawn_func(
master_addr: str,
master_port: str,
device_type: str,
device_ids: List[int],
func: Callable,
kwargs: dict,
):
@ -116,7 +108,6 @@ def wrapper_spawn_func(
master_addr=master_addr,
master_port=master_port,
device_type=device_type,
device_ids=device_ids,
):
func(**kwargs)
@ -132,7 +123,6 @@ def spawn_parallel_fn(
master_addr: str = "localhost",
master_port: str = "29500",
device_type: str = "cuda",
device_ids: Optional[List[int]] = None,
**kwargs,
):
# clear environment variables
@ -148,8 +138,9 @@ def spawn_parallel_fn(
del os.environ[key]
if world_size == 1:
device_ids = device_ids or [0]
device_id = torch.device(device_type, device_ids[0])
device_id = torch.device(device_type, 0)
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs)
@ -161,7 +152,6 @@ def spawn_parallel_fn(
master_addr,
master_port,
device_type,
device_ids,
func,
kwargs,
)

View File

@ -53,7 +53,6 @@ class Trainer:
master_addr=config.master_addr,
master_port=config.master_port,
device_type=config.device_type,
device_ids=config.device_ids,
checkpoint=checkpoint,
)

View File

@ -155,7 +155,7 @@ def parse_args() -> argparse.Namespace:
def ddp_wrap(model: nn.Module):
local_rank = get_rank()
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)
ddp_model = DDP(
model,
device_ids=[local_rank],