import pytest import torch from astrai.config.model_config import ModelConfig from astrai.model.transformer import Transformer TINY_CONFIG = dict( vocab_size=128, dim=8, n_heads=2, n_kv_heads=1, dim_ffn=16, max_len=64, n_layers=2, norm_eps=1e-5, ) CONFIGS = [ pytest.param( {**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp"}, id="gqa_mlp", ), pytest.param( { **TINY_CONFIG, "attn_type": "mla", "ffn_type": "mlp", "kv_lora_rank": 4, "qk_nope_head_dim": 2, "qk_rope_head_dim": 2, }, id="mla_mlp", ), pytest.param( { **TINY_CONFIG, "attn_type": "gqa", "ffn_type": "moe", "n_routed_experts": 4, "n_shared_experts": 1, "n_activated_experts": 2, "moe_topk_method": "greedy", }, id="gqa_moe", ), pytest.param( { **TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "rope_theta": 100000.0, }, id="gqa_rope_theta", ), pytest.param( {**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "use_qk_norm": True}, id="gqa_qk_norm", ), pytest.param( {**TINY_CONFIG, "attn_type": "gqa", "ffn_type": "mlp", "tie_weight": True}, id="gqa_tie_weight", ), ] @pytest.mark.parametrize("config_kwargs", CONFIGS) def test_model_forward(config_kwargs): config = ModelConfig(**config_kwargs) device = "cuda" if torch.cuda.is_available() else "cpu" model = Transformer(config).to(device=device) model.eval() batch_size, seq_len = 2, 8 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seq_len), device=device ) with torch.no_grad(): output = model(input_ids) assert "logits" in output assert "hidden_states" in output assert output["logits"].shape == (batch_size, seq_len, config.vocab_size) assert output["hidden_states"].shape == (batch_size, seq_len, config.dim) assert not torch.isnan(output["logits"]).any() assert not torch.isnan(output["hidden_states"]).any() @pytest.mark.parametrize("config_kwargs", CONFIGS) def test_model_forward_with_padding(config_kwargs): config = ModelConfig(**config_kwargs) device = "cuda" if torch.cuda.is_available() else "cpu" model = Transformer(config).to(device=device) model.eval() batch_size, seq_len = 2, 8 input_ids = torch.randint( 0, config.vocab_size, (batch_size, seq_len), device=device ) input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device) input_mask[:, 4:] = False with torch.no_grad(): output = model(input_ids, input_mask=input_mask) assert output["logits"].shape == (batch_size, seq_len, config.vocab_size) assert not torch.isnan(output["logits"]).any()