import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from astrai.factory import BaseFactory from astrai.model.components.linear import Linear class FFNFactory(BaseFactory[nn.Module]): @classmethod def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module: return super().create(ffn_type, dim, dim_ffn, **kwargs) @FFNFactory.register("mlp") class MLP(nn.Module): def __init__(self, dim: int, dim_ffn: int): super().__init__() self.up = Linear(dim, dim_ffn) self.gate = Linear(dim, dim_ffn) self.down = Linear(dim_ffn, dim) def forward(self, x: Tensor) -> Tensor: gated = self.up(x) * F.silu(self.gate(x)) out = self.down(gated) return out @FFNFactory.register("moe") class DeepSeekMoE(nn.Module): def __init__( self, dim: int, dim_ffn: int, n_routed_experts: int, n_shared_experts: int = 1, n_activated_experts: int = 2, topk_method: str = "greedy", ): super().__init__() self.dim = dim self.n_routed_experts = n_routed_experts self.n_shared_experts = n_shared_experts self.n_activated_experts = n_activated_experts self.topk_method = topk_method self.router = Linear(dim, n_routed_experts, bias=False) self.shared_experts = nn.ModuleList( [MLP(dim, dim_ffn) for _ in range(n_shared_experts)] ) self.routed_experts = nn.ModuleList( [MLP(dim, dim_ffn) for _ in range(n_routed_experts)] ) def forward(self, x: Tensor) -> Tensor: bsz, seq_len, dim = x.shape x_flat = x.view(-1, dim) shared_out = self._shared_forward(x_flat) routed_out = self._routed_forward(x_flat) out = (shared_out + routed_out).view(bsz, seq_len, dim) return out def _shared_forward(self, x: Tensor) -> Tensor: if self.n_shared_experts == 0: return torch.zeros_like(x) return sum(e(x) for e in self.shared_experts) / self.n_shared_experts def _routed_forward(self, x: Tensor) -> Tensor: N, D = x.shape K = self.n_activated_experts router_logits = self.router(x) router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype) topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) output = torch.zeros(N, D, device=x.device, dtype=x.dtype) for expert_idx in range(self.n_routed_experts): expert_mask = topk_indices == expert_idx token_idx, k_idx = expert_mask.nonzero(as_tuple=True) if token_idx.numel() == 0: continue expert_input = x[token_idx] expert_output = self.routed_experts[expert_idx](expert_input) weights = topk_weights[token_idx, k_idx].unsqueeze(-1) output.index_add_(0, token_idx, expert_output * weights) return output