refactor: 移除 device_ids 参数设计,统一通过 CUDA_VISIBLE_DEVICES 控制 GPU 分配;更新 README 训练示例

- setup.py: 移除 device_ids 参数,setup_parallel 直接用 rank 作为设备索引
- train_config.py: 移除 device_ids 字段
- trainer.py: 不再传递 device_ids
- train.py: ddp_wrap 用 get_rank() 直接取值
- README.md, README-zh-CN.md: 训练示例改为多行命令风格,去掉参数表格
This commit is contained in:
ViperEkura 2026-05-09 14:55:43 +08:00
parent 283bcaf2ff
commit b98c9cefdc
6 changed files with 27 additions and 87 deletions

View File

@ -68,41 +68,18 @@ pip install -e ".[dev]"
#### Train a Model #### Train a Model
```bash ```bash
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
``` ```
| Parameter | Description | Default | Full reference at [Parameter Guide](assets/docs/params.md).
|-----------|-------------|---------|
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
| `--data_root_path` | Dataset root directory | required |
| `--param_path` | Model / checkpoint path | required |
| `--n_epoch` | Training epochs | 1 |
| `--batch_size` | Batch size | 1 |
| `--accumulation_steps` | Gradient accumulation steps | 1 |
| `--warmup_steps` | LR warmup steps | 1000 |
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
| `--max_grad_norm` | Max gradient norm for clipping | 1.0 |
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
| `--random_seed` | Random seed | 3407 |
| `--num_workers` | DataLoader workers | 4 |
| `--window_size` | Max input sequence length | auto |
| `--stride` | Sequence stride | auto |
| `--label_smoothing` | Label smoothing for cross entropy | 0.1 |
| `--dpo_beta` | DPO beta | 0.1 |
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 |
| `--group_size` | GRPO group size | 4 |
| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 |
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
| `--ckpt_dir` | Checkpoint directory | checkpoint |
| `--start_epoch` | Start epoch (for resume) | 0 |
| `--start_batch` | Start batch (for resume) | 0 |
| `--nprocs` | Number of GPUs | 1 |
| `--device_type` | Device type | cuda |
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
#### Generate Text #### Generate Text

View File

@ -74,41 +74,18 @@ pip install -e ".[dev]"
#### 训练模型 #### 训练模型
```bash ```bash
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
--train_type seq \
--data_root_path /path/to/dataset \
--param_path /path/to/model \
--batch_size 4 \
--accumulation_steps 8 \
--max_lr 3e-4 \
--warmup_steps 1000 \
--n_epoch 1
``` ```
| 参数 | 说明 | 默认值 | 完整参数列表见[参数说明](./params.md)。
|------|------|--------|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo` | 必填 |
| `--data_root_path` | 数据集根目录 | 必填 |
| `--param_path` | 模型参数或断点路径 | 必填 |
| `--n_epoch` | 训练轮数 | 1 |
| `--batch_size` | 批次大小 | 1 |
| `--accumulation_steps` | 梯度累积步数 | 1 |
| `--warmup_steps` | 预热步数 | 1000 |
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 |
| `--adamw_beta1` | AdamW beta1 | 0.9 |
| `--adamw_beta2` | AdamW beta2 | 0.95 |
| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 |
| `--random_seed` | 随机种子 | 3407 |
| `--num_workers` | 数据加载线程数 | 4 |
| `--window_size` | 最大输入序列长度 | auto |
| `--stride` | 序列步长 | auto |
| `--label_smoothing` | 交叉熵标签平滑 | 0.1 |
| `--dpo_beta` | DPO beta | 0.1 |
| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 |
| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 |
| `--group_size` | GRPO 组大小 | 4 |
| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 |
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
| `--start_epoch` | 起始轮次(用于断点续训) | 0 |
| `--start_batch` | 起始批次(用于断点续训) | 0 |
| `--nprocs` | GPU 数量 | 1 |
| `--device_type` | 设备类型 | cuda |
完整参数列表见[参数说明](./params.md#training-parameters)。
#### 文本生成 #### 文本生成

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, List, Optional from typing import Callable, Optional
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
@ -74,9 +74,6 @@ class TrainConfig:
) )
# others # others
device_ids: Optional[List[int]] = field(
default=None, metadata={"help": "Device ids for distributed training."}
)
device_type: str = field( device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."} default="cuda", metadata={"help": "Device type for distributed training."}
) )

View File

@ -1,7 +1,7 @@
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from functools import wraps from functools import wraps
from typing import Callable, List, Optional from typing import Callable
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -34,7 +34,6 @@ def setup_parallel(
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
device_ids: Optional[List[int]] = None,
): ):
if dist.is_available() and dist.is_initialized(): if dist.is_available() and dist.is_initialized():
@ -45,16 +44,10 @@ def setup_parallel(
yield None yield None
return return
if device_ids is None: device_id = torch.device(device_type, rank)
device_ids = [i for i in range(world_size)]
effective_rank = rank % len(device_ids)
device_id = torch.device(device_type, device_ids[effective_rank])
rank = device_ids[effective_rank]
os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port os.environ["MASTER_PORT"] = master_port
os.environ["LOCAL_RANK"] = str(rank) os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size) os.environ["WORLD_SIZE"] = str(world_size)
os.environ["LOCAL_DEVICE"] = str(device_id) os.environ["LOCAL_DEVICE"] = str(device_id)
@ -104,7 +97,6 @@ def wrapper_spawn_func(
master_addr: str, master_addr: str,
master_port: str, master_port: str,
device_type: str, device_type: str,
device_ids: List[int],
func: Callable, func: Callable,
kwargs: dict, kwargs: dict,
): ):
@ -116,7 +108,6 @@ def wrapper_spawn_func(
master_addr=master_addr, master_addr=master_addr,
master_port=master_port, master_port=master_port,
device_type=device_type, device_type=device_type,
device_ids=device_ids,
): ):
func(**kwargs) func(**kwargs)
@ -132,7 +123,6 @@ def spawn_parallel_fn(
master_addr: str = "localhost", master_addr: str = "localhost",
master_port: str = "29500", master_port: str = "29500",
device_type: str = "cuda", device_type: str = "cuda",
device_ids: Optional[List[int]] = None,
**kwargs, **kwargs,
): ):
# clear environment variables # clear environment variables
@ -148,8 +138,9 @@ def spawn_parallel_fn(
del os.environ[key] del os.environ[key]
if world_size == 1: if world_size == 1:
device_ids = device_ids or [0] device_id = torch.device(device_type, 0)
device_id = torch.device(device_type, device_ids[0]) os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_DEVICE"] = str(device_id) os.environ["LOCAL_DEVICE"] = str(device_id)
func(**kwargs) func(**kwargs)
@ -161,7 +152,6 @@ def spawn_parallel_fn(
master_addr, master_addr,
master_port, master_port,
device_type, device_type,
device_ids,
func, func,
kwargs, kwargs,
) )

View File

@ -53,7 +53,6 @@ class Trainer:
master_addr=config.master_addr, master_addr=config.master_addr,
master_port=config.master_port, master_port=config.master_port,
device_type=config.device_type, device_type=config.device_type,
device_ids=config.device_ids,
checkpoint=checkpoint, checkpoint=checkpoint,
) )

View File

@ -155,7 +155,7 @@ def parse_args() -> argparse.Namespace:
def ddp_wrap(model: nn.Module): def ddp_wrap(model: nn.Module):
local_rank = get_rank() local_rank = get_rank()
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) model = model.to(dtype=torch.bfloat16)
ddp_model = DDP( ddp_model = DDP(
model, model,
device_ids=[local_rank], device_ids=[local_rank],