refactor: 移除 device_ids 参数设计,统一通过 CUDA_VISIBLE_DEVICES 控制 GPU 分配;更新 README 训练示例
- setup.py: 移除 device_ids 参数,setup_parallel 直接用 rank 作为设备索引 - train_config.py: 移除 device_ids 字段 - trainer.py: 不再传递 device_ids - train.py: ddp_wrap 用 get_rank() 直接取值 - README.md, README-zh-CN.md: 训练示例改为多行命令风格,去掉参数表格
This commit is contained in:
parent
283bcaf2ff
commit
b98c9cefdc
43
README.md
43
README.md
|
|
@ -68,41 +68,18 @@ pip install -e ".[dev]"
|
|||
#### Train a Model
|
||||
|
||||
```bash
|
||||
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 1000 \
|
||||
--n_epoch 1
|
||||
```
|
||||
|
||||
| Parameter | Description | Default |
|
||||
|-----------|-------------|---------|
|
||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
||||
| `--data_root_path` | Dataset root directory | required |
|
||||
| `--param_path` | Model / checkpoint path | required |
|
||||
| `--n_epoch` | Training epochs | 1 |
|
||||
| `--batch_size` | Batch size | 1 |
|
||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
||||
| `--warmup_steps` | LR warmup steps | 1000 |
|
||||
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
||||
| `--max_grad_norm` | Max gradient norm for clipping | 1.0 |
|
||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||
| `--random_seed` | Random seed | 3407 |
|
||||
| `--num_workers` | DataLoader workers | 4 |
|
||||
| `--window_size` | Max input sequence length | auto |
|
||||
| `--stride` | Sequence stride | auto |
|
||||
| `--label_smoothing` | Label smoothing for cross entropy | 0.1 |
|
||||
| `--dpo_beta` | DPO beta | 0.1 |
|
||||
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
||||
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 |
|
||||
| `--group_size` | GRPO group size | 4 |
|
||||
| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 |
|
||||
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
||||
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
||||
| `--start_epoch` | Start epoch (for resume) | 0 |
|
||||
| `--start_batch` | Start batch (for resume) | 0 |
|
||||
| `--nprocs` | Number of GPUs | 1 |
|
||||
| `--device_type` | Device type | cuda |
|
||||
|
||||
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
||||
Full reference at [Parameter Guide](assets/docs/params.md).
|
||||
|
||||
#### Generate Text
|
||||
|
||||
|
|
|
|||
|
|
@ -74,41 +74,18 @@ pip install -e ".[dev]"
|
|||
#### 训练模型
|
||||
|
||||
```bash
|
||||
python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \
|
||||
--train_type seq \
|
||||
--data_root_path /path/to/dataset \
|
||||
--param_path /path/to/model \
|
||||
--batch_size 4 \
|
||||
--accumulation_steps 8 \
|
||||
--max_lr 3e-4 \
|
||||
--warmup_steps 1000 \
|
||||
--n_epoch 1
|
||||
```
|
||||
|
||||
| 参数 | 说明 | 默认值 |
|
||||
|------|------|--------|
|
||||
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo`) | 必填 |
|
||||
| `--data_root_path` | 数据集根目录 | 必填 |
|
||||
| `--param_path` | 模型参数或断点路径 | 必填 |
|
||||
| `--n_epoch` | 训练轮数 | 1 |
|
||||
| `--batch_size` | 批次大小 | 1 |
|
||||
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
||||
| `--warmup_steps` | 预热步数 | 1000 |
|
||||
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
||||
| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 |
|
||||
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||
| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 |
|
||||
| `--random_seed` | 随机种子 | 3407 |
|
||||
| `--num_workers` | 数据加载线程数 | 4 |
|
||||
| `--window_size` | 最大输入序列长度 | auto |
|
||||
| `--stride` | 序列步长 | auto |
|
||||
| `--label_smoothing` | 交叉熵标签平滑 | 0.1 |
|
||||
| `--dpo_beta` | DPO beta | 0.1 |
|
||||
| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 |
|
||||
| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 |
|
||||
| `--group_size` | GRPO 组大小 | 4 |
|
||||
| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 |
|
||||
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
||||
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
||||
| `--start_epoch` | 起始轮次(用于断点续训) | 0 |
|
||||
| `--start_batch` | 起始批次(用于断点续训) | 0 |
|
||||
| `--nprocs` | GPU 数量 | 1 |
|
||||
| `--device_type` | 设备类型 | cuda |
|
||||
|
||||
完整参数列表见[参数说明](./params.md#training-parameters)。
|
||||
完整参数列表见[参数说明](./params.md)。
|
||||
|
||||
#### 文本生成
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
|
@ -74,9 +74,6 @@ class TrainConfig:
|
|||
)
|
||||
|
||||
# others
|
||||
device_ids: Optional[List[int]] = field(
|
||||
default=None, metadata={"help": "Device ids for distributed training."}
|
||||
)
|
||||
device_type: str = field(
|
||||
default="cuda", metadata={"help": "Device type for distributed training."}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -34,7 +34,6 @@ def setup_parallel(
|
|||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
device_type: str = "cuda",
|
||||
device_ids: Optional[List[int]] = None,
|
||||
):
|
||||
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
|
|
@ -45,16 +44,10 @@ def setup_parallel(
|
|||
yield None
|
||||
return
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = [i for i in range(world_size)]
|
||||
|
||||
effective_rank = rank % len(device_ids)
|
||||
device_id = torch.device(device_type, device_ids[effective_rank])
|
||||
rank = device_ids[effective_rank]
|
||||
device_id = torch.device(device_type, rank)
|
||||
|
||||
os.environ["MASTER_ADDR"] = master_addr
|
||||
os.environ["MASTER_PORT"] = master_port
|
||||
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
|
@ -104,7 +97,6 @@ def wrapper_spawn_func(
|
|||
master_addr: str,
|
||||
master_port: str,
|
||||
device_type: str,
|
||||
device_ids: List[int],
|
||||
func: Callable,
|
||||
kwargs: dict,
|
||||
):
|
||||
|
|
@ -116,7 +108,6 @@ def wrapper_spawn_func(
|
|||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
device_type=device_type,
|
||||
device_ids=device_ids,
|
||||
):
|
||||
func(**kwargs)
|
||||
|
||||
|
|
@ -132,7 +123,6 @@ def spawn_parallel_fn(
|
|||
master_addr: str = "localhost",
|
||||
master_port: str = "29500",
|
||||
device_type: str = "cuda",
|
||||
device_ids: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# clear environment variables
|
||||
|
|
@ -148,8 +138,9 @@ def spawn_parallel_fn(
|
|||
del os.environ[key]
|
||||
|
||||
if world_size == 1:
|
||||
device_ids = device_ids or [0]
|
||||
device_id = torch.device(device_type, device_ids[0])
|
||||
device_id = torch.device(device_type, 0)
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["LOCAL_DEVICE"] = str(device_id)
|
||||
|
||||
func(**kwargs)
|
||||
|
|
@ -161,7 +152,6 @@ def spawn_parallel_fn(
|
|||
master_addr,
|
||||
master_port,
|
||||
device_type,
|
||||
device_ids,
|
||||
func,
|
||||
kwargs,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -53,7 +53,6 @@ class Trainer:
|
|||
master_addr=config.master_addr,
|
||||
master_port=config.master_port,
|
||||
device_type=config.device_type,
|
||||
device_ids=config.device_ids,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -155,7 +155,7 @@ def parse_args() -> argparse.Namespace:
|
|||
|
||||
def ddp_wrap(model: nn.Module):
|
||||
local_rank = get_rank()
|
||||
model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16)
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
ddp_model = DDP(
|
||||
model,
|
||||
device_ids=[local_rank],
|
||||
|
|
|
|||
Loading…
Reference in New Issue