feat: BaseModelConfig + DeepSeekMoE + 工厂模式替代 if/else
- BaseModelConfig: fields() 精确字段匹配 + 类型矫正 + 未知key警告 - DeepSeekMoE: 共享专家 + 路由专家 + top-K 门控 - AttnFactory/FFNFactory: 装饰器注册,DecoderBlock 零分支 - config 用 attn_type/ffn_type 驱动组件选择
This commit is contained in:
parent
ef25efffa2
commit
e12f1a7ee5
|
|
@ -1,12 +1,92 @@
|
|||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Self
|
||||
import sys
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, Optional, Self, get_type_hints
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
# basic config
|
||||
class BaseModelConfig:
|
||||
"""Field-aware JSON load/save for dataclass configs.
|
||||
|
||||
Subclass with additional fields. The base ``model_type`` field
|
||||
enables ``AutoModel`` to pick the correct subclass.
|
||||
"""
|
||||
|
||||
model_type: Optional[str] = None
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
raw: Dict[str, Any] = {}
|
||||
with open(config_path, "r") as f:
|
||||
raw.update(json.load(f))
|
||||
|
||||
hints = get_type_hints(type(self))
|
||||
valid = {fld.name for fld in fields(self)}
|
||||
for key, value in raw.items():
|
||||
if key not in valid:
|
||||
sys.stderr.write(f"WARNING: unknown config key '{key}'\n")
|
||||
continue
|
||||
|
||||
target_type = self._unwrap_optional(hints.get(key))
|
||||
if target_type is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
value = self._coerce(value, target_type)
|
||||
except (TypeError, ValueError):
|
||||
sys.stderr.write(
|
||||
f"WARNING: cannot coerce '{key}' = {value!r} to {target_type}\n"
|
||||
)
|
||||
continue
|
||||
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str):
|
||||
config_dict: Dict[str, Any] = {}
|
||||
for fld in fields(self):
|
||||
v = getattr(self, fld.name)
|
||||
if v is not None:
|
||||
config_dict[fld.name] = v
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_optional(tp: type) -> Optional[type]:
|
||||
if tp is None:
|
||||
return None
|
||||
origin = getattr(tp, "__origin__", None)
|
||||
if origin is not None:
|
||||
args = getattr(tp, "__args__", ())
|
||||
non_none = [a for a in args if a is not type(None)]
|
||||
return non_none[0] if non_none else None
|
||||
return tp
|
||||
|
||||
@staticmethod
|
||||
def _coerce(value: Any, target_type: type) -> Any:
|
||||
if target_type is bool and isinstance(value, bool):
|
||||
return value
|
||||
if (
|
||||
target_type is int
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return int(value)
|
||||
if (
|
||||
target_type is float
|
||||
and isinstance(value, (int, float))
|
||||
and not isinstance(value, bool)
|
||||
):
|
||||
return float(value)
|
||||
if target_type is str and isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, target_type):
|
||||
return value
|
||||
raise TypeError
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig(BaseModelConfig):
|
||||
vocab_size: Optional[int] = None
|
||||
dim: Optional[int] = None
|
||||
|
||||
|
|
@ -19,24 +99,16 @@ class ModelConfig:
|
|||
max_len: Optional[int] = None
|
||||
rope_theta: Optional[float] = None
|
||||
|
||||
# GQA
|
||||
# attention
|
||||
attn_type: str = "gqa"
|
||||
n_heads: Optional[int] = None
|
||||
n_kv_heads: Optional[int] = None
|
||||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
config = {}
|
||||
with open(config_path, "r") as f:
|
||||
config.update(json.load(f))
|
||||
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str):
|
||||
config_dict = {k: v for k, v in asdict(self).items() if v is not None}
|
||||
with open(config_path, "w") as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
# MoE
|
||||
ffn_type: str = "mlp"
|
||||
n_routed_experts: Optional[int] = None
|
||||
n_shared_experts: Optional[int] = None
|
||||
n_activated_experts: Optional[int] = None
|
||||
moe_topk_method: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
|
|
@ -22,6 +23,13 @@ def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
|||
)
|
||||
|
||||
|
||||
class AttnFactory(BaseFactory[nn.Module]):
|
||||
@classmethod
|
||||
def create(cls, attn_type: str, **kwargs) -> nn.Module:
|
||||
return super().create(attn_type, **kwargs)
|
||||
|
||||
|
||||
@AttnFactory.register("gqa")
|
||||
class GQA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -32,6 +40,7 @@ class GQA(nn.Module):
|
|||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % n_heads == 0
|
||||
|
|
@ -101,6 +110,7 @@ class GQA(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@AttnFactory.register("mla")
|
||||
class MLA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -113,6 +123,7 @@ class MLA(nn.Module):
|
|||
norm_eps: float,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
|
||||
from astrai.inference.core.cache import KvcacheView
|
||||
from astrai.model.components.attention import GQA
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.attention import AttnFactory
|
||||
from astrai.model.components.mlp import FFNFactory
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
|
||||
|
||||
|
|
@ -20,20 +20,24 @@ class DecoderBlock(nn.Module):
|
|||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
attn_type: str = "gqa",
|
||||
ffn_type: str = "mlp",
|
||||
**moe_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = GQA(
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
use_qk_norm,
|
||||
norm_eps,
|
||||
use_gated_attention,
|
||||
layer_id,
|
||||
self.attention = AttnFactory.create(
|
||||
attn_type,
|
||||
dim=dim,
|
||||
n_heads=n_heads,
|
||||
n_kv_heads=n_kv_heads,
|
||||
use_qk_norm=use_qk_norm,
|
||||
norm_eps=norm_eps,
|
||||
use_gated_attention=use_gated_attention,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = MLP(dim, dim_ffn)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,21 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.model.components.linear import Linear
|
||||
|
||||
|
||||
class FFNFactory(BaseFactory[nn.Module]):
|
||||
@classmethod
|
||||
def create(cls, ffn_type: str, dim: int, dim_ffn: int, **kwargs) -> nn.Module:
|
||||
return super().create(ffn_type, dim, dim_ffn, **kwargs)
|
||||
|
||||
|
||||
@FFNFactory.register("mlp")
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, dim: int, dim_feed_forward: int):
|
||||
def __init__(self, dim: int, dim_feed_forward: int, **kwargs):
|
||||
super().__init__()
|
||||
self.up = Linear(dim, dim_feed_forward)
|
||||
self.gate = Linear(dim, dim_feed_forward)
|
||||
|
|
@ -16,3 +25,70 @@ class MLP(nn.Module):
|
|||
gated = self.up(x) * F.silu(self.gate(x))
|
||||
out = self.down(gated)
|
||||
return out
|
||||
|
||||
|
||||
@FFNFactory.register("moe")
|
||||
class DeepSeekMoE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
dim_feed_forward: int,
|
||||
n_routed_experts: int,
|
||||
n_shared_experts: int = 1,
|
||||
n_activated_experts: int = 2,
|
||||
topk_method: str = "greedy",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_activated_experts = n_activated_experts
|
||||
self.topk_method = topk_method
|
||||
|
||||
self.router = Linear(dim, n_routed_experts, bias=False)
|
||||
|
||||
self.shared_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_feed_forward) for _ in range(n_shared_experts)]
|
||||
)
|
||||
self.routed_experts = nn.ModuleList(
|
||||
[MLP(dim, dim_feed_forward) for _ in range(n_routed_experts)]
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
bsz, seq_len, dim = x.shape
|
||||
x_flat = x.view(-1, dim)
|
||||
|
||||
shared_out = self._shared_forward(x_flat)
|
||||
routed_out = self._routed_forward(x_flat)
|
||||
|
||||
out = (shared_out + routed_out).view(bsz, seq_len, dim)
|
||||
return out
|
||||
|
||||
def _shared_forward(self, x: Tensor) -> Tensor:
|
||||
if self.n_shared_experts == 0:
|
||||
return torch.zeros_like(x)
|
||||
return sum(e(x) for e in self.shared_experts) / self.n_shared_experts
|
||||
|
||||
def _routed_forward(self, x: Tensor) -> Tensor:
|
||||
N, D = x.shape
|
||||
K = self.n_activated_experts
|
||||
|
||||
router_logits = self.router(x)
|
||||
router_probs = torch.softmax(router_logits.float(), dim=-1).to(x.dtype)
|
||||
|
||||
topk_weights, topk_indices = torch.topk(router_probs, K, dim=-1)
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
output = torch.zeros(N, D, device=x.device, dtype=x.dtype)
|
||||
for expert_idx in range(self.n_routed_experts):
|
||||
expert_mask = topk_indices == expert_idx
|
||||
token_idx, k_idx = expert_mask.nonzero(as_tuple=True)
|
||||
if token_idx.numel() == 0:
|
||||
continue
|
||||
expert_input = x[token_idx]
|
||||
expert_output = self.routed_experts[expert_idx](expert_input)
|
||||
weights = topk_weights[token_idx, k_idx].unsqueeze(-1)
|
||||
output.index_add_(0, token_idx, expert_output * weights)
|
||||
|
||||
return output
|
||||
|
|
|
|||
|
|
@ -69,6 +69,12 @@ class Transformer(AutoModel):
|
|||
config.use_qk_norm,
|
||||
config.use_gated_attention,
|
||||
layer_id,
|
||||
attn_type=config.attn_type,
|
||||
ffn_type=config.ffn_type,
|
||||
n_routed_experts=config.n_routed_experts,
|
||||
n_shared_experts=config.n_shared_experts,
|
||||
n_activated_experts=config.n_activated_experts,
|
||||
topk_method=config.moe_topk_method,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue