94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
from astrai.factory import BaseFactory
|
|
from astrai.model.components.linear import Linear
|
|
|
|
|
|
class FFNFactory(BaseFactory[nn.Module]):
|
|
@classmethod
|
|
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
|
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
|
|
|
|
|
@FFNFactory.register("mlp")
|
|
class MLP(nn.Module):
|
|
def __init__(self, dim: int, dim_ffn: int):
|
|
super().__init__()
|
|
self.up = Linear(dim, dim_ffn)
|
|
self.gate = Linear(dim, dim_ffn)
|
|
self.down = Linear(dim_ffn, dim)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
gated = self.up(x) * F.silu(self.gate(x))
|
|
out = self.down(gated)
|
|
return out
|
|
|
|
|
|
@FFNFactory.register("moe")
|
|
class DeepSeekMoE(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
dim_ffn: int,
|
|
n_routed_experts: int,
|
|
n_shared_experts: int = 1,
|
|
n_activated_experts: int = 2,
|
|
topk_method: str = "greedy",
|
|
):
|
|
super().__init__()
|
|
self.dim = dim
|
|
self.n_routed_experts = n_routed_experts
|
|
self.n_shared_experts = n_shared_experts
|
|
self.n_activated_experts = n_activated_experts
|
|
self.topk_method = topk_method
|
|
|
|
self.router = Linear(dim, n_routed_experts, bias=False)
|
|
|
|
self.shared_experts = nn.ModuleList(
|
|
[MLP(dim, dim_ffn) for _ in range(n_shared_experts)]
|
|
)
|
|
self.routed_experts = nn.ModuleList(
|
|
[MLP(dim, dim_ffn) for _ in range(n_routed_experts)]
|
|
)
|
|
|
|
def forward(self, x: Tensor) -> Tensor:
|
|
bsz, seq_len, dim = x.shape
|
|
x_flat = x.view(-1, dim)
|
|
|
|
shared_out = self._shared_forward(x_flat)
|
|
routed_out = self._routed_forward(x_flat)
|
|
|
|
out = (shared_out + routed_out).view(bsz, seq_len, dim)
|
|
return out
|
|
|
|
def _shared_forward(self, x: Tensor) -> Tensor:
|
|
if self.n_shared_experts == 0:
|
|
return torch.zeros_like(x)
|
|
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
|
|
|
|
def _routed_forward(self, x: Tensor) -> Tensor:
|
|
N, D = x.shape
|
|
K = self.n_activated_experts
|
|
|
|
router_logits = self.router(x)
|
|
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
|
|
|
|
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
|
|
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
|
for expert_idx in range(self.n_routed_experts):
|
|
expert_mask = topk_indices == expert_idx
|
|
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
|
|
if token_idx.numel() == 0:
|
|
continue
|
|
expert_input = x[token_idx]
|
|
expert_output = self.routed_experts[expert_idx](expert_input)
|
|
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
|
|
output.index_add_(0, token_idx, expert_output * weights)
|
|
|
|
return output
|