diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index cf1ebb8..d707213 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -1,12 +1,92 @@ import json -from dataclasses import asdict, dataclass -from typing import Optional, Self +import sys +from dataclasses import dataclass, fields +from typing import Any, Dict, Optional, Self, get_type_hints @dataclass -class ModelConfig: - # basic config +class BaseModelConfig: + """Field-aware JSON load/save for dataclass configs. + + Subclass with additional fields. The base ``model_type`` field + enables ``AutoModel`` to pick the correct subclass. + """ + model_type: Optional[str] = None + + def load(self, config_path: str) -> Self: + raw: Dict[str, Any] = {} + with open(config_path, "r") as f: + raw.update(json.load(f)) + + hints = get_type_hints(type(self)) + valid = {fld.name for fld in fields(self)} + for key, value in raw.items(): + if key not in valid: + sys.stderr.write(f"WARNING: unknown config key '{key}'\n") + continue + + target_type = self._unwrap_optional(hints.get(key)) + if target_type is None: + continue + + try: + value = self._coerce(value, target_type) + except (TypeError, ValueError): + sys.stderr.write( + f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n" + ) + continue + + setattr(self, key, value) + + return self + + def save(self, config_path: str): + config_dict: Dict[str, Any] = {} + for fld in fields(self): + v = getattr(self, fld.name) + if v is not None: + config_dict[fld.name] = v + with open(config_path, "w") as f: + json.dump(config_dict, f, indent=4) + + @staticmethod + def _unwrap_optional(tp: type) -> Optional[type]: + if tp is None: + return None + origin = getattr(tp, "__origin__", None) + if origin is not None: + args = getattr(tp, "__args__", ()) + non_none = [a for a in args if a is not type(None)] + return non_none[0] if non_none else None + return tp + + @staticmethod + def _coerce(value: Any, target_type: type) -> Any: + if target_type is bool and isinstance(value, bool): + return value + if ( + target_type is int + and isinstance(value, (int, float)) + and not isinstance(value, bool) + ): + return int(value) + if ( + target_type is float + and isinstance(value, (int, float)) + and not isinstance(value, bool) + ): + return float(value) + if target_type is str and isinstance(value, str): + return value + if isinstance(value, target_type): + return value + raise TypeError + + +@dataclass +class ModelConfig(BaseModelConfig): vocab_size: Optional[int] = None dim: Optional[int] = None @@ -19,24 +99,16 @@ class ModelConfig: max_len: Optional[int] = None rope_theta: Optional[float] = None - # GQA + # attention + attn_type: str = "gqa" n_heads: Optional[int] = None n_kv_heads: Optional[int] = None use_qk_norm: Optional[bool] = None use_gated_attention: Optional[bool] = None - def load(self, config_path: str) -> Self: - config = {} - with open(config_path, "r") as f: - config.update(json.load(f)) - - for key, value in config.items(): - if hasattr(self, key): - setattr(self, key, value) - - return self - - def save(self, config_path: str): - config_dict = {k: v for k, v in asdict(self).items() if v is not None} - with open(config_path, "w") as f: - json.dump(config_dict, f, indent=4) + # MoE + ffn_type: str = "mlp" + n_routed_experts: Optional[int] = None + n_shared_experts: Optional[int] = None + n_activated_experts: Optional[int] = None + moe_topk_method: Optional[str] = None diff --git a/astrai/model/components/attention.py b/astrai/model/components/attention.py index f82a416..0ff8268 100644 --- a/astrai/model/components/attention.py +++ b/astrai/model/components/attention.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from astrai.factory import BaseFactory from astrai.inference.core.cache import KvcacheView from astrai.model.components.linear import Linear from astrai.model.components.norm import RMSNorm @@ -22,6 +23,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor: ) +class AttnFactory(BaseFactory[nn.Module]): + @classmethod + def create(cls, attn_type: str, **kwargs) -> nn.Module: + return super().create(attn_type, **kwargs) + + +@AttnFactory.register("gqa") class GQA(nn.Module): def __init__( self, @@ -32,6 +40,7 @@ class GQA(nn.Module): norm_eps: float, use_gated_attention: bool, layer_id: int, + **kwargs, ): super().__init__() assert dim % n_heads == 0 @@ -101,6 +110,7 @@ class GQA(nn.Module): return out +@AttnFactory.register("mla") class MLA(nn.Module): def __init__( self, @@ -113,6 +123,7 @@ class MLA(nn.Module): norm_eps: float, use_gated_attention: bool, layer_id: int, + **kwargs, ): super().__init__() self.dim = dim diff --git a/astrai/model/components/decoder_block.py b/astrai/model/components/decoder_block.py index c7c9cca..63cf9b4 100644 --- a/astrai/model/components/decoder_block.py +++ b/astrai/model/components/decoder_block.py @@ -4,8 +4,8 @@ import torch.nn as nn from torch import Tensor from astrai.inference.core.cache import KvcacheView -from astrai.model.components.attention import GQA -from astrai.model.components.mlp import MLP +from astrai.model.components.attention import AttnFactory +from astrai.model.components.mlp import FFNFactory from astrai.model.components.norm import RMSNorm @@ -20,20 +20,24 @@ class DecoderBlock(nn.Module): use_qk_norm: bool, use_gated_attention: bool, layer_id: int, + attn_type: str = "gqa", + ffn_type: str = "mlp", + **moe_kwargs, ): super().__init__() - self.attention = GQA( - dim, - n_heads, - n_kv_heads, - use_qk_norm, - norm_eps, - use_gated_attention, - layer_id, + self.attention = AttnFactory.create( + attn_type, + dim=dim, + n_heads=n_heads, + n_kv_heads=n_kv_heads, + use_qk_norm=use_qk_norm, + norm_eps=norm_eps, + use_gated_attention=use_gated_attention, + layer_id=layer_id, ) self.input_norm = RMSNorm(dim, norm_eps) - self.mlp = MLP(dim, dim_ffn) self.post_attention_norm = RMSNorm(dim, norm_eps) + self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs) def forward( self, diff --git a/astrai/model/components/mlp.py b/astrai/model/components/mlp.py index 9f342f4..de7e06b 100644 --- a/astrai/model/components/mlp.py +++ b/astrai/model/components/mlp.py @@ -1,12 +1,21 @@ +import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor +from astrai.factory import BaseFactory from astrai.model.components.linear import Linear +class FFNFactory(BaseFactory[nn.Module]): + @classmethod + def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module: + return super().create(ffn_type, dim, dim_ffn, **kwargs) + + +@FFNFactory.register("mlp") class MLP(nn.Module): - def __init__(self, dim: int, dim_feed_forward: int): + def __init__(self, dim: int, dim_feed_forward: int, **kwargs): super().__init__() self.up = Linear(dim, dim_feed_forward) self.gate = Linear(dim, dim_feed_forward) @@ -16,3 +25,70 @@ class MLP(nn.Module): gated = self.up(x) * F.silu(self.gate(x)) out = self.down(gated) return out + + +@FFNFactory.register("moe") +class DeepSeekMoE(nn.Module): + def __init__( + self, + dim: int, + dim_feed_forward: int, + n_routed_experts: int, + n_shared_experts: int = 1, + n_activated_experts: int = 2, + topk_method: str = "greedy", + **kwargs, + ): + super().__init__() + self.dim = dim + self.n_routed_experts = n_routed_experts + self.n_shared_experts = n_shared_experts + self.n_activated_experts = n_activated_experts + self.topk_method = topk_method + + self.router = Linear(dim, n_routed_experts, bias=False) + + self.shared_experts = nn.ModuleList( + [MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)] + ) + self.routed_experts = nn.ModuleList( + [MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)] + ) + + def forward(self, x: Tensor) -> Tensor: + bsz, seq_len, dim = x.shape + x_flat = x.view(-1, dim) + + shared_out = self._shared_forward(x_flat) + routed_out = self._routed_forward(x_flat) + + out = (shared_out + routed_out).view(bsz, seq_len, dim) + return out + + def _shared_forward(self, x: Tensor) -> Tensor: + if self.n_shared_experts == 0: + return torch.zeros_like(x) + return sum(e(x) for e in self.shared_experts) / self.n_shared_experts + + def _routed_forward(self, x: Tensor) -> Tensor: + N, D = x.shape + K = self.n_activated_experts + + router_logits = self.router(x) + router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype) + + topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1) + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + output = torch.zeros(N, D, device=x.device, dtype=x.dtype) + for expert_idx in range(self.n_routed_experts): + expert_mask = topk_indices == expert_idx + token_idx, k_idx = expert_mask.nonzero(as_tuple=True) + if token_idx.numel() == 0: + continue + expert_input = x[token_idx] + expert_output = self.routed_experts[expert_idx](expert_input) + weights = topk_weights[token_idx, k_idx].unsqueeze(-1) + output.index_add_(0, token_idx, expert_output * weights) + + return output diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index b297513..8fbd7ed 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -69,6 +69,12 @@ class Transformer(AutoModel): config.use_qk_norm, config.use_gated_attention, layer_id, + attn_type=config.attn_type, + ffn_type=config.ffn_type, + n_routed_experts=config.n_routed_experts, + n_shared_experts=config.n_shared_experts, + n_activated_experts=config.n_activated_experts, + topk_method=config.moe_topk_method, ) for layer_id in range(config.n_layers) ]