From 445378667f5a9551599bfba143846b2d3a9c0075 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Thu, 11 Jun 2026 15:31:08 +0800 Subject: [PATCH] =?UTF-8?q?feat=20:=20NEFTune=20=E5=99=AA=E5=A3=B0?= =?UTF-8?q?=E6=B3=A8=E5=85=A5=20+=20label=5Fsmoothing=20=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E5=80=BC=E4=BF=AE=E6=AD=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Embedding.forward 训练时注入 randn 噪声,缩放系数 neftune_noise_alpha / sqrt(seq_len) - TrainConfig.neftune_alpha 通过 config 传递(默认 0=关闭) - TrainContextBuilder 将 config.neftune_alpha 写入 embed_tokens - --neftune_alpha CLI 参数(典型值 5.0) - label_smoothing 默认值 0.05 -> 0.0 --- astrai/config/train_config.py | 4 ++++ astrai/model/components/embedding.py | 9 ++++++++- astrai/trainer/train_context.py | 1 + scripts/tools/train.py | 10 +++++++++- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 91ea3c0..1421219 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -128,6 +128,10 @@ class TrainConfig(BaseConfig): default=1000, metadata={"help": "Number of optimizer steps between validation runs."}, ) + neftune_alpha: float = field( + default=0.0, + metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."}, + ) executor_kwargs: dict = field( default_factory=dict, diff --git a/astrai/model/components/embedding.py b/astrai/model/components/embedding.py index 3f03796..f8f4551 100644 --- a/astrai/model/components/embedding.py +++ b/astrai/model/components/embedding.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,9 +10,14 @@ class Embedding(nn.Module): def __init__(self, vocab_size: int, embedding_dim: int): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) + self.neftune_noise_alpha = 0.0 def reset_parameters(self): nn.init.normal_(self.weight, mean=0.0, std=0.02) def forward(self, x: Tensor) -> Tensor: - return F.embedding(x, self.weight) + out = F.embedding(x, self.weight) + if self.training and self.neftune_noise_alpha > 0.0: + eps = self.neftune_noise_alpha / math.sqrt(out.size(1)) + out = out + eps * torch.randn_like(out) + return out diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 6af33a9..93d2ea5 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -63,6 +63,7 @@ class TrainContextBuilder: model = cfg.model_fn() model = model.to(device=device) + model.embed_tokens.neftune_noise_alpha = cfg.neftune_alpha model_config = {} if self._resume_dir: diff --git a/scripts/tools/train.py b/scripts/tools/train.py index d2c8eec..e845596 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -113,7 +113,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--label_smoothing", type=float, - default=0.05, + default=0.0, help="cross_entropy function label smoothing parameter", ) parser.add_argument( @@ -214,6 +214,12 @@ def parse_args() -> argparse.Namespace: choices=["spawn", "fork", "forkserver"], help="Multiprocessing start method.", ) + parser.add_argument( + "--neftune_alpha", + type=float, + default=0.0, + help="NEFTune noise alpha (0=disabled, typical: 5.0).", + ) args = parser.parse_args() @@ -293,6 +299,7 @@ def train( master_addr: str, master_port: str, start_method: str, + neftune_alpha: float, ): assert train_type in ["seq", "sft", "dpo", "grpo"] assert os.path.exists(param_path) @@ -385,6 +392,7 @@ def train( gradient_checkpointing_modules=grad_ckpt_modules, executor_kwargs=executor_kwargs, extra_kwargs=strategy_kwargs, + neftune_alpha=neftune_alpha, ) trainer = Trainer(train_config)