fix: 修复 MLA 多个 bug 并缩小测试模型参数
- MLA kv_b_proj 输出维度和 q_rope 切分偏移修复 - 打通 MLA 配置从 ModelConfig 到 DecoderBlock 的传递路径 - rope_theta 配置不再被忽略,MLA 使用 qk_rope_head_dim - tie_weight 使用 is True 避免 None 隐式生效 - norm_eps/rope base 类型标注修正 - 测试模型参数缩小 (dim=8, head_dim=4) - 新增 6 种架构配置 × 2 场景的前向传播测试
This commit is contained in:
parent
3d12a03909
commit
0ba8c70ce1
|
|
@ -106,6 +106,11 @@ class ModelConfig(BaseModelConfig):
|
|||
use_qk_norm: Optional[bool] = None
|
||||
use_gated_attention: Optional[bool] = None
|
||||
|
||||
# MLA
|
||||
kv_lora_rank: Optional[int] = None
|
||||
qk_nope_head_dim: Optional[int] = None
|
||||
qk_rope_head_dim: Optional[int] = None
|
||||
|
||||
# MoE
|
||||
ffn_type: str = "mlp"
|
||||
n_routed_experts: Optional[int] = None
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class MLA(nn.Module):
|
|||
|
||||
self.kv_b_proj = Linear(
|
||||
kv_lora_rank,
|
||||
n_kv_heads * (self.head_dim + qk_rope_head_dim + self.head_dim),
|
||||
n_kv_heads * (2 * self.head_dim),
|
||||
)
|
||||
|
||||
self.o_proj = Linear(dim, dim, bias=False)
|
||||
|
|
@ -176,7 +176,7 @@ class MLA(nn.Module):
|
|||
|
||||
q_nope, q_rope = (
|
||||
q[..., : self.qk_nope_head_dim],
|
||||
q[..., self.qk_rope_head_dim :],
|
||||
q[..., self.qk_nope_head_dim :],
|
||||
)
|
||||
q_rope = apply_rotary_emb(q_rope, rotary_emb)
|
||||
k_rope = apply_rotary_emb(k_rope, rotary_emb)
|
||||
|
|
|
|||
|
|
@ -16,13 +16,13 @@ class DecoderBlock(nn.Module):
|
|||
n_heads: int,
|
||||
dim_ffn: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: int,
|
||||
norm_eps: float,
|
||||
use_qk_norm: bool,
|
||||
use_gated_attention: bool,
|
||||
layer_id: int,
|
||||
attn_type: str = "gqa",
|
||||
ffn_type: str = "mlp",
|
||||
**moe_kwargs,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = AttnFactory.create(
|
||||
|
|
@ -34,10 +34,11 @@ class DecoderBlock(nn.Module):
|
|||
norm_eps=norm_eps,
|
||||
use_gated_attention=use_gated_attention,
|
||||
layer_id=layer_id,
|
||||
**kwargs,
|
||||
)
|
||||
self.input_norm = RMSNorm(dim, norm_eps)
|
||||
self.post_attention_norm = RMSNorm(dim, norm_eps)
|
||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **moe_kwargs)
|
||||
self.mlp = FFNFactory.create(ffn_type, dim, dim_ffn, **kwargs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: Tensor) -> Tensor:
|
|||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim: int, max_len: int, base: int = 10000):
|
||||
def __init__(self, dim: int, max_len: int, base: float = 10000):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_len = max_len
|
||||
|
|
|
|||
|
|
@ -53,9 +53,13 @@ class Transformer(AutoModel):
|
|||
def __init__(self, config: ModelConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.rotary_embedding = RotaryEmbedding(
|
||||
config.dim // config.n_heads, config.max_len
|
||||
rope_dim = (
|
||||
config.qk_rope_head_dim
|
||||
if config.attn_type == "mla"
|
||||
else config.dim // config.n_heads
|
||||
)
|
||||
rope_base = config.rope_theta if config.rope_theta is not None else 10000
|
||||
self.rotary_embedding = RotaryEmbedding(rope_dim, config.max_len, rope_base)
|
||||
self.embed_tokens = Embedding(config.vocab_size, config.dim)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
|
|
@ -75,6 +79,9 @@ class Transformer(AutoModel):
|
|||
n_shared_experts=config.n_shared_experts,
|
||||
n_activated_experts=config.n_activated_experts,
|
||||
topk_method=config.moe_topk_method,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
)
|
||||
for layer_id in range(config.n_layers)
|
||||
]
|
||||
|
|
@ -83,7 +90,7 @@ class Transformer(AutoModel):
|
|||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||
|
||||
if self.config.tie_weight:
|
||||
if self.config.tie_weight is True:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
||||
self._init_weights()
|
||||
|
|
@ -99,7 +106,7 @@ class Transformer(AutoModel):
|
|||
|
||||
state_dict = dict(state_dict)
|
||||
|
||||
if self.config.tie_weight:
|
||||
if self.config.tie_weight is True:
|
||||
# same tensor for embed and lm_head
|
||||
if embed_key in state_dict:
|
||||
state_dict[lm_head_key] = state_dict[embed_key]
|
||||
|
|
@ -115,7 +122,7 @@ class Transformer(AutoModel):
|
|||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
|
||||
if self.config.tie_weight:
|
||||
if self.config.tie_weight is True:
|
||||
lm_head_key = prefix + "lm_head.weight"
|
||||
if lm_head_key in state_dict:
|
||||
del state_dict[lm_head_key]
|
||||
|
|
|
|||
|
|
@ -107,12 +107,12 @@ def test_model():
|
|||
"""Session-scoped small Transformer model, created once."""
|
||||
config = ModelConfig(
|
||||
vocab_size=1000,
|
||||
dim=16,
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
dim_ffn=32,
|
||||
max_len=1024,
|
||||
n_layers=4,
|
||||
dim=8,
|
||||
n_heads=2,
|
||||
n_kv_heads=1,
|
||||
dim_ffn=16,
|
||||
max_len=64,
|
||||
n_layers=2,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
|
@ -137,12 +137,12 @@ def base_test_env(test_model, test_tokenizer):
|
|||
json.dump(
|
||||
{
|
||||
"vocab_size": 1000,
|
||||
"dim": 16,
|
||||
"n_heads": 4,
|
||||
"n_kv_heads": 2,
|
||||
"dim_ffn": 32,
|
||||
"max_len": 1024,
|
||||
"n_layers": 4,
|
||||
"dim": 8,
|
||||
"n_heads": 2,
|
||||
"n_kv_heads": 1,
|
||||
"dim_ffn": 16,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5,
|
||||
},
|
||||
f,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,108 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
TINY_CONFIG = dict(
|
||||
vocab_size=128,
|
||||
dim=8,
|
||||
n_heads=2,
|
||||
n_kv_heads=1,
|
||||
dim_ffn=16,
|
||||
max_len=64,
|
||||
n_layers=2,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
|
||||
CONFIGS = [
|
||||
pytest.param(
|
||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
|
||||
id="gqa_mlp",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
**TINY_CONFIG,
|
||||
"attn_type": "mla",
|
||||
"ffn_type": "mlp",
|
||||
"kv_lora_rank": 4,
|
||||
"qk_nope_head_dim": 2,
|
||||
"qk_rope_head_dim": 2,
|
||||
},
|
||||
id="mla_mlp",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
**TINY_CONFIG,
|
||||
"attn_type": "gqa",
|
||||
"ffn_type": "moe",
|
||||
"n_routed_experts": 4,
|
||||
"n_shared_experts": 1,
|
||||
"n_activated_experts": 2,
|
||||
"moe_topk_method": "greedy",
|
||||
},
|
||||
id="gqa_moe",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
**TINY_CONFIG,
|
||||
"attn_type": "gqa",
|
||||
"ffn_type": "mlp",
|
||||
"rope_theta": 100000.0,
|
||||
},
|
||||
id="gqa_rope_theta",
|
||||
),
|
||||
pytest.param(
|
||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
|
||||
id="gqa_qk_norm",
|
||||
),
|
||||
pytest.param(
|
||||
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
|
||||
id="gqa_tie_weight",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||
def test_model_forward(config_kwargs):
|
||||
config = ModelConfig(**config_kwargs)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids)
|
||||
|
||||
assert "logits" in output
|
||||
assert "hidden_states" in output
|
||||
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
|
||||
assert not torch.isnan(output["logits"]).any()
|
||||
assert not torch.isnan(output["hidden_states"]).any()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
||||
def test_model_forward_with_padding(config_kwargs):
|
||||
config = ModelConfig(**config_kwargs)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model = Transformer(config).to(device=device)
|
||||
model.eval()
|
||||
|
||||
batch_size, seq_len = 2, 8
|
||||
input_ids = torch.randint(
|
||||
0, config.vocab_size, (batch_size, seq_len), device=device
|
||||
)
|
||||
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
||||
input_mask[:, 4:] = False
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(input_ids, input_mask=input_mask)
|
||||
|
||||
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
||||
assert not torch.isnan(output["logits"]).any()
|
||||
|
|
@ -17,10 +17,10 @@ def transformer_test_env():
|
|||
|
||||
config = {
|
||||
"vocab_size": 1000,
|
||||
"dim": 128,
|
||||
"n_heads": 4,
|
||||
"n_kv_heads": 2,
|
||||
"dim_ffn": 256,
|
||||
"dim": 8,
|
||||
"n_heads": 2,
|
||||
"n_kv_heads": 1,
|
||||
"dim_ffn": 16,
|
||||
"max_len": 64,
|
||||
"n_layers": 2,
|
||||
"norm_eps": 1e-5,
|
||||
|
|
|
|||
Loading…
Reference in New Issue