From 0ba8c70ce1c6fe44ea15042759696186071de50c Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 16 May 2026 14:56:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BF=AE=E5=A4=8D=20MLA=20=E5=A4=9A?= =?UTF-8?q?=E4=B8=AA=20bug=20=E5=B9=B6=E7=BC=A9=E5=B0=8F=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=8F=82=E6=95=B0=20-=20MLA=20kv=5Fb=5Fproj?= =?UTF-8?q?=20=E8=BE=93=E5=87=BA=E7=BB=B4=E5=BA=A6=E5=92=8C=20q=5Frope=20?= =?UTF-8?q?=E5=88=87=E5=88=86=E5=81=8F=E7=A7=BB=E4=BF=AE=E5=A4=8D=20-=20?= =?UTF-8?q?=E6=89=93=E9=80=9A=20MLA=20=E9=85=8D=E7=BD=AE=E4=BB=8E=20ModelC?= =?UTF-8?q?onfig=20=E5=88=B0=20DecoderBlock=20=E7=9A=84=E4=BC=A0=E9=80=92?= =?UTF-8?q?=E8=B7=AF=E5=BE=84=20-=20rope=5Ftheta=20=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E4=B8=8D=E5=86=8D=E8=A2=AB=E5=BF=BD=E7=95=A5=EF=BC=8CMLA=20?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=20qk=5Frope=5Fhead=5Fdim=20-=20tie=5Fweight?= =?UTF-8?q?=20=E4=BD=BF=E7=94=A8=20is=20True=20=E9=81=BF=E5=85=8D=20None?= =?UTF-8?q?=20=E9=9A=90=E5=BC=8F=E7=94=9F=E6=95=88=20-=20norm=5Feps/rope?= =?UTF-8?q?=20base=20=E7=B1=BB=E5=9E=8B=E6=A0=87=E6=B3=A8=E4=BF=AE?= =?UTF-8?q?=E6=AD=A3=20-=20=E6=B5=8B=E8=AF=95=E6=A8=A1=E5=9E=8B=E5=8F=82?= =?UTF-8?q?=E6=95=B0=E7=BC=A9=E5=B0=8F=20(dim=3D8,=20head=5Fdim=3D4)=20-?= =?UTF-8?q?=20=E6=96=B0=E5=A2=9E=206=20=E7=A7=8D=E6=9E=B6=E6=9E=84?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=20=C3=97=202=20=E5=9C=BA=E6=99=AF=E7=9A=84?= =?UTF-8?q?=E5=89=8D=E5=90=91=E4=BC=A0=E6=92=AD=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/config/model_config.py | 5 ++ astrai/model/components/attention.py | 4 +- astrai/model/components/decoder_block.py | 7 +- astrai/model/components/rope.py | 2 +- astrai/model/transformer.py | 17 ++-- tests/conftest.py | 24 ++--- tests/module/test_forward_configs.py | 108 +++++++++++++++++++++++ tests/module/test_tie_weight.py | 8 +- 8 files changed, 148 insertions(+), 27 deletions(-) create mode 100644 tests/module/test_forward_configs.py diff --git a/astrai/config/model_config.py b/astrai/config/model_config.py index d707213..0e02428 100644 --- a/astrai/config/model_config.py +++ b/astrai/config/model_config.py @@ -106,6 +106,11 @@ class ModelConfig(BaseModelConfig): use_qk_norm: Optional[bool] = None use_gated_attention: Optional[bool] = None + # MLA + kv_lora_rank: Optional[int] = None + qk_nope_head_dim: Optional[int] = None + qk_rope_head_dim: Optional[int] = None + # MoE ffn_type: str = "mlp" n_routed_experts: Optional[int] = None diff --git a/astrai/model/components/attention.py b/astrai/model/components/attention.py index 0ff8268..2ad0ea5 100644 --- a/astrai/model/components/attention.py +++ b/astrai/model/components/attention.py @@ -143,7 +143,7 @@ class MLA(nn.Module): self.kv_b_proj = Linear( kv_lora_rank, - n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim), + n_kv_heads * (2 * self.head_dim), ) self.o_proj = Linear(dim, dim, bias=False) @@ -176,7 +176,7 @@ class MLA(nn.Module): q_nope, q_rope = ( q[..., : self.qk_nope_head_dim], - q[..., self.qk_rope_head_dim :], + q[..., self.qk_nope_head_dim :], ) q_rope = apply_rotary_emb(q_rope, rotary_emb) k_rope = apply_rotary_emb(k_rope, rotary_emb) diff --git a/astrai/model/components/decoder_block.py b/astrai/model/components/decoder_block.py index 63cf9b4..60b263f 100644 --- a/astrai/model/components/decoder_block.py +++ b/astrai/model/components/decoder_block.py @@ -16,13 +16,13 @@ class DecoderBlock(nn.Module): n_heads: int, dim_ffn: int, n_kv_heads: int, - norm_eps: int, + norm_eps: float, use_qk_norm: bool, use_gated_attention: bool, layer_id: int, attn_type: str = "gqa", ffn_type: str = "mlp", - **moe_kwargs, + **kwargs, ): super().__init__() self.attention = AttnFactory.create( @@ -34,10 +34,11 @@ class DecoderBlock(nn.Module): norm_eps=norm_eps, use_gated_attention=use_gated_attention, layer_id=layer_id, + **kwargs, ) self.input_norm = RMSNorm(dim, norm_eps) self.post_attention_norm = RMSNorm(dim, norm_eps) - self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs) + self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs) def forward( self, diff --git a/astrai/model/components/rope.py b/astrai/model/components/rope.py index 1df5d92..f7aaff7 100644 --- a/astrai/model/components/rope.py +++ b/astrai/model/components/rope.py @@ -30,7 +30,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor: class RotaryEmbedding(nn.Module): - def __init__(self, dim: int, max_len: int, base: int = 10000): + def __init__(self, dim: int, max_len: int, base: float = 10000): super().__init__() self.dim = dim self.max_len = max_len diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 8fbd7ed..72d7b00 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -53,9 +53,13 @@ class Transformer(AutoModel): def __init__(self, config: ModelConfig): super().__init__(config) self.config = config - self.rotary_embedding = RotaryEmbedding( - config.dim // config.n_heads, config.max_len + rope_dim = ( + config.qk_rope_head_dim + if config.attn_type == "mla" + else config.dim // config.n_heads ) + rope_base = config.rope_theta if config.rope_theta is not None else 10000 + self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base) self.embed_tokens = Embedding(config.vocab_size, config.dim) self.layers = nn.ModuleList( @@ -75,6 +79,9 @@ class Transformer(AutoModel): n_shared_experts=config.n_shared_experts, n_activated_experts=config.n_activated_experts, topk_method=config.moe_topk_method, + kv_lora_rank=config.kv_lora_rank, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, ) for layer_id in range(config.n_layers) ] @@ -83,7 +90,7 @@ class Transformer(AutoModel): self.norm = RMSNorm(config.dim, config.norm_eps) self.lm_head = Linear(config.dim, config.vocab_size) - if self.config.tie_weight: + if self.config.tie_weight is True: self.lm_head.weight = self.embed_tokens.weight self._init_weights() @@ -99,7 +106,7 @@ class Transformer(AutoModel): state_dict = dict(state_dict) - if self.config.tie_weight: + if self.config.tie_weight is True: # same tensor for embed and lm_head if embed_key in state_dict: state_dict[lm_head_key] = state_dict[embed_key] @@ -115,7 +122,7 @@ class Transformer(AutoModel): destination=destination, prefix=prefix, keep_vars=keep_vars ) - if self.config.tie_weight: + if self.config.tie_weight is True: lm_head_key = prefix + "lm_head.weight" if lm_head_key in state_dict: del state_dict[lm_head_key] diff --git a/tests/conftest.py b/tests/conftest.py index 3f17c12..5d0149c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -107,12 +107,12 @@ def test_model(): """Session-scoped small Transformer model, created once.""" config = ModelConfig( vocab_size=1000, - dim=16, - n_heads=4, - n_kv_heads=2, - dim_ffn=32, - max_len=1024, - n_layers=4, + dim=8, + n_heads=2, + n_kv_heads=1, + dim_ffn=16, + max_len=64, + n_layers=2, norm_eps=1e-5, ) device = "cuda" if torch.cuda.is_available() else "cpu" @@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer): json.dump( { "vocab_size": 1000, - "dim": 16, - "n_heads": 4, - "n_kv_heads": 2, - "dim_ffn": 32, - "max_len": 1024, - "n_layers": 4, + "dim": 8, + "n_heads": 2, + "n_kv_heads": 1, + "dim_ffn": 16, + "max_len": 64, + "n_layers": 2, "norm_eps": 1e-5, }, f, diff --git a/tests/module/test_forward_configs.py b/tests/module/test_forward_configs.py new file mode 100644 index 0000000..fa8fdd2 --- /dev/null +++ b/tests/module/test_forward_configs.py @@ -0,0 +1,108 @@ +import pytest +import torch + +from astrai.config.model_config import ModelConfig +from astrai.model.transformer import Transformer + +TINY_CONFIG = dict( + vocab_size=128, + dim=8, + n_heads=2, + n_kv_heads=1, + dim_ffn=16, + max_len=64, + n_layers=2, + norm_eps=1e-5, +) + + +CONFIGS = [ + pytest.param( + {**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"}, + id="gqa_mlp", + ), + pytest.param( + { + **TINY_CONFIG, + "attn_type": "mla", + "ffn_type": "mlp", + "kv_lora_rank": 4, + "qk_nope_head_dim": 2, + "qk_rope_head_dim": 2, + }, + id="mla_mlp", + ), + pytest.param( + { + **TINY_CONFIG, + "attn_type": "gqa", + "ffn_type": "moe", + "n_routed_experts": 4, + "n_shared_experts": 1, + "n_activated_experts": 2, + "moe_topk_method": "greedy", + }, + id="gqa_moe", + ), + pytest.param( + { + **TINY_CONFIG, + "attn_type": "gqa", + "ffn_type": "mlp", + "rope_theta": 100000.0, + }, + id="gqa_rope_theta", + ), + pytest.param( + {**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True}, + id="gqa_qk_norm", + ), + pytest.param( + {**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True}, + id="gqa_tie_weight", + ), +] + + +@pytest.mark.parametrize("config_kwargs", CONFIGS) +def test_model_forward(config_kwargs): + config = ModelConfig(**config_kwargs) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = Transformer(config).to(device=device) + model.eval() + + batch_size, seq_len = 2, 8 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_len), device=device + ) + + with torch.no_grad(): + output = model(input_ids) + + assert "logits" in output + assert "hidden_states" in output + assert output["logits"].shape == (batch_size, seq_len, config.vocab_size) + assert output["hidden_states"].shape == (batch_size, seq_len, config.dim) + assert not torch.isnan(output["logits"]).any() + assert not torch.isnan(output["hidden_states"]).any() + + +@pytest.mark.parametrize("config_kwargs", CONFIGS) +def test_model_forward_with_padding(config_kwargs): + config = ModelConfig(**config_kwargs) + device = "cuda" if torch.cuda.is_available() else "cpu" + model = Transformer(config).to(device=device) + model.eval() + + batch_size, seq_len = 2, 8 + input_ids = torch.randint( + 0, config.vocab_size, (batch_size, seq_len), device=device + ) + input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) + input_mask[:, 4:] = False + + with torch.no_grad(): + output = model(input_ids, input_mask=input_mask) + + assert output["logits"].shape == (batch_size, seq_len, config.vocab_size) + assert not torch.isnan(output["logits"]).any() diff --git a/tests/module/test_tie_weight.py b/tests/module/test_tie_weight.py index e15b4ac..73b763b 100644 --- a/tests/module/test_tie_weight.py +++ b/tests/module/test_tie_weight.py @@ -17,10 +17,10 @@ def transformer_test_env(): config = { "vocab_size": 1000, - "dim": 128, - "n_heads": 4, - "n_kv_heads": 2, - "dim_ffn": 256, + "dim": 8, + "n_heads": 2, + "n_kv_heads": 1, + "dim_ffn": 16, "max_len": 64, "n_layers": 2, "norm_eps": 1e-5,