diff --git a/README.md b/README.md index 388a8ef..988adca 100644 --- a/README.md +++ b/README.md @@ -68,41 +68,18 @@ pip install -e ".[dev]" #### Train a Model ```bash -python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model +CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \ + --train_type seq \ + --data_root_path /path/to/dataset \ + --param_path /path/to/model \ + --batch_size 4 \ + --accumulation_steps 8 \ + --max_lr 3e-4 \ + --warmup_steps 1000 \ + --n_epoch 1 ``` -| Parameter | Description | Default | -|-----------|-------------|---------| -| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required | -| `--data_root_path` | Dataset root directory | required | -| `--param_path` | Model / checkpoint path | required | -| `--n_epoch` | Training epochs | 1 | -| `--batch_size` | Batch size | 1 | -| `--accumulation_steps` | Gradient accumulation steps | 1 | -| `--warmup_steps` | LR warmup steps | 1000 | -| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 | -| `--max_grad_norm` | Max gradient norm for clipping | 1.0 | -| `--adamw_beta1` | AdamW beta1 | 0.9 | -| `--adamw_beta2` | AdamW beta2 | 0.95 | -| `--adamw_weight_decay` | AdamW weight decay | 0.01 | -| `--random_seed` | Random seed | 3407 | -| `--num_workers` | DataLoader workers | 4 | -| `--window_size` | Max input sequence length | auto | -| `--stride` | Sequence stride | auto | -| `--label_smoothing` | Label smoothing for cross entropy | 0.1 | -| `--dpo_beta` | DPO beta | 0.1 | -| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 | -| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | -| `--group_size` | GRPO group size | 4 | -| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 | -| `--ckpt_interval` | Checkpoint interval (iters) | 5000 | -| `--ckpt_dir` | Checkpoint directory | checkpoint | -| `--start_epoch` | Start epoch (for resume) | 0 | -| `--start_batch` | Start batch (for resume) | 0 | -| `--nprocs` | Number of GPUs | 1 | -| `--device_type` | Device type | cuda | - -Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters). +Full reference at [Parameter Guide](assets/docs/params.md). #### Generate Text diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index d9298bf..f2683b2 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -74,41 +74,18 @@ pip install -e ".[dev]" #### 训练模型 ```bash -python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset --param_path=/path/to/model +CUDA_VISIBLE_DEVICES=0,1,2,3 python scripts/tools/train.py \ + --train_type seq \ + --data_root_path /path/to/dataset \ + --param_path /path/to/model \ + --batch_size 4 \ + --accumulation_steps 8 \ + --max_lr 3e-4 \ + --warmup_steps 1000 \ + --n_epoch 1 ``` -| 参数 | 说明 | 默认值 | -|------|------|--------| -| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo`) | 必填 | -| `--data_root_path` | 数据集根目录 | 必填 | -| `--param_path` | 模型参数或断点路径 | 必填 | -| `--n_epoch` | 训练轮数 | 1 | -| `--batch_size` | 批次大小 | 1 | -| `--accumulation_steps` | 梯度累积步数 | 1 | -| `--warmup_steps` | 预热步数 | 1000 | -| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 | -| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 | -| `--adamw_beta1` | AdamW beta1 | 0.9 | -| `--adamw_beta2` | AdamW beta2 | 0.95 | -| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 | -| `--random_seed` | 随机种子 | 3407 | -| `--num_workers` | 数据加载线程数 | 4 | -| `--window_size` | 最大输入序列长度 | auto | -| `--stride` | 序列步长 | auto | -| `--label_smoothing` | 交叉熵标签平滑 | 0.1 | -| `--dpo_beta` | DPO beta | 0.1 | -| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 | -| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 | -| `--group_size` | GRPO 组大小 | 4 | -| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 | -| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 | -| `--ckpt_dir` | 检查点保存目录 | checkpoint | -| `--start_epoch` | 起始轮次(用于断点续训) | 0 | -| `--start_batch` | 起始批次(用于断点续训) | 0 | -| `--nprocs` | GPU 数量 | 1 | -| `--device_type` | 设备类型 | cuda | - -完整参数列表见[参数说明](./params.md#training-parameters)。 +完整参数列表见[参数说明](./params.md)。 #### 文本生成 diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 8cb9dc3..a41a23a 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Callable, List, Optional +from typing import Callable, Optional import torch.nn as nn from torch.optim import Optimizer @@ -74,9 +74,6 @@ class TrainConfig: ) # others - device_ids: Optional[List[int]] = field( - default=None, metadata={"help": "Device ids for distributed training."} - ) device_type: str = field( default="cuda", metadata={"help": "Device type for distributed training."} ) diff --git a/astrai/parallel/setup.py b/astrai/parallel/setup.py index b00a3ed..3128a67 100644 --- a/astrai/parallel/setup.py +++ b/astrai/parallel/setup.py @@ -1,7 +1,7 @@ import os from contextlib import contextmanager from functools import wraps -from typing import Callable, List, Optional +from typing import Callable import torch import torch.distributed as dist @@ -34,7 +34,6 @@ def setup_parallel( master_addr: str = "localhost", master_port: str = "29500", device_type: str = "cuda", - device_ids: Optional[List[int]] = None, ): if dist.is_available() and dist.is_initialized(): @@ -45,16 +44,10 @@ def setup_parallel( yield None return - if device_ids is None: - device_ids = [i for i in range(world_size)] - - effective_rank = rank % len(device_ids) - device_id = torch.device(device_type, device_ids[effective_rank]) - rank = device_ids[effective_rank] + device_id = torch.device(device_type, rank) os.environ["MASTER_ADDR"] = master_addr os.environ["MASTER_PORT"] = master_port - os.environ["LOCAL_RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_DEVICE"] = str(device_id) @@ -104,7 +97,6 @@ def wrapper_spawn_func( master_addr: str, master_port: str, device_type: str, - device_ids: List[int], func: Callable, kwargs: dict, ): @@ -116,7 +108,6 @@ def wrapper_spawn_func( master_addr=master_addr, master_port=master_port, device_type=device_type, - device_ids=device_ids, ): func(**kwargs) @@ -132,7 +123,6 @@ def spawn_parallel_fn( master_addr: str = "localhost", master_port: str = "29500", device_type: str = "cuda", - device_ids: Optional[List[int]] = None, **kwargs, ): # clear environment variables @@ -148,8 +138,9 @@ def spawn_parallel_fn( del os.environ[key] if world_size == 1: - device_ids = device_ids or [0] - device_id = torch.device(device_type, device_ids[0]) + device_id = torch.device(device_type, 0) + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = "1" os.environ["LOCAL_DEVICE"] = str(device_id) func(**kwargs) @@ -161,7 +152,6 @@ def spawn_parallel_fn( master_addr, master_port, device_type, - device_ids, func, kwargs, ) diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 8b688f3..b138e21 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -53,7 +53,6 @@ class Trainer: master_addr=config.master_addr, master_port=config.master_port, device_type=config.device_type, - device_ids=config.device_ids, checkpoint=checkpoint, ) diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 07b9452..bdc4067 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -155,7 +155,7 @@ def parse_args() -> argparse.Namespace: def ddp_wrap(model: nn.Module): local_rank = get_rank() - model = model.to(device=f"cuda:{local_rank}", dtype=torch.bfloat16) + model = model.to(dtype=torch.bfloat16) ddp_model = DDP( model, device_ids=[local_rank],