feat: BaseModelConfig + DeepSeekMoE + 工厂模式替代 if/else

- BaseModelConfig: fields() 精确字段匹配 + 类型矫正 + 未知key警告
- DeepSeekMoE: 共享专家 + 路由专家 + top-K 门控
- AttnFactory/FFNFactory: 装饰器注册,DecoderBlock 零分支
- config 用 attn_type/ffn_type 驱动组件选择
This commit is contained in:
ViperEkura 2026-05-15 20:34:52 +08:00
parent ef25efffa2
commit e12f1a7ee5
5 changed files with 201 additions and 32 deletions

View File

@ -1,12 +1,92 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional, Self
import sys
from dataclasses import dataclass, fields
from typing import Any, Dict, Optional, Self, get_type_hints
@dataclass
class ModelConfig:
# basic config
class BaseModelConfig:
"""Field-aware JSON load/save for dataclass configs.
Subclass with additional fields. The base ``model_type`` field
enables ``AutoModel`` to pick the correct subclass.
"""
model_type: Optional[str] = None
def load(self, config_path: str) -> Self:
raw: Dict[str, Any] = {}
with open(config_path, "r") as f:
raw.update(json.load(f))
hints = get_type_hints(type(self))
valid = {fld.name for fld in fields(self)}
for key, value in raw.items():
if key not in valid:
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
continue
target_type = self._unwrap_optional(hints.get(key))
if target_type is None:
continue
try:
value = self._coerce(value, target_type)
except (TypeError, ValueError):
sys.stderr.write(
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
)
continue
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict: Dict[str, Any] = {}
for fld in fields(self):
v = getattr(self, fld.name)
if v is not None:
config_dict[fld.name] = v
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
@staticmethod
def _unwrap_optional(tp: type) -> Optional[type]:
if tp is None:
return None
origin = getattr(tp, "__origin__", None)
if origin is not None:
args = getattr(tp, "__args__", ())
non_none = [a for a in args if a is not type(None)]
return non_none[0] if non_none else None
return tp
@staticmethod
def _coerce(value: Any, target_type: type) -> Any:
if target_type is bool and isinstance(value, bool):
return value
if (
target_type is int
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return int(value)
if (
target_type is float
and isinstance(value, (int, float))
and not isinstance(value, bool)
):
return float(value)
if target_type is str and isinstance(value, str):
return value
if isinstance(value, target_type):
return value
raise TypeError
@dataclass
class ModelConfig(BaseModelConfig):
vocab_size: Optional[int] = None
dim: Optional[int] = None
@ -19,24 +99,16 @@ class ModelConfig:
max_len: Optional[int] = None
rope_theta: Optional[float] = None
# GQA
# attention
attn_type: str = "gqa"
n_heads: Optional[int] = None
n_kv_heads: Optional[int] = None
use_qk_norm: Optional[bool] = None
use_gated_attention: Optional[bool] = None
def load(self, config_path: str) -> Self:
config = {}
with open(config_path, "r") as f:
config.update(json.load(f))
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str):
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
with open(config_path, "w") as f:
json.dump(config_dict, f, indent=4)
# MoE
ffn_type: str = "mlp"
n_routed_experts: Optional[int] = None
n_shared_experts: Optional[int] = None
n_activated_experts: Optional[int] = None
moe_topk_method: Optional[str] = None

View File

@ -5,6 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.linear import Linear
from astrai.model.components.norm import RMSNorm
@ -22,6 +23,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
)
class AttnFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, attn_type: str, **kwargs) -> nn.Module:
return super().create(attn_type, **kwargs)
@AttnFactory.register("gqa")
class GQA(nn.Module):
def __init__(
self,
@ -32,6 +40,7 @@ class GQA(nn.Module):
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
**kwargs,
):
super().__init__()
assert dim % n_heads == 0
@ -101,6 +110,7 @@ class GQA(nn.Module):
return out
@AttnFactory.register("mla")
class MLA(nn.Module):
def __init__(
self,
@ -113,6 +123,7 @@ class MLA(nn.Module):
norm_eps: float,
use_gated_attention: bool,
layer_id: int,
**kwargs,
):
super().__init__()
self.dim = dim

