AstrAI/tests/module/test_forward_configs.py

109 lines
2.9 KiB
Python

import pytest
import torch
from astrai.config.model_config import ModelConfig
from astrai.model.transformer import Transformer
TINY_CONFIG = dict(
vocab_size=128,
dim=8,
n_heads=2,
n_kv_heads=1,
dim_ffn=16,
max_len=64,
n_layers=2,
norm_eps=1e-5,
)
CONFIGS = [
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"},
id="gqa_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "mla",
"ffn_type": "mlp",
"kv_lora_rank": 4,
"qk_nope_head_dim": 2,
"qk_rope_head_dim": 2,
},
id="mla_mlp",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "moe",
"n_routed_experts": 4,
"n_shared_experts": 1,
"n_activated_experts": 2,
"moe_topk_method": "greedy",
},
id="gqa_moe",
),
pytest.param(
{
**TINY_CONFIG,
"attn_type": "gqa",
"ffn_type": "mlp",
"rope_theta": 100000.0,
},
id="gqa_rope_theta",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True},
id="gqa_qk_norm",
),
pytest.param(
{**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True},
id="gqa_tie_weight",
),
]
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward(config_kwargs):
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
with torch.no_grad():
output = model(input_ids)
assert "logits" in output
assert "hidden_states" in output
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert output["hidden_states"].shape == (batch_size, seq_len, config.dim)
assert not torch.isnan(output["logits"]).any()
assert not torch.isnan(output["hidden_states"]).any()
@pytest.mark.parametrize("config_kwargs", CONFIGS)
def test_model_forward_with_padding(config_kwargs):
config = ModelConfig(**config_kwargs)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Transformer(config).to(device=device)
model.eval()
batch_size, seq_len = 2, 8
input_ids = torch.randint(
0, config.vocab_size, (batch_size, seq_len), device=device
)
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
input_mask[:, 4:] = False
with torch.no_grad():
output = model(input_ids, input_mask=input_mask)
assert output["logits"].shape == (batch_size, seq_len, config.vocab_size)
assert not torch.isnan(output["logits"]).any()