refactor: 改用递归子模块 init 替代统一 normal_(0.006)
- Embedding.reset_parameters: normal_(std=0.02) - Linear.reset_parameters: kaiming_uniform_ + uniform_ bias - Transformer._init_weights 通过 apply 递归调用子模块 reset_parameters - 移除全局 normal_(0.006) 覆盖,各模块使用更合适的分布
This commit is contained in:
parent
ad9f4d9cf6
commit
1d54491809
|
|
@ -9,5 +9,8 @@ class Embedding(nn.Module):
|
|||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.empty((vocab_size, embedding_dim)))
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.embedding(x, self.weight)
|
||||
|
|
|
|||
|
|
@ -10,5 +10,12 @@ class Linear(nn.Module):
|
|||
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
||||
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
||||
if self.bias is not None:
|
||||
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
||||
bound = 1 / (fan_in**0.5)
|
||||
nn.init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
|
|
|||
|
|
@ -93,12 +93,11 @@ class Transformer(AutoModel):
|
|||
if self.config.tie_weight is True:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
||||
self._init_weights()
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self):
|
||||
for param in self.parameters():
|
||||
if param.dim() > 1:
|
||||
nn.init.normal_(param, mean=0.0, std=0.006)
|
||||
def _init_weights(self, module):
|
||||
if hasattr(module, "reset_parameters"):
|
||||
module.reset_parameters()
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict=True, assign=False):
|
||||
lm_head_key = "lm_head.weight"
|
||||
|
|
|
|||
Loading…
Reference in New Issue