feat: GRPO CLI 接入 + on-policy,OpenAI API top_k 参数化,补充训练参数表
- train.py 新增 --train_type=grpo 及参数 (--grpo_clip_eps, --grpo_kl_coef, --group_size, --grpo_sync_interval, --start_epoch) - GRPOStrategy 统一 on-policy 模式,ratio = exp(logπ_θ - logπ_ref),PPO 裁剪目标,sync_interval 自动同步 ref_model - ChatCompletionRequest 新增 top_k 参数,不再硬编码 - 补充 README 完整训练参数表(含此前缺失的 max_grad_norm / adamw / window_size / stride 等)
This commit is contained in:
parent
34a511e36e
commit
bc7c82977e
22
README.md
22
README.md
|
|
@ -73,18 +73,34 @@ python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset
|
||||||
|
|
||||||
| Parameter | Description | Default |
|
| Parameter | Description | Default |
|
||||||
|-----------|-------------|---------|
|
|-----------|-------------|---------|
|
||||||
| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required |
|
| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required |
|
||||||
| `--data_root_path` | Dataset root directory | required |
|
| `--data_root_path` | Dataset root directory | required |
|
||||||
| `--param_path` | Model / checkpoint path | required |
|
| `--param_path` | Model / checkpoint path | required |
|
||||||
| `--n_epoch` | Training epochs | 1 |
|
| `--n_epoch` | Training epochs | 1 |
|
||||||
| `--batch_size` | Batch size | 1 |
|
| `--batch_size` | Batch size | 1 |
|
||||||
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
| `--accumulation_steps` | Gradient accumulation steps | 1 |
|
||||||
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
|
||||||
| `--warmup_steps` | LR warmup steps | 1000 |
|
| `--warmup_steps` | LR warmup steps | 1000 |
|
||||||
|
| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 |
|
||||||
|
| `--max_grad_norm` | Max gradient norm for clipping | 1.0 |
|
||||||
|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||||
|
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||||
|
| `--adamw_weight_decay` | AdamW weight decay | 0.01 |
|
||||||
|
| `--random_seed` | Random seed | 3407 |
|
||||||
|
| `--num_workers` | DataLoader workers | 4 |
|
||||||
|
| `--window_size` | Max input sequence length | auto |
|
||||||
|
| `--stride` | Sequence stride | auto |
|
||||||
|
| `--label_smoothing` | Label smoothing for cross entropy | 0.1 |
|
||||||
|
| `--dpo_beta` | DPO beta | 0.1 |
|
||||||
|
| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 |
|
||||||
|
| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 |
|
||||||
|
| `--group_size` | GRPO group size | 4 |
|
||||||
|
| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 |
|
||||||
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
| `--ckpt_interval` | Checkpoint interval (iters) | 5000 |
|
||||||
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
| `--ckpt_dir` | Checkpoint directory | checkpoint |
|
||||||
| `--num_workers` | DataLoader workers | 4 |
|
| `--start_epoch` | Start epoch (for resume) | 0 |
|
||||||
|
| `--start_batch` | Start batch (for resume) | 0 |
|
||||||
| `--nprocs` | Number of GPUs | 1 |
|
| `--nprocs` | Number of GPUs | 1 |
|
||||||
|
| `--device_type` | Device type | cuda |
|
||||||
|
|
||||||
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters).
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,18 +79,34 @@ python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset
|
||||||
|
|
||||||
| 参数 | 说明 | 默认值 |
|
| 参数 | 说明 | 默认值 |
|
||||||
|------|------|--------|
|
|------|------|--------|
|
||||||
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 |
|
| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo`) | 必填 |
|
||||||
| `--data_root_path` | 数据集根目录 | 必填 |
|
| `--data_root_path` | 数据集根目录 | 必填 |
|
||||||
| `--param_path` | 模型参数或断点路径 | 必填 |
|
| `--param_path` | 模型参数或断点路径 | 必填 |
|
||||||
| `--n_epoch` | 训练轮数 | 1 |
|
| `--n_epoch` | 训练轮数 | 1 |
|
||||||
| `--batch_size` | 批次大小 | 1 |
|
| `--batch_size` | 批次大小 | 1 |
|
||||||
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
| `--accumulation_steps` | 梯度累积步数 | 1 |
|
||||||
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
|
||||||
| `--warmup_steps` | 预热步数 | 1000 |
|
| `--warmup_steps` | 预热步数 | 1000 |
|
||||||
|
| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 |
|
||||||
|
| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 |
|
||||||
|
| `--adamw_beta1` | AdamW beta1 | 0.9 |
|
||||||
|
| `--adamw_beta2` | AdamW beta2 | 0.95 |
|
||||||
|
| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 |
|
||||||
|
| `--random_seed` | 随机种子 | 3407 |
|
||||||
|
| `--num_workers` | 数据加载线程数 | 4 |
|
||||||
|
| `--window_size` | 最大输入序列长度 | auto |
|
||||||
|
| `--stride` | 序列步长 | auto |
|
||||||
|
| `--label_smoothing` | 交叉熵标签平滑 | 0.1 |
|
||||||
|
| `--dpo_beta` | DPO beta | 0.1 |
|
||||||
|
| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 |
|
||||||
|
| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 |
|
||||||
|
| `--group_size` | GRPO 组大小 | 4 |
|
||||||
|
| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 |
|
||||||
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
| `--ckpt_interval` | 检查点间隔(迭代步) | 5000 |
|
||||||
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
| `--ckpt_dir` | 检查点保存目录 | checkpoint |
|
||||||
| `--num_workers` | 数据加载线程数 | 4 |
|
| `--start_epoch` | 起始轮次(用于断点续训) | 0 |
|
||||||
|
| `--start_batch` | 起始批次(用于断点续训) | 0 |
|
||||||
| `--nprocs` | GPU 数量 | 1 |
|
| `--nprocs` | GPU 数量 | 1 |
|
||||||
|
| `--device_type` | 设备类型 | cuda |
|
||||||
|
|
||||||
完整参数列表见[参数说明](./params.md#training-parameters)。
|
完整参数列表见[参数说明](./params.md#training-parameters)。
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,7 @@ class ChatCompletionRequest(BaseModel):
|
||||||
messages: List[ChatMessage]
|
messages: List[ChatMessage]
|
||||||
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0)
|
||||||
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0)
|
||||||
|
top_k: Optional[int] = Field(default=50, ge=1)
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
max_tokens: Optional[int] = Field(default=2048, ge=1)
|
||||||
|
|
@ -204,7 +205,7 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=50,
|
top_k=request.top_k,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def event_stream():
|
async def event_stream():
|
||||||
|
|
@ -256,7 +257,7 @@ async def chat_completion(request: ChatCompletionRequest):
|
||||||
max_tokens=request.max_tokens,
|
max_tokens=request.max_tokens,
|
||||||
temperature=request.temperature,
|
temperature=request.temperature,
|
||||||
top_p=request.top_p,
|
top_p=request.top_p,
|
||||||
top_k=50,
|
top_k=request.top_k,
|
||||||
)
|
)
|
||||||
async for token in agen:
|
async for token in agen:
|
||||||
chunks.append(token)
|
chunks.append(token)
|
||||||
|
|
|
||||||
|
|
@ -265,7 +265,9 @@ class DPOStrategy(BaseStrategy):
|
||||||
class GRPOStrategy(BaseStrategy):
|
class GRPOStrategy(BaseStrategy):
|
||||||
"""Group Relative Policy Optimization strategy.
|
"""Group Relative Policy Optimization strategy.
|
||||||
|
|
||||||
Implements GRPO with clipping and KL penalty.
|
On-policy GRPO following DeepSeek-R1: the policy model is updated while
|
||||||
|
a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref),
|
||||||
|
clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -276,6 +278,7 @@ class GRPOStrategy(BaseStrategy):
|
||||||
kl_coef: float = 0.01,
|
kl_coef: float = 0.01,
|
||||||
group_size: int = 4,
|
group_size: int = 4,
|
||||||
reduction: str = "mean",
|
reduction: str = "mean",
|
||||||
|
sync_interval: int = 200,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(model, device, **kwargs)
|
super().__init__(model, device, **kwargs)
|
||||||
|
|
@ -284,8 +287,19 @@ class GRPOStrategy(BaseStrategy):
|
||||||
self.kl_coef = kl_coef
|
self.kl_coef = kl_coef
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
self.sync_interval = sync_interval
|
||||||
|
self._step = 0
|
||||||
|
|
||||||
|
def sync_ref_model(self):
|
||||||
|
"""Copy current model weights to ref model."""
|
||||||
|
ref_state = self.model.state_dict()
|
||||||
|
self.ref_model.load_state_dict(ref_state)
|
||||||
|
|
||||||
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor:
|
||||||
|
self._step += 1
|
||||||
|
if self._step % self.sync_interval == 0:
|
||||||
|
self.sync_ref_model()
|
||||||
|
|
||||||
batch = move_to_device(batch, self.device)
|
batch = move_to_device(batch, self.device)
|
||||||
prompts = batch["prompts"]
|
prompts = batch["prompts"]
|
||||||
responses = batch["responses"]
|
responses = batch["responses"]
|
||||||
|
|
@ -297,7 +311,6 @@ class GRPOStrategy(BaseStrategy):
|
||||||
masks_flat = masks.view(-1, response_len)
|
masks_flat = masks.view(-1, response_len)
|
||||||
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1)
|
||||||
|
|
||||||
# Shape: (batch_size * group_size, seq_len + response_len)
|
|
||||||
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1)
|
||||||
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1)
|
||||||
|
|
||||||
|
|
@ -312,14 +325,13 @@ class GRPOStrategy(BaseStrategy):
|
||||||
)
|
)
|
||||||
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
log_probs_ref = log_probs_ref.view(batch_size, group_size)
|
||||||
|
|
||||||
# Compute advantages from rewards with normalization
|
|
||||||
eps = torch.finfo(log_probs_policy.dtype).eps
|
eps = torch.finfo(log_probs_policy.dtype).eps
|
||||||
mean = rewards.mean(dim=-1, keepdim=True)
|
mean = rewards.mean(dim=-1, keepdim=True)
|
||||||
std = rewards.std(dim=-1, keepdim=True)
|
std = rewards.std(dim=-1, keepdim=True)
|
||||||
advantages = (rewards - mean) / (std + eps)
|
advantages = (rewards - mean) / (std + eps)
|
||||||
|
|
||||||
# PPO-style clipped surrogate objective
|
ratio = torch.exp(log_probs_policy - log_probs_ref)
|
||||||
ratio = torch.exp(0) # Off-policy: policy_model = old_model
|
|
||||||
surr1 = ratio * advantages
|
surr1 = ratio * advantages
|
||||||
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--train_type",
|
"--train_type",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
choices=["seq", "sft", "dpo"],
|
choices=["seq", "sft", "dpo", "grpo"],
|
||||||
help="Train type.",
|
help="Train type.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -42,9 +42,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
"--n_epoch", type=int, default=1, help="Number of epochs to train."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||||
"--batch_size", type=int, default=1, help="Batch size for training."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--accumulation_steps",
|
"--accumulation_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
@ -106,6 +104,17 @@ def parse_args() -> argparse.Namespace:
|
||||||
"--stride", type=int, default=None, help="the step size of the input sequence."
|
"--stride", type=int, default=None, help="the step size of the input sequence."
|
||||||
)
|
)
|
||||||
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.")
|
||||||
|
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--on_policy",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Enable on-policy GRPO mode.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient."
|
||||||
|
)
|
||||||
|
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_smoothing",
|
"--label_smoothing",
|
||||||
type=float,
|
type=float,
|
||||||
|
|
@ -125,6 +134,13 @@ def parse_args() -> argparse.Namespace:
|
||||||
default="checkpoint",
|
default="checkpoint",
|
||||||
help="Directory to save checkpoints.",
|
help="Directory to save checkpoints.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--grpo_sync_interval",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="GRPO ref model sync interval (steps).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--start_epoch", type=int, default=0, help="Start epoch for training."
|
"--start_epoch", type=int, default=0, help="Start epoch for training."
|
||||||
)
|
)
|
||||||
|
|
@ -182,6 +198,10 @@ def train(
|
||||||
ckpt_interval: int,
|
ckpt_interval: int,
|
||||||
ckpt_dir: str,
|
ckpt_dir: str,
|
||||||
dpo_beta: float,
|
dpo_beta: float,
|
||||||
|
grpo_clip_eps: float,
|
||||||
|
grpo_kl_coef: float,
|
||||||
|
group_size: int,
|
||||||
|
grpo_sync_interval: int,
|
||||||
adamw_beta1: float,
|
adamw_beta1: float,
|
||||||
adamw_beta2: float,
|
adamw_beta2: float,
|
||||||
adamw_weight_decay: float,
|
adamw_weight_decay: float,
|
||||||
|
|
@ -195,7 +215,7 @@ def train(
|
||||||
nprocs: int,
|
nprocs: int,
|
||||||
device_type: str,
|
device_type: str,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo"]
|
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
||||||
# Load config
|
# Load config
|
||||||
|
|
@ -216,7 +236,14 @@ def train(
|
||||||
state_dict = st.load_file(weights_path)
|
state_dict = st.load_file(weights_path)
|
||||||
model.load_state_dict(state_dict, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing}
|
strategy_kwargs = {
|
||||||
|
"dpo_beta": dpo_beta,
|
||||||
|
"label_smoothing": label_smoothing,
|
||||||
|
"clip_eps": grpo_clip_eps,
|
||||||
|
"kl_coef": grpo_kl_coef,
|
||||||
|
"group_size": group_size,
|
||||||
|
"sync_interval": grpo_sync_interval,
|
||||||
|
}
|
||||||
|
|
||||||
dataset = DatasetFactory.load(
|
dataset = DatasetFactory.load(
|
||||||
train_type=train_type,
|
train_type=train_type,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue