refactor: 改用递归子模块 init 替代统一 normal_(0.006)

- Embedding.reset_parameters: normal_(std=0.02)
- Linear.reset_parameters: kaiming_uniform_ + uniform_ bias
- Transformer._init_weights 通过 apply 递归调用子模块 reset_parameters
- 移除全局 normal_(0.006) 覆盖,各模块使用更合适的分布
This commit is contained in:
ViperEkura 2026-05-17 10:44:18 +08:00
parent ad9f4d9cf6
commit 1d54491809
3 changed files with 14 additions and 5 deletions

View File

@ -9,5 +9,8 @@ class Embedding(nn.Module):
super().__init__()
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
def reset_parameters(self):
nn.init.normal_(self.weight, mean=0.0, std=0.02)
def forward(self, x: Tensor) -> Tensor:
return F.embedding(x, self.weight)

View File

@ -10,5 +10,12 @@ class Linear(nn.Module):
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
if self.bias is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / (fan_in**0.5)
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, x: Tensor) -> Tensor:
return F.linear(x, self.weight, self.bias)

View File

@ -93,12 +93,11 @@ class Transformer(AutoModel):
if self.config.tie_weight is True:
self.lm_head.weight = self.embed_tokens.weight
self._init_weights()
self.apply(self._init_weights)
def _init_weights(self):
for param in self.parameters():
if param.dim() > 1:
nn.init.normal_(param, mean=0.0, std=0.006)
def _init_weights(self, module):
if hasattr(module, "reset_parameters"):
module.reset_parameters()
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
lm_head_key = "lm_head.weight"