From 1d5449180962b2e63960442322bb25c42f8753e3 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sun, 17 May 2026 10:44:18 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=94=B9=E7=94=A8=E9=80=92?= =?UTF-8?q?=E5=BD=92=E5=AD=90=E6=A8=A1=E5=9D=97=20init=20=E6=9B=BF?= =?UTF-8?q?=E4=BB=A3=E7=BB=9F=E4=B8=80=20normal=5F(0.006)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Embedding.reset_parameters: normal_(std=0.02) - Linear.reset_parameters: kaiming_uniform_ + uniform_ bias - Transformer._init_weights 通过 apply 递归调用子模块 reset_parameters - 移除全局 normal_(0.006) 覆盖,各模块使用更合适的分布 --- astrai/model/components/embedding.py | 3 +++ astrai/model/components/linear.py | 7 +++++++ astrai/model/transformer.py | 9 ++++----- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/astrai/model/components/embedding.py b/astrai/model/components/embedding.py index 5923816..3f03796 100644 --- a/astrai/model/components/embedding.py +++ b/astrai/model/components/embedding.py @@ -9,5 +9,8 @@ class Embedding(nn.Module): super().__init__() self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim))) + def reset_parameters(self): + nn.init.normal_(self.weight, mean=0.0, std=0.02) + def forward(self, x: Tensor) -> Tensor: return F.embedding(x, self.weight) diff --git a/astrai/model/components/linear.py b/astrai/model/components/linear.py index 1810562..c90b1a3 100644 --- a/astrai/model/components/linear.py +++ b/astrai/model/components/linear.py @@ -10,5 +10,12 @@ class Linear(nn.Module): self.weight = nn.Parameter(torch.empty((out_dim, in_dim))) self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=5**0.5) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / (fan_in**0.5) + nn.init.uniform_(self.bias, -bound, bound) + def forward(self, x: Tensor) -> Tensor: return F.linear(x, self.weight, self.bias) diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 72d7b00..3621ff2 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -93,12 +93,11 @@ class Transformer(AutoModel): if self.config.tie_weight is True: self.lm_head.weight = self.embed_tokens.weight - self._init_weights() + self.apply(self._init_weights) - def _init_weights(self): - for param in self.parameters(): - if param.dim() > 1: - nn.init.normal_(param, mean=0.0, std=0.006) + def _init_weights(self, module): + if hasattr(module, "reset_parameters"): + module.reset_parameters() def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False): lm_head_key = "lm_head.weight"