AstrAI/astrai/model/components/embedding.py

24 lines
716 B
Python

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
class Embedding(nn.Module):
def __init__(self, vocab_size: int, embedding_dim: int):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
self.neftune_noise_alpha = 0.0
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0.0, std=0.02)
def forward(self, x: Tensor) -> Tensor:
out = F.embedding(x, self.weight)
if self.training and self.neftune_noise_alpha > 0.0:
eps = self.neftune_noise_alpha / math.sqrt(out.size(1))
out = out + eps * torch.randn_like(out)
return out