19 lines
531 B
Python
19 lines
531 B
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
from astrai.model.components.linear import Linear
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, dim: int, dim_feed_forward: int):
|
|
super().__init__()
|
|
self.up = Linear(dim, dim_feed_forward)
|
|
self.gate = Linear(dim, dim_feed_forward)
|
|
self.down = Linear(dim_feed_forward, dim)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
gated = self.up(x) * F.silu(self.gate(x))
|
|
out = self.down(gated)
|
|
return out
|