feat: BaseModelConfig + DeepSeekMoE + 工厂模式替代 if/else

- BaseModelConfig: fields() 精确字段匹配 + 类型矫正 + 未知key警告
- DeepSeekMoE: 共享专家 + 路由专家 + top-K 门控
- AttnFactory/FFNFactory: 装饰器注册,DecoderBlock 零分支
- config 用 attn_type/ffn_type 驱动组件选择
This commit is contained in:
ViperEkura 2026-05-15 20:34:52 +08:00
parent ef25efffa2
commit e12f1a7ee5
5 changed files with 201 additions and 32 deletions

View File

@ -1,12 +1,92 @@
import json import json
from dataclasses import asdict, dataclass import sys
from typing import Optional, Self from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass @dataclass
class ModelConfig: class BaseModelConfig:
# basic config """Field-aware JSON load/save for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
model_type: Optional[str] = None model_type: Optional[str] = None
def load(self, config_path: str) -> Self:
raw: Dict[str, Any] = {}
with open(config_path, "r") as f:
raw.update(json.load(f))
hints = get_type_hints(type(self))
valid = {fld.name for fld in fields(self)}
for key, value in raw.items():
if key not in valid:
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
continue
target_type = self._unwrap_optional(hints.get(key))
if target_type is None:
continue
try:
value = self._coerce(value, target_type)
except (TypeError, ValueError):
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass
class ModelConfig(BaseModelConfig):
vocab_size: Optional[int] = None vocab_size: Optional[int] = None
dim: Optional[int] = None dim: Optional[int] = None
@ -19,24 +99,16 @@ class ModelConfig:
max_len: Optional[int] = None max_len: Optional[int] = None
rope_theta: Optional[float] = None rope_theta: Optional[float] = None
# GQA # attention
attn_type: str = "gqa"
n_heads: Optional[int] = None n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self: # MoE
config = {} ffn_type: str = "mlp"
with open(config_path, "r") as f: n_routed_experts: Optional[int] = None
config.update(json.load(f)) n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
for key, value in config.items(): moe_topk_method: Optional[str] = None
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)

View File

@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.factory import BaseFactory
from astrai.inference.core.cache import KvcacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.components.linear import Linear from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm from astrai.model.components.norm import RMSNorm
@ -22,6 +23,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
) )
class AttnFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, attn_type: str, **kwargs) -> nn.Module:
return super().create(attn_type, **kwargs)
@AttnFactory.register("gqa")
class GQA(nn.Module): class GQA(nn.Module):
def __init__( def __init__(
self, self,
@ -32,6 +40,7 @@ class GQA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
assert dim % n_heads == 0 assert dim % n_heads == 0
@ -101,6 +110,7 @@ class GQA(nn.Module):
return out return out
@AttnFactory.register("mla")
class MLA(nn.Module): class MLA(nn.Module):
def __init__( def __init__(
self, self,
@ -113,6 +123,7 @@ class MLA(nn.Module):
norm_eps: float, norm_eps: float,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
**kwargs,
): ):
super().__init__() super().__init__()
self.dim = dim self.dim = dim

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from astrai.inference.core.cache import KvcacheView from astrai.inference.core.cache import KvcacheView
from astrai.model.components.attention import GQA from astrai.model.components.attention import AttnFactory
from astrai.model.components.mlp import MLP from astrai.model.components.mlp import FFNFactory
from astrai.model.components.norm import RMSNorm from astrai.model.components.norm import RMSNorm
@ -20,20 +20,24 @@ class DecoderBlock(nn.Module):
use_qk_norm: bool, use_qk_norm: bool,
use_gated_attention: bool, use_gated_attention: bool,
layer_id: int, layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**moe_kwargs,
): ):
super().__init__() super().__init__()
self.attention = GQA( self.attention = AttnFactory.create(
dim, attn_type,
n_heads, dim=dim,
n_kv_heads, n_heads=n_heads,
use_qk_norm, n_kv_heads=n_kv_heads,
norm_eps, use_qk_norm=use_qk_norm,
use_gated_attention, norm_eps=norm_eps,
layer_id, use_gated_attention=use_gated_attention,
layer_id=layer_id,
) )
self.input_norm = RMSNorm(dim, norm_eps) self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps) self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
def forward( def forward(
self, self,

View File

@ -1,12 +1,21 @@
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from astrai.factory import BaseFactory
from astrai.model.components.linear import Linear from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
return super().create(ffn_type, dim, dim_ffn, **kwargs)
@FFNFactory.register("mlp")
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int): def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
super().__init__() super().__init__()
self.up = Linear(dim, dim_feed_forward) self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward) self.gate = Linear(dim, dim_feed_forward)
@ -16,3 +25,70 @@ class MLP(nn.Module):
gated = self.up(x) * F.silu(self.gate(x)) gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated) out = self.down(gated)
return out return out
@FFNFactory.register("moe")
class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_feed_forward: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
**kwargs,
):
super().__init__()
self.dim = dim
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.n_activated_experts = n_activated_experts
self.topk_method = topk_method
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:
bsz, seq_len, dim = x.shape
x_flat = x.view(-1, dim)
shared_out = self._shared_forward(x_flat)
routed_out = self._routed_forward(x_flat)
out = (shared_out + routed_out).view(bsz, seq_len, dim)
return out
def _shared_forward(self, x: Tensor) -> Tensor:
if self.n_shared_experts == 0:
return torch.zeros_like(x)
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
def _routed_forward(self, x: Tensor) -> Tensor:
N, D = x.shape
K = self.n_activated_experts
router_logits = self.router(x)
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
for expert_idx in range(self.n_routed_experts):
expert_mask = topk_indices == expert_idx
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
expert_input = x[token_idx]
expert_output = self.routed_experts[expert_idx](expert_input)
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
output.index_add_(0, token_idx, expert_output * weights)
return output

View File

@ -69,6 +69,12 @@ class Transformer(AutoModel):
config.use_qk_norm, config.use_qk_norm,
config.use_gated_attention, config.use_gated_attention,
layer_id, layer_id,
attn_type=config.attn_type,
ffn_type=config.ffn_type,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.moe_topk_method,
) )
for layer_id in range(config.n_layers) for layer_id in range(config.n_layers)
] ]