AstrAI/astrai/model/components/mlp.py

19 lines
531 B
Python

import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.model.components.linear import Linear
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
self.down = Linear(dim_feed_forward, dim)
def forward(self, x: Tensor) -> Tensor:
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out