View File

@ -4,8 +4,8 @@ import torch.nn as nn
from torch import Tensor
from astrai.inference.core.cache import KvcacheView
from astrai.model.components.attention import GQA
from astrai.model.components.mlp import MLP
from astrai.model.components.attention import AttnFactory
from astrai.model.components.mlp import FFNFactory
from astrai.model.components.norm import RMSNorm
@ -20,20 +20,24 @@ class DecoderBlock(nn.Module):
use_qk_norm: bool,
use_gated_attention: bool,
layer_id: int,
attn_type: str = "gqa",
ffn_type: str = "mlp",
**moe_kwargs,
):
super().__init__()
self.attention = GQA(
dim,
n_heads,
n_kv_heads,
use_qk_norm,
norm_eps,
use_gated_attention,
layer_id,
self.attention = AttnFactory.create(
attn_type,
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
use_qk_norm=use_qk_norm,
norm_eps=norm_eps,
use_gated_attention=use_gated_attention,
layer_id=layer_id,
)
self.input_norm = RMSNorm(dim, norm_eps)
self.mlp = MLP(dim, dim_ffn)
self.post_attention_norm = RMSNorm(dim, norm_eps)
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
def forward(
self,

View File

@ -1,12 +1,21 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from astrai.factory import BaseFactory
from astrai.model.components.linear import Linear
class FFNFactory(BaseFactory[nn.Module]):
@classmethod
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
return super().create(ffn_type, dim, dim_ffn, **kwargs)
@FFNFactory.register("mlp")
class MLP(nn.Module):
def __init__(self, dim: int, dim_feed_forward: int):
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
super().__init__()
self.up = Linear(dim, dim_feed_forward)
self.gate = Linear(dim, dim_feed_forward)
@ -16,3 +25,70 @@ class MLP(nn.Module):
gated = self.up(x) * F.silu(self.gate(x))
out = self.down(gated)
return out
@FFNFactory.register("moe")
class DeepSeekMoE(nn.Module):
def __init__(
self,
dim: int,
dim_feed_forward: int,
n_routed_experts: int,
n_shared_experts: int = 1,
n_activated_experts: int = 2,
topk_method: str = "greedy",
**kwargs,
):
super().__init__()
self.dim = dim
self.n_routed_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.n_activated_experts = n_activated_experts
self.topk_method = topk_method
self.router = Linear(dim, n_routed_experts, bias=False)
self.shared_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
)
self.routed_experts = nn.ModuleList(
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
)
def forward(self, x: Tensor) -> Tensor:
bsz, seq_len, dim = x.shape
x_flat = x.view(-1, dim)
shared_out = self._shared_forward(x_flat)
routed_out = self._routed_forward(x_flat)
out = (shared_out + routed_out).view(bsz, seq_len, dim)
return out
def _shared_forward(self, x: Tensor) -> Tensor:
if self.n_shared_experts == 0:
return torch.zeros_like(x)
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
def _routed_forward(self, x: Tensor) -> Tensor:
N, D = x.shape
K = self.n_activated_experts
router_logits = self.router(x)
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
for expert_idx in range(self.n_routed_experts):
expert_mask = topk_indices == expert_idx
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
if token_idx.numel() == 0:
continue
expert_input = x[token_idx]
expert_output = self.routed_experts[expert_idx](expert_input)
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
output.index_add_(0, token_idx, expert_output * weights)
return output

View File

@ -69,6 +69,12 @@ class Transformer(AutoModel):
config.use_qk_norm,
config.use_gated_attention,
layer_id,
attn_type=config.attn_type,
ffn_type=config.ffn_type,
n_routed_experts=config.n_routed_experts,
n_shared_experts=config.n_shared_experts,
n_activated_experts=config.n_activated_experts,
topk_method=config.moe_topk_method,
)
for layer_id in range(config.n_layers)
]