fix: 修复 MLA 多个 bug 并缩小测试模型参数
- MLA kv_b_proj 输出维度和 q_rope 切分偏移修复 - 打通 MLA 配置从 ModelConfig 到 DecoderBlock 的传递路径 - rope_theta 配置不再被忽略,MLA 使用 qk_rope_head_dim - tie_weight 使用 is True 避免 None 隐式生效 - norm_eps/rope base 类型标注修正 - 测试模型参数缩小 (dim=8, head_dim=4) - 新增 6 种架构配置 × 2 场景的前向传播测试
This commit is contained in:
parent
3d12a03909
commit
0ba8c70ce1
|
|
@ -106,6 +106,11 @@ class ModelConfig(BaseModelConfig):
|
||||||
use_qk_norm: Optional[bool] = None
|
use_qk_norm: Optional[bool] = None
|
||||||
use_gated_attention: Optional[bool] = None
|
use_gated_attention: Optional[bool] = None
|
||||||
|
|
||||||
|
# MLA
|
||||||
|
kv_lora_rank: Optional[int] = None
|
||||||
|
qk_nope_head_dim: Optional[int] = None
|
||||||
|
qk_rope_head_dim: Optional[int] = None
|
||||||
|
|
||||||
# MoE
|
# MoE
|
||||||
ffn_type: str = "mlp"
|
ffn_type: str = "mlp"
|
||||||
n_routed_experts: Optional[int] = None
|
n_routed_experts: Optional[int] = None
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,7 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
self.kv_b_proj = Linear(
|
self.kv_b_proj = Linear(
|
||||||
kv_lora_rank,
|
kv_lora_rank,
|
||||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
n_kv_heads * (2 * self.head_dim),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.o_proj = Linear(dim, dim, bias=False)
|
self.o_proj = Linear(dim, dim, bias=False)
|
||||||
|
|
@ -176,7 +176,7 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
q_nope, q_rope = (
|
q_nope, q_rope = (
|
||||||
q[..., : self.qk_nope_head_dim],
|
q[..., : self.qk_nope_head_dim],
|
||||||
q[..., self.qk_rope_head_dim :],
|
q[..., self.qk_nope_head_dim :],
|
||||||
)
|
)
|
||||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||||
|
|
|
||||||
|
|
@ -16,13 +16,13 @@ class DecoderBlock(nn.Module):
|
||||||
n_heads: int,
|
n_heads: int,
|
||||||
dim_ffn: int,
|
dim_ffn: int,
|
||||||
n_kv_heads: int,
|
n_kv_heads: int,
|
||||||
norm_eps: int,
|
norm_eps: float,
|
||||||
use_qk_norm: bool,
|
use_qk_norm: bool,
|
||||||
use_gated_attention: bool,
|
use_gated_attention: bool,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
attn_type: str = "gqa",
|
attn_type: str = "gqa",
|
||||||
ffn_type: str = "mlp",
|
ffn_type: str = "mlp",
|
||||||
**moe_kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention = AttnFactory.create(
|
self.attention = AttnFactory.create(
|
||||||
|
|
@ -34,10 +34,11 @@ class DecoderBlock(nn.Module):
|
||||||
norm_eps=norm_eps,
|
norm_eps=norm_eps,
|
||||||
use_gated_attention=use_gated_attention,
|
use_gated_attention=use_gated_attention,
|
||||||
layer_id=layer_id,
|
layer_id=layer_id,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
self.input_norm = RMSNorm(dim, norm_eps)
|
self.input_norm = RMSNorm(dim, norm_eps)
|
||||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
|
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
||||||
|
|
||||||
class RotaryEmbedding(nn.Module):
|
class RotaryEmbedding(nn.Module):
|
||||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.max_len = max_len
|
self.max_len = max_len
|
||||||
|
|
|
||||||
|
|
@ -53,9 +53,13 @@ class Transformer(AutoModel):
|
||||||
def __init__(self, config: ModelConfig):
|
def __init__(self, config: ModelConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
self.rotary_embedding = RotaryEmbedding(
|
rope_dim = (
|
||||||
config.dim // config.n_heads, config.max_len
|
config.qk_rope_head_dim
|
||||||
|
if config.attn_type == "mla"
|
||||||
|
else config.dim // config.n_heads
|
||||||
)
|
)
|
||||||
|
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||||
|
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
|
|
@ -75,6 +79,9 @@ class Transformer(AutoModel):
|
||||||
n_shared_experts=config.n_shared_experts,
|
n_shared_experts=config.n_shared_experts,
|
||||||
n_activated_experts=config.n_activated_experts,
|
n_activated_experts=config.n_activated_experts,
|
||||||
topk_method=config.moe_topk_method,
|
topk_method=config.moe_topk_method,
|
||||||
|
kv_lora_rank=config.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.n_layers)
|
for layer_id in range(config.n_layers)
|
||||||
]
|
]
|
||||||
|
|
@ -83,7 +90,7 @@ class Transformer(AutoModel):
|
||||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight is True:
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
||||||
self._init_weights()
|
self._init_weights()
|
||||||
|
|
@ -99,7 +106,7 @@ class Transformer(AutoModel):
|
||||||
|
|
||||||
state_dict = dict(state_dict)
|
state_dict = dict(state_dict)
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight is True:
|
||||||
# same tensor for embed and lm_head
|
# same tensor for embed and lm_head
|
||||||
if embed_key in state_dict:
|
if embed_key in state_dict:
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
|
|
@ -115,7 +122,7 @@ class Transformer(AutoModel):
|
||||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.tie_weight:
|
if self.config.tie_weight is True:
|
||||||
lm_head_key = prefix + "lm_head.weight"
|
lm_head_key = prefix + "lm_head.weight"
|
||||||
if lm_head_key in state_dict:
|
if lm_head_key in state_dict:
|
||||||
del state_dict[lm_head_key]
|
del state_dict[lm_head_key]
|
||||||
|
|
|
||||||
|
|
@ -107,12 +107,12 @@ def test_model():
|
||||||
"""Session-scoped small Transformer model, created once."""
|
"""Session-scoped small Transformer model, created once."""
|
||||||
config = ModelConfig(
|
config = ModelConfig(
|
||||||
vocab_size=1000,
|
vocab_size=1000,
|
||||||
dim=16,
|
dim=8,
|
||||||
n_heads=4,
|
n_heads=2,
|
||||||
n_kv_heads=2,
|
n_kv_heads=1,
|
||||||
dim_ffn=32,
|
dim_ffn=16,
|
||||||
max_len=1024,
|
max_len=64,
|
||||||
n_layers=4,
|
n_layers=2,
|
||||||
norm_eps=1e-5,
|
norm_eps=1e-5,
|
||||||
)
|
)
|
||||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
|
||||||
json.dump(
|
json.dump(
|
||||||
{
|
{
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"dim": 16,
|
"dim": 8,
|
||||||
"n_heads": 4,
|
"n_heads": 2,
|
||||||
"n_kv_heads": 2,
|
"n_kv_heads": 1,
|
||||||
"dim_ffn": 32,
|
"dim_ffn": 16,
|
||||||
"max_len": 1024,
|
"max_len": 64,
|
||||||
"n_layers": 4,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5,
|
"norm_eps": 1e-5,
|
||||||
},
|
},
|
||||||
f,
|
f,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.config.model_config import ModelConfig
|
||||||
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
TINY_CONFIG = dict(
|
||||||
|
vocab_size=128,
|
||||||
|
dim=8,
|
||||||
|
n_heads=2,
|
||||||
|
n_kv_heads=1,
|
||||||
|
dim_ffn=16,
|
||||||
|
max_len=64,
|
||||||
|
n_layers=2,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
CONFIGS = [
|
||||||
|
pytest.param(
|
||||||
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
|
||||||
|
id="gqa_mlp",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
**TINY_CONFIG,
|
||||||
|
"attn_type": "mla",
|
||||||
|
"ffn_type": "mlp",
|
||||||
|
"kv_lora_rank": 4,
|
||||||
|
"qk_nope_head_dim": 2,
|
||||||
|
"qk_rope_head_dim": 2,
|
||||||
|
},
|
||||||
|
id="mla_mlp",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
**TINY_CONFIG,
|
||||||
|
"attn_type": "gqa",
|
||||||
|
"ffn_type": "moe",
|
||||||
|
"n_routed_experts": 4,
|
||||||
|
"n_shared_experts": 1,
|
||||||
|
"n_activated_experts": 2,
|
||||||
|
"moe_topk_method": "greedy",
|
||||||
|
},
|
||||||
|
id="gqa_moe",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
**TINY_CONFIG,
|
||||||
|
"attn_type": "gqa",
|
||||||
|
"ffn_type": "mlp",
|
||||||
|
"rope_theta": 100000.0,
|
||||||
|
},
|
||||||
|
id="gqa_rope_theta",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
|
||||||
|
id="gqa_qk_norm",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
|
||||||
|
id="gqa_tie_weight",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||||
|
def test_model_forward(config_kwargs):
|
||||||
|
config = ModelConfig(**config_kwargs)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = Transformer(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids)
|
||||||
|
|
||||||
|
assert "logits" in output
|
||||||
|
assert "hidden_states" in output
|
||||||
|
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
||||||
|
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
|
||||||
|
assert not torch.isnan(output["logits"]).any()
|
||||||
|
assert not torch.isnan(output["hidden_states"]).any()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||||
|
def test_model_forward_with_padding(config_kwargs):
|
||||||
|
config = ModelConfig(**config_kwargs)
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model = Transformer(config).to(device=device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
batch_size, seq_len = 2, 8
|
||||||
|
input_ids = torch.randint(
|
||||||
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||||
|
)
|
||||||
|
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||||
|
input_mask[:, 4:] = False
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = model(input_ids, input_mask=input_mask)
|
||||||
|
|
||||||
|
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
||||||
|
assert not torch.isnan(output["logits"]).any()
|
||||||
|
|
@ -17,10 +17,10 @@ def transformer_test_env():
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"vocab_size": 1000,
|
"vocab_size": 1000,
|
||||||
"dim": 128,
|
"dim": 8,
|
||||||
"n_heads": 4,
|
"n_heads": 2,
|
||||||
"n_kv_heads": 2,
|
"n_kv_heads": 1,
|
||||||
"dim_ffn": 256,
|
"dim_ffn": 16,
|
||||||
"max_len": 64,
|
"max_len": 64,
|
||||||
"n_layers": 2,
|
"n_layers": 2,
|
||||||
"norm_eps": 1e-5,
|
"norm_eps": 1e-5,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue