22 lines
744 B
Python
22 lines
744 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
|
|
class Linear(nn.Module):
|
|
def __init__(self, in_dim: int, out_dim: int, bias: bool = False):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.empty((out_dim, in_dim)))
|
|
self.bias = nn.Parameter(torch.zeros(out_dim)) if bias else None
|
|
|
|
def reset_parameters(self):
|
|
nn.init.kaiming_uniform_(self.weight, a=5**0.5)
|
|
if self.bias is not None:
|
|
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
|
bound = 1 / (fan_in**0.5)
|
|
nn.init.uniform_(self.bias, -bound, bound)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
return F.linear(x, self.weight, self.bias)
|