diff --git a/astrai/model/components/embedding.py b/astrai/model/components/embedding.py index 5923816..3f03796 100644 --- a/astrai/model/components/embedding.py +++ b/astrai/model/components/embedding.py @@ -9,5 +9,8 @@ class Embedding(nn.Module): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) + def reset_parameters(self): + nn.init.normal_(self.weight, mean=0.0, std=0.02) + def forward(self, x: Tensor) -> Tensor: return F.embedding(x, self.weight) diff --git a/astrai/model/components/linear.py b/astrai/model/components/linear.py index 1810562..c90b1a3 100644 --- a/astrai/model/components/linear.py +++ b/astrai/model/components/linear.py @@ -10,5 +10,12 @@ class Linear(nn.Module): self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=5**0.5) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / (fan_in**0.5) + nn.init.uniform_(self.bias, -bound, bound) + def forward(self, x: Tensor) -> Tensor: return F.linear(x, self.weight, self.bias) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 72d7b00..3621ff2 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -93,12 +93,11 @@ class Transformer(AutoModel): if self.config.tie_weight is True: self.lm_head.weight = self.embed_tokens.weight - self._init_weights() + self.apply(self._init_weights) - def _init_weights(self): - for param in self.parameters(): - if param.dim() > 1: - nn.init.normal_(param, mean=0.0, std=0.006) + def _init_weights(self, module): + if hasattr(module, "reset_parameters"): + module.reset_parameters() def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): lm_head_key = "lm_head.weight"