feat : NEFTune 噪声注入 + label_smoothing 默认值修正
- Embedding.forward 训练时注入 randn 噪声,缩放系数 neftune_noise_alpha / sqrt(seq_len) - TrainConfig.neftune_alpha 通过 config 传递(默认 0=关闭) - TrainContextBuilder 将 config.neftune_alpha 写入 embed_tokens - --neftune_alpha CLI 参数(典型值 5.0) - label_smoothing 默认值 0.05 -> 0.0
This commit is contained in:
parent
6ae1828449
commit
445378667f
|
|
@ -128,6 +128,10 @@ class TrainConfig(BaseConfig):
|
||||||
default=1000,
|
default=1000,
|
||||||
metadata={"help": "Number of optimizer steps between validation runs."},
|
metadata={"help": "Number of optimizer steps between validation runs."},
|
||||||
)
|
)
|
||||||
|
neftune_alpha: float = field(
|
||||||
|
default=0.0,
|
||||||
|
metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."},
|
||||||
|
)
|
||||||
|
|
||||||
executor_kwargs: dict = field(
|
executor_kwargs: dict = field(
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
import math
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
@ -8,9 +10,14 @@ class Embedding(nn.Module):
|
||||||
def __init__(self, vocab_size: int, embedding_dim: int):
|
def __init__(self, vocab_size: int, embedding_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||||
|
self.neftune_noise_alpha = 0.0
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
||||||
|
|
||||||
def forward(self, x: Tensor) -> Tensor:
|
def forward(self, x: Tensor) -> Tensor:
|
||||||
return F.embedding(x, self.weight)
|
out = F.embedding(x, self.weight)
|
||||||
|
if self.training and self.neftune_noise_alpha > 0.0:
|
||||||
|
eps = self.neftune_noise_alpha / math.sqrt(out.size(1))
|
||||||
|
out = out + eps * torch.randn_like(out)
|
||||||
|
return out
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,7 @@ class TrainContextBuilder:
|
||||||
|
|
||||||
model = cfg.model_fn()
|
model = cfg.model_fn()
|
||||||
model = model.to(device=device)
|
model = model.to(device=device)
|
||||||
|
model.embed_tokens.neftune_noise_alpha = cfg.neftune_alpha
|
||||||
|
|
||||||
model_config = {}
|
model_config = {}
|
||||||
if self._resume_dir:
|
if self._resume_dir:
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--label_smoothing",
|
"--label_smoothing",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.05,
|
default=0.0,
|
||||||
help="cross_entropy function label smoothing parameter",
|
help="cross_entropy function label smoothing parameter",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|
@ -214,6 +214,12 @@ def parse_args() -> argparse.Namespace:
|
||||||
choices=["spawn", "fork", "forkserver"],
|
choices=["spawn", "fork", "forkserver"],
|
||||||
help="Multiprocessing start method.",
|
help="Multiprocessing start method.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--neftune_alpha",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="NEFTune noise alpha (0=disabled, typical: 5.0).",
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
@ -293,6 +299,7 @@ def train(
|
||||||
master_addr: str,
|
master_addr: str,
|
||||||
master_port: str,
|
master_port: str,
|
||||||
start_method: str,
|
start_method: str,
|
||||||
|
neftune_alpha: float,
|
||||||
):
|
):
|
||||||
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
assert train_type in ["seq", "sft", "dpo", "grpo"]
|
||||||
assert os.path.exists(param_path)
|
assert os.path.exists(param_path)
|
||||||
|
|
@ -385,6 +392,7 @@ def train(
|
||||||
gradient_checkpointing_modules=grad_ckpt_modules,
|
gradient_checkpointing_modules=grad_ckpt_modules,
|
||||||
executor_kwargs=executor_kwargs,
|
executor_kwargs=executor_kwargs,
|
||||||
extra_kwargs=strategy_kwargs,
|
extra_kwargs=strategy_kwargs,
|
||||||
|
neftune_alpha=neftune_alpha,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = Trainer(train_config)
|
trainer = Trainer(train_config)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue