From bc7c82977ee6725e98f909aa9260540fe811289d Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 9 May 2026 12:22:33 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20GRPO=20CLI=20=E6=8E=A5=E5=85=A5=20+=20o?= =?UTF-8?q?n-policy=EF=BC=8COpenAI=20API=20top=5Fk=20=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E5=8C=96=EF=BC=8C=E8=A1=A5=E5=85=85=E8=AE=AD=E7=BB=83=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - train.py 新增 --train_type=grpo 及参数 (--grpo_clip_eps, --grpo_kl_coef, --group_size, --grpo_sync_interval, --start_epoch) - GRPOStrategy 统一 on-policy 模式,ratio = exp(logπ_θ - logπ_ref),PPO 裁剪目标,sync_interval 自动同步 ref_model - ChatCompletionRequest 新增 top_k 参数,不再硬编码 - 补充 README 完整训练参数表(含此前缺失的 max_grad_norm / adamw / window_size / stride 等) --- README.md | 22 ++++++++++++++++++--- assets/docs/README-zh-CN.md | 22 ++++++++++++++++++--- astrai/inference/server.py | 5 +++-- astrai/trainer/strategy.py | 22 ++++++++++++++++----- scripts/tools/train.py | 39 +++++++++++++++++++++++++++++++------ 5 files changed, 91 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 295b479..388a8ef 100644 --- a/README.md +++ b/README.md @@ -73,18 +73,34 @@ python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset | Parameter | Description | Default | |-----------|-------------|---------| -| `--train_type` | Training type (`seq`, `sft`, `dpo`) | required | +| `--train_type` | Training type (`seq`, `sft`, `dpo`, `grpo`) | required | | `--data_root_path` | Dataset root directory | required | | `--param_path` | Model / checkpoint path | required | | `--n_epoch` | Training epochs | 1 | | `--batch_size` | Batch size | 1 | | `--accumulation_steps` | Gradient accumulation steps | 1 | -| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 | | `--warmup_steps` | LR warmup steps | 1000 | +| `--max_lr` | Peak learning rate (cosine decay) | 3e-4 | +| `--max_grad_norm` | Max gradient norm for clipping | 1.0 | +| `--adamw_beta1` | AdamW beta1 | 0.9 | +| `--adamw_beta2` | AdamW beta2 | 0.95 | +| `--adamw_weight_decay` | AdamW weight decay | 0.01 | +| `--random_seed` | Random seed | 3407 | +| `--num_workers` | DataLoader workers | 4 | +| `--window_size` | Max input sequence length | auto | +| `--stride` | Sequence stride | auto | +| `--label_smoothing` | Label smoothing for cross entropy | 0.1 | +| `--dpo_beta` | DPO beta | 0.1 | +| `--grpo_clip_eps` | GRPO clip epsilon | 0.2 | +| `--grpo_kl_coef` | GRPO KL penalty coefficient | 0.01 | +| `--group_size` | GRPO group size | 4 | +| `--grpo_sync_interval` | GRPO ref model sync interval (steps) | 200 | | `--ckpt_interval` | Checkpoint interval (iters) | 5000 | | `--ckpt_dir` | Checkpoint directory | checkpoint | -| `--num_workers` | DataLoader workers | 4 | +| `--start_epoch` | Start epoch (for resume) | 0 | +| `--start_batch` | Start batch (for resume) | 0 | | `--nprocs` | Number of GPUs | 1 | +| `--device_type` | Device type | cuda | Full reference at [Parameter Guide](./assets/docs/params.md#training-parameters). diff --git a/assets/docs/README-zh-CN.md b/assets/docs/README-zh-CN.md index c0256f4..d9298bf 100644 --- a/assets/docs/README-zh-CN.md +++ b/assets/docs/README-zh-CN.md @@ -79,18 +79,34 @@ python scripts/tools/train.py --train_type=seq --data_root_path=/path/to/dataset | 参数 | 说明 | 默认值 | |------|------|--------| -| `--train_type` | 训练类型(`seq`, `sft`, `dpo`) | 必填 | +| `--train_type` | 训练类型(`seq`, `sft`, `dpo`, `grpo`) | 必填 | | `--data_root_path` | 数据集根目录 | 必填 | | `--param_path` | 模型参数或断点路径 | 必填 | | `--n_epoch` | 训练轮数 | 1 | | `--batch_size` | 批次大小 | 1 | | `--accumulation_steps` | 梯度累积步数 | 1 | -| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 | | `--warmup_steps` | 预热步数 | 1000 | +| `--max_lr` | 峰值学习率(余弦衰减) | 3e-4 | +| `--max_grad_norm` | 梯度裁剪最大值 | 1.0 | +| `--adamw_beta1` | AdamW beta1 | 0.9 | +| `--adamw_beta2` | AdamW beta2 | 0.95 | +| `--adamw_weight_decay` | AdamW 权重衰减 | 0.01 | +| `--random_seed` | 随机种子 | 3407 | +| `--num_workers` | 数据加载线程数 | 4 | +| `--window_size` | 最大输入序列长度 | auto | +| `--stride` | 序列步长 | auto | +| `--label_smoothing` | 交叉熵标签平滑 | 0.1 | +| `--dpo_beta` | DPO beta | 0.1 | +| `--grpo_clip_eps` | GRPO 裁剪 epsilon | 0.2 | +| `--grpo_kl_coef` | GRPO KL 惩罚系数 | 0.01 | +| `--group_size` | GRPO 组大小 | 4 | +| `--grpo_sync_interval` | GRPO ref_model 同步间隔(步) | 200 | | `--ckpt_interval` | 检查点间隔(迭代步) | 5000 | | `--ckpt_dir` | 检查点保存目录 | checkpoint | -| `--num_workers` | 数据加载线程数 | 4 | +| `--start_epoch` | 起始轮次(用于断点续训) | 0 | +| `--start_batch` | 起始批次(用于断点续训) | 0 | | `--nprocs` | GPU 数量 | 1 | +| `--device_type` | 设备类型 | cuda | 完整参数列表见[参数说明](./params.md#training-parameters)。 diff --git a/astrai/inference/server.py b/astrai/inference/server.py index e1e7d37..7216eaa 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -51,6 +51,7 @@ class ChatCompletionRequest(BaseModel): messages: List[ChatMessage] temperature: Optional[float] = Field(default=1.0, ge=0.0, le=2.0) top_p: Optional[float] = Field(default=1.0, ge=0.0, le=1.0) + top_k: Optional[int] = Field(default=50, ge=1) stream: Optional[bool] = False stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = Field(default=2048, ge=1) @@ -204,7 +205,7 @@ async def chat_completion(request: ChatCompletionRequest): max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, - top_k=50, + top_k=request.top_k, ) async def event_stream(): @@ -256,7 +257,7 @@ async def chat_completion(request: ChatCompletionRequest): max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, - top_k=50, + top_k=request.top_k, ) async for token in agen: chunks.append(token) diff --git a/astrai/trainer/strategy.py b/astrai/trainer/strategy.py index eb5a6e5..3789e11 100644 --- a/astrai/trainer/strategy.py +++ b/astrai/trainer/strategy.py @@ -265,7 +265,9 @@ class DPOStrategy(BaseStrategy): class GRPOStrategy(BaseStrategy): """Group Relative Policy Optimization strategy. - Implements GRPO with clipping and KL penalty. + On-policy GRPO following DeepSeek-R1: the policy model is updated while + a frozen ref_model stores the old-policy log-probs. ratio = exp(logπ_θ - logπ_ref), + clipped PPO objective. Call ``sync_ref_model()`` after each data-generation round. """ def __init__( @@ -276,6 +278,7 @@ class GRPOStrategy(BaseStrategy): kl_coef: float = 0.01, group_size: int = 4, reduction: str = "mean", + sync_interval: int = 200, **kwargs, ): super().__init__(model, device, **kwargs) @@ -284,8 +287,19 @@ class GRPOStrategy(BaseStrategy): self.kl_coef = kl_coef self.group_size = group_size self.reduction = reduction + self.sync_interval = sync_interval + self._step = 0 + + def sync_ref_model(self): + """Copy current model weights to ref model.""" + ref_state = self.model.state_dict() + self.ref_model.load_state_dict(ref_state) def compute_loss(self, batch: Dict[str, Tensor]) -> Tensor: + self._step += 1 + if self._step % self.sync_interval == 0: + self.sync_ref_model() + batch = move_to_device(batch, self.device) prompts = batch["prompts"] responses = batch["responses"] @@ -297,7 +311,6 @@ class GRPOStrategy(BaseStrategy): masks_flat = masks.view(-1, response_len) prompt_expanded = prompts.unsqueeze(1).repeat(1, group_size, 1).flatten(0, 1) - # Shape: (batch_size * group_size, seq_len + response_len) full_sequences = torch.cat([prompt_expanded, responses_flat], dim=-1) full_masks = torch.cat([torch.ones_like(prompt_expanded), masks_flat], dim=-1) @@ -312,14 +325,13 @@ class GRPOStrategy(BaseStrategy): ) log_probs_ref = log_probs_ref.view(batch_size, group_size) - # Compute advantages from rewards with normalization eps = torch.finfo(log_probs_policy.dtype).eps mean = rewards.mean(dim=-1, keepdim=True) std = rewards.std(dim=-1, keepdim=True) advantages = (rewards - mean) / (std + eps) - # PPO-style clipped surrogate objective - ratio = torch.exp(0) # Off-policy: policy_model = old_model + ratio = torch.exp(log_probs_policy - log_probs_ref) + surr1 = ratio * advantages surr2 = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages diff --git a/scripts/tools/train.py b/scripts/tools/train.py index eb43382..5268c02 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -23,7 +23,7 @@ def parse_args() -> argparse.Namespace: "--train_type", type=str, required=True, - choices=["seq", "sft", "dpo"], + choices=["seq", "sft", "dpo", "grpo"], help="Train type.", ) parser.add_argument( @@ -42,9 +42,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--n_epoch", type=int, default=1, help="Number of epochs to train." ) - parser.add_argument( - "--batch_size", type=int, default=1, help="Batch size for training." - ) + parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") parser.add_argument( "--accumulation_steps", type=int, @@ -106,6 +104,17 @@ def parse_args() -> argparse.Namespace: "--stride", type=int, default=None, help="the step size of the input sequence." ) parser.add_argument("--dpo_beta", type=float, default=0.1, help="DPO beta value.") + parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") + parser.add_argument( + "--on_policy", + action="store_true", + default=False, + help="Enable on-policy GRPO mode.", + ) + parser.add_argument( + "--grpo_kl_coef", type=float, default=0.01, help="GRPO KL penalty coefficient." + ) + parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") parser.add_argument( "--label_smoothing", type=float, @@ -125,6 +134,13 @@ def parse_args() -> argparse.Namespace: default="checkpoint", help="Directory to save checkpoints.", ) + parser.add_argument("--group_size", type=int, default=4, help="GRPO group size.") + parser.add_argument( + "--grpo_sync_interval", + type=int, + default=200, + help="GRPO ref model sync interval (steps).", + ) parser.add_argument( "--start_epoch", type=int, default=0, help="Start epoch for training." ) @@ -182,6 +198,10 @@ def train( ckpt_interval: int, ckpt_dir: str, dpo_beta: float, + grpo_clip_eps: float, + grpo_kl_coef: float, + group_size: int, + grpo_sync_interval: int, adamw_beta1: float, adamw_beta2: float, adamw_weight_decay: float, @@ -195,7 +215,7 @@ def train( nprocs: int, device_type: str, ): - assert train_type in ["seq", "sft", "dpo"] + assert train_type in ["seq", "sft", "dpo", "grpo"] assert os.path.exists(param_path) # Load config @@ -216,7 +236,14 @@ def train( state_dict = st.load_file(weights_path) model.load_state_dict(state_dict, strict=False) - strategy_kwargs = {"dpo_beta": dpo_beta, "label_smoothing": label_smoothing} + strategy_kwargs = { + "dpo_beta": dpo_beta, + "label_smoothing": label_smoothing, + "clip_eps": grpo_clip_eps, + "kl_coef": grpo_kl_coef, + "group_size": group_size, + "sync_interval": grpo_sync_interval, + } dataset = DatasetFactory.load( train_type=train_type,