diff --git a/astrai/config/train_config.py b/astrai/config/train_config.py index 91ea3c0..1421219 100644 --- a/astrai/config/train_config.py +++ b/astrai/config/train_config.py @@ -128,6 +128,10 @@ class TrainConfig(BaseConfig): default=1000, metadata={"help": "Number of optimizer steps between validation runs."}, ) + neftune_alpha: float = field( + default=0.0, + metadata={"help": "NEFTune noise alpha (0=disabled, typical: 5.0)."}, + ) executor_kwargs: dict = field( default_factory=dict, diff --git a/astrai/model/components/embedding.py b/astrai/model/components/embedding.py index 3f03796..f8f4551 100644 --- a/astrai/model/components/embedding.py +++ b/astrai/model/components/embedding.py @@ -1,3 +1,5 @@ +import math + import torch import torch.nn as nn import torch.nn.functional as F @@ -8,9 +10,14 @@ class Embedding(nn.Module): def __init__(self, vocab_size: int, embedding_dim: int): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) + self.neftune_noise_alpha = 0.0 def reset_parameters(self): nn.init.normal_(self.weight, mean=0.0, std=0.02) def forward(self, x: Tensor) -> Tensor: - return F.embedding(x, self.weight) + out = F.embedding(x, self.weight) + if self.training and self.neftune_noise_alpha > 0.0: + eps = self.neftune_noise_alpha / math.sqrt(out.size(1)) + out = out + eps * torch.randn_like(out) + return out diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index 6af33a9..93d2ea5 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -63,6 +63,7 @@ class TrainContextBuilder: model = cfg.model_fn() model = model.to(device=device) + model.embed_tokens.neftune_noise_alpha = cfg.neftune_alpha model_config = {} if self._resume_dir: diff --git a/scripts/tools/train.py b/scripts/tools/train.py index d2c8eec..e845596 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -113,7 +113,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--label_smoothing", type=float, - default=0.05, + default=0.0, help="cross_entropy function label smoothing parameter", ) parser.add_argument( @@ -214,6 +214,12 @@ def parse_args() -> argparse.Namespace: choices=["spawn", "fork", "forkserver"], help="Multiprocessing start method.", ) + parser.add_argument( + "--neftune_alpha", + type=float, + default=0.0, + help="NEFTune noise alpha (0=disabled, typical: 5.0).", + ) args = parser.parse_args() @@ -293,6 +299,7 @@ def train( master_addr: str, master_port: str, start_method: str, + neftune_alpha: float, ): assert train_type in ["seq", "sft", "dpo", "grpo"] assert os.path.exists(param_path) @@ -385,6 +392,7 @@ def train( gradient_checkpointing_modules=grad_ckpt_modules, executor_kwargs=executor_kwargs, extra_kwargs=strategy_kwargs, + neftune_alpha=neftune_alpha, ) trainer = Trainer(train_config)