feat : NEFTune 噪声注入 + label_smoothing 默认值修正

- Embedding.forward 训练时注入 randn 噪声,缩放系数 neftune_noise_alpha / sqrt(seq_len)
- TrainConfig.neftune_alpha 通过 config 传递(默认 0=关闭)
- TrainContextBuilder 将 config.neftune_alpha 写入 embed_tokens
- --neftune_alpha CLI 参数(典型值 5.0)
- label_smoothing 默认值 0.05 -> 0.0
This commit is contained in:
ViperEkura 2026-06-11 15:31:08 +08:00
parent 6ae1828449
commit 445378667f
4 changed files with 22 additions and 2 deletions

View File

@ -128,6 +128,10 @@ class TrainConfig(BaseConfig):
default=1000,
metadata={"help": "Number of optimizer steps between validation runs."},
)
neftune_alpha: float = field(
default=0.0,
metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."},
)
executor_kwargs: dict = field(
default_factory=dict,

View File

@ -1,3 +1,5 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -8,9 +10,14 @@ class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
self.neftune_noise_alpha = 0.0
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0.0, std=0.02)
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)
out = F.embedding(x, self.weight)
if self.training and self.neftune_noise_alpha > 0.0:
eps = self.neftune_noise_alpha / math.sqrt(out.size(1))
out = out + eps * torch.randn_like(out)
return out

View File

@ -63,6 +63,7 @@ class TrainContextBuilder:
model = cfg.model_fn()
model = model.to(device=device)
model.embed_tokens.neftune_noise_alpha = cfg.neftune_alpha
model_config = {}
if self._resume_dir:

View File

@ -113,7 +113,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--label_smoothing",
type=float,
default=0.05,
default=0.0,
help="cross_entropy function label smoothing parameter",
)
parser.add_argument(
@ -214,6 +214,12 @@ def parse_args() -> argparse.Namespace:
choices=["spawn", "fork", "forkserver"],
help="Multiprocessing start method.",
)
parser.add_argument(
"--neftune_alpha",
type=float,
default=0.0,
help="NEFTune noise alpha (0=disabled, typical: 5.0).",
)
args = parser.parse_args()
@ -293,6 +299,7 @@ def train(
master_addr: str,
master_port: str,
start_method: str,
neftune_alpha: float,
):
assert train_type in ["seq", "sft", "dpo", "grpo"]
assert os.path.exists(param_path)
@ -385,6 +392,7 @@ def train(
gradient_checkpointing_modules=grad_ckpt_modules,
executor_kwargs=executor_kwargs,
extra_kwargs=strategy_kwargs,
neftune_alpha=neftune_alpha,
)
trainer = Trainer(train_config)