167 lines
4.7 KiB
Python
167 lines
4.7 KiB
Python
import torch
|
|
|
|
from astrai.config.model_config import EncoderConfig
|
|
from astrai.model.encoder import EmbeddingEncoder
|
|
|
|
TINY_CONFIG = dict(
|
|
vocab_size=128,
|
|
dim=8,
|
|
n_heads=2,
|
|
n_kv_heads=1,
|
|
dim_ffn=16,
|
|
max_len=64,
|
|
n_layers=2,
|
|
norm_eps=1e-5,
|
|
)
|
|
|
|
|
|
def test_encoder_forward_mean():
|
|
config = EncoderConfig(**TINY_CONFIG)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = EmbeddingEncoder(config).to(device=device)
|
|
model.eval()
|
|
|
|
batch_size, seq_len = 2, 8
|
|
input_ids = torch.randint(
|
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
)
|
|
|
|
with torch.no_grad():
|
|
output = model(input_ids)
|
|
|
|
assert output.shape == (batch_size, config.dim)
|
|
assert not torch.isnan(output).any()
|
|
|
|
|
|
def test_encoder_forward_cls():
|
|
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "cls"})
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = EmbeddingEncoder(config).to(device=device)
|
|
model.eval()
|
|
|
|
batch_size, seq_len = 2, 8
|
|
input_ids = torch.randint(
|
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
)
|
|
|
|
with torch.no_grad():
|
|
output = model(input_ids)
|
|
|
|
assert output.shape == (batch_size, config.dim)
|
|
assert not torch.isnan(output).any()
|
|
|
|
|
|
def test_encoder_forward_last():
|
|
config = EncoderConfig(**{**TINY_CONFIG, "pooling_type": "last"})
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = EmbeddingEncoder(config).to(device=device)
|
|
model.eval()
|
|
|
|
batch_size, seq_len = 2, 8
|
|
input_ids = torch.randint(
|
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
)
|
|
|
|
with torch.no_grad():
|
|
output = model(input_ids)
|
|
|
|
assert output.shape == (batch_size, config.dim)
|
|
assert not torch.isnan(output).any()
|
|
|
|
|
|
def test_encoder_forward_with_padding():
|
|
config = EncoderConfig(**TINY_CONFIG)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = EmbeddingEncoder(config).to(device=device)
|
|
model.eval()
|
|
|
|
batch_size, seq_len = 2, 8
|
|
input_ids = torch.randint(
|
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
)
|
|
input_mask = torch.ones(batch_size, seq_len, dtype=torch.bool, device=device)
|
|
input_mask[:, 4:] = False
|
|
|
|
with torch.no_grad():
|
|
output = model(input_ids, input_mask=input_mask)
|
|
|
|
assert output.shape == (batch_size, config.dim)
|
|
assert not torch.isnan(output).any()
|
|
|
|
|
|
def test_encoder_normalize():
|
|
config = EncoderConfig(
|
|
**{**TINY_CONFIG, "pooling_type": "mean", "normalize_embeddings": True}
|
|
)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = EmbeddingEncoder(config).to(device=device)
|
|
model.eval()
|
|
|
|
batch_size, seq_len = 2, 8
|
|
input_ids = torch.randint(
|
|
0, config.vocab_size, (batch_size, seq_len), device=device
|
|
)
|
|
|
|
with torch.no_grad():
|
|
output = model(input_ids)
|
|
|
|
norms = output.norm(p=2, dim=-1)
|
|
assert torch.allclose(norms, torch.ones_like(norms), atol=1e-4)
|
|
|
|
|
|
def test_encoder_register():
|
|
from astrai.model.automodel import AutoModel
|
|
|
|
assert AutoModel.is_registered("embedding")
|
|
cls = AutoModel.get_component_class("embedding")
|
|
assert cls is EmbeddingEncoder
|
|
|
|
|
|
def test_encoder_from_transformer_checkpoint():
|
|
config = EncoderConfig(**TINY_CONFIG)
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model = EmbeddingEncoder(config).to(device=device)
|
|
|
|
state_dict = model.state_dict()
|
|
state_dict["lm_head.weight"] = torch.randn(
|
|
config.vocab_size, config.dim, device=device
|
|
)
|
|
|
|
new_model = EmbeddingEncoder(config).to(device=device)
|
|
new_model.load_state_dict(state_dict, strict=True)
|
|
|
|
for key in model.state_dict():
|
|
assert torch.equal(new_model.state_dict()[key], model.state_dict()[key])
|
|
|
|
|
|
def test_encoder_save_load():
|
|
import json
|
|
import os
|
|
import tempfile
|
|
|
|
import safetensors.torch as st
|
|
|
|
test_dir = tempfile.mkdtemp(prefix="encoder_test_")
|
|
config_path = os.path.join(test_dir, "config.json")
|
|
weights_path = os.path.join(test_dir, "model.safetensors")
|
|
|
|
try:
|
|
config_data = {**TINY_CONFIG, "pooling_type": "mean"}
|
|
with open(config_path, "w") as f:
|
|
json.dump(config_data, f)
|
|
|
|
config = EncoderConfig.from_file(config_path)
|
|
original = EmbeddingEncoder(config)
|
|
st.save_file(original.state_dict(), weights_path)
|
|
|
|
loaded = EmbeddingEncoder(config)
|
|
loaded.load_state_dict(st.load_file(weights_path))
|
|
|
|
for key in original.state_dict():
|
|
assert torch.equal(original.state_dict()[key], loaded.state_dict()[key])
|
|
finally:
|
|
if os.path.exists(test_dir):
|
|
for f in os.listdir(test_dir):
|
|
os.remove(os.path.join(test_dir, f))
|
|
os.rmdir(test_dir)
|