109 lines
2.9 KiB
Python
109 lines
2.9 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from astrai.config.model_config import ModelConfig
|
|
from astrai.model.transformer import Transformer
|
|
|
|
TINY_CONFIG = dict(
|
|
vocab_size=128,
|
|
dim=8,
|
|
n_heads=2,
|
|
n_kv_heads=1,
|
|
dim_ffn=16,
|
|
max_len=64,
|
|
n_layers=2,
|
|
norm_eps=1e-5,
|
|
)
|
|
|
|
|
|
CONFIGS = [
|
|
pytest.param(
|
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
|
|
id="gqa_mlp",
|
|
),
|
|
pytest.param(
|
|
{
|
|
**TINY_CONFIG,
|
|
"attn_type": "mla",
|
|
"ffn_type": "mlp",
|
|
"kv_lora_rank": 4,
|
|
"qk_nope_head_dim": 2,
|
|
"qk_rope_head_dim": 2,
|
|
},
|
|
id="mla_mlp",
|
|
),
|
|
pytest.param(
|
|
{
|
|
**TINY_CONFIG,
|
|
"attn_type": "gqa",
|
|
"ffn_type": "moe",
|
|
"n_routed_experts": 4,
|
|
"n_shared_experts": 1,
|
|
"n_activated_experts": 2,
|
|
"moe_topk_method": "greedy",
|
|
},
|
|
id="gqa_moe",
|
|
),
|
|
pytest.param(
|
|
{
|
|
**TINY_CONFIG,
|
|
"attn_type": "gqa",
|
|
"ffn_type": "mlp",
|
|
"rope_theta": 100000.0,
|
|
},
|
|
id="gqa_rope_theta",
|
|
),
|
|
pytest.param(
|
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
|
|
id="gqa_qk_norm",
|
|
),
|
|
pytest.param(
|
|
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
|
|
id="gqa_tie_weight",
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
|
def test_model_forward(config_kwargs):
|
|
config = ModelConfig(**config_kwargs)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = Transformer(config).to(device=device)
|
|
model.eval()
|
|
|
|
batch_size, seq_len = 2, 8
|
|
input_ids = torch.randint(
|
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
)
|
|
|
|
with torch.no_grad():
|
|
output = model(input_ids)
|
|
|
|
assert "logits" in output
|
|
assert "hidden_states" in output
|
|
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
|
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
|
|
assert not torch.isnan(output["logits"]).any()
|
|
assert not torch.isnan(output["hidden_states"]).any()
|
|
|
|
|
|
@pytest.mark.parametrize("config_kwargs", CONFIGS)
|
|
def test_model_forward_with_padding(config_kwargs):
|
|
config = ModelConfig(**config_kwargs)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = Transformer(config).to(device=device)
|
|
model.eval()
|
|
|
|
batch_size, seq_len = 2, 8
|
|
input_ids = torch.randint(
|
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
)
|
|
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
|
input_mask[:, 4:] = False
|
|
|
|
with torch.no_grad():
|
|
output = model(input_ids, input_mask=input_mask)
|
|
|
|
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
|
|
assert not torch.isnan(output["logits"]).any()
|