AstrAI/tests/module/test_lora.py

356 lines
9.3 KiB
Python

import tempfile
import pytest
import torch
from astrai.config.model_config import AutoRegressiveLMConfig
from astrai.model import AutoRegressiveLM
from astrai.model.components.lora import (
LoRAConfig,
LoRALinear,
_get_lora_count,
_collect_lora_info,
inject_lora,
load_lora,
merge_lora,
save_lora,
)
from astrai.model.components.linear import Linear
MODEL_KWARGS = dict(
vocab_size=1000,
dim=64,
n_heads=4,
n_kv_heads=2,
dim_ffn=128,
n_layers=2,
max_len=32,
norm_eps=1e-5,
)
def _make_model(**kwargs):
kw = {**MODEL_KWARGS, **kwargs}
config = AutoRegressiveLMConfig(**kw)
model = AutoRegressiveLM(config)
model.eval()
return model
def test_loralinear_init():
base = Linear(64, 128)
lora = LoRALinear(base, r=8, alpha=16)
assert lora.weight is base.weight
assert not lora.weight.requires_grad
assert lora.lora_A.shape == (8, 64)
assert lora.lora_B.shape == (128, 8)
assert lora.scaling == 2.0
assert not lora._merged
assert lora.lora_A.requires_grad
assert lora.lora_B.requires_grad
def test_loralinear_forward_init_zero_delta():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
base_out = base(x)
lora_out = lora(x)
assert torch.allclose(base_out, lora_out)
def test_loralinear_forward_with_delta():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
base_out = base(x)
with torch.no_grad():
lora.lora_B.fill_(1.0)
lora_out = lora(x)
assert not torch.allclose(base_out, lora_out)
def test_loralinear_merge():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
x = torch.randn(2, 4)
lora = LoRALinear(base, r=2, alpha=2)
with torch.no_grad():
lora.lora_B.fill_(1.0)
out_before = lora(x).clone()
lora.merge()
out_after = lora(x)
torch.testing.assert_close(out_before, out_after)
assert lora._merged
assert not hasattr(lora, "lora_A")
def test_loralinear_merge_is_idempotent():
base = Linear(4, 4)
with torch.no_grad():
base.weight.zero_()
lora = LoRALinear(base, r=2, alpha=2)
with torch.no_grad():
lora.lora_B.fill_(1.0)
lora.merge()
lora.merge()
def test_inject_lora_default_target():
model = _make_model()
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
inject_lora(model, r=4, alpha=8)
lora_count = _get_lora_count(model)
assert lora_count > 0
assert lora_count < n_before
def test_inject_lora_ffn():
model = _make_model()
from astrai.model.components.lora import TARGET_MODULES_FFN
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
assert _get_lora_count(model) > 0
def test_inject_lora_returns_config():
model = _make_model()
cfg = inject_lora(model, r=8, alpha=32)
assert isinstance(cfg, LoRAConfig)
assert cfg.r == 8
assert cfg.alpha == 32
def test_inject_lora_no_matching_targets_warns(caplog):
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
assert "No LoRA layers injected" in caplog.text
def test_inject_lora_preserves_base_output():
model = _make_model()
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_before = model(x)["logits"].clone()
inject_lora(model, r=4, alpha=8)
with torch.no_grad():
out_after = model(x)["logits"]
torch.testing.assert_close(out_before, out_after)
def test_inject_lora_does_not_reinject():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
first_count = _get_lora_count(model)
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
assert _get_lora_count(model) == first_count
def test_inject_lora_adds_new_modules():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
first = _get_lora_count(model)
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
assert _get_lora_count(model) > first
def test_inject_lora_on_mla_model():
model = _make_model(
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
)
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
assert _get_lora_count(model) > 0
def test_inject_lora_on_moe_model():
model = _make_model(
ffn_type="moe",
n_routed_experts=4,
n_shared_experts=1,
n_activated_experts=2,
dim_ffn=32,
)
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
assert _get_lora_count(model) > 0
def test_state_dict_key_format():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
sd = model.state_dict()
assert "layers.0.attention.q_proj.weight" in sd
assert "layers.0.attention.q_proj.lora_A" in sd
assert "layers.0.attention.q_proj.lora_B" in sd
def test_only_lora_params_trainable():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
for name, param in model.named_parameters():
if isinstance(name.split(".")[-1], str) and "lora" in name:
assert param.requires_grad, f"lora param should be trainable: {name}"
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
assert not param.requires_grad, f"injected weight should be frozen: {name}"
def test_state_dict_after_inject_consistent_with_original():
model = _make_model()
sd_before = {k: v for k, v in model.state_dict().items()}
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
sd_after = model.state_dict()
# original keys unchanged
for k in sd_before:
assert k in sd_after
assert sd_before[k].shape == sd_after[k].shape
# new lora keys present
lora_keys = [k for k in sd_after if "lora" in k]
assert len(lora_keys) > 0
def test_save_load_roundtrip():
model = _make_model()
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_src = model(x)["logits"].clone()
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
load_lora(model2, tmpdir)
with torch.no_grad():
out_dst = model2(x)["logits"]
torch.testing.assert_close(out_src, out_dst)
def test_save_after_merge_raises():
model = _make_model()
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
merge_lora(model)
tmpdir2 = tempfile.mkdtemp()
with pytest.raises(RuntimeError, match="No LoRA parameters"):
save_lora(model, tmpdir2, cfg)
def test_load_lora_on_already_injected():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
# load onto already-injected model
load_lora(model2, tmpdir)
assert _get_lora_count(model2) > 0
def test_load_lora_mismatched_r_raises():
model = _make_model()
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
tmpdir = tempfile.mkdtemp()
save_lora(model, tmpdir, cfg)
model2 = _make_model()
model2.load_state_dict(model.state_dict(), strict=False)
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
with pytest.raises(RuntimeError, match="size mismatch"):
load_lora(model2, tmpdir) # strict=False, only lora keys
def test_merge_preserves_output():
model = _make_model()
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
with torch.no_grad():
for m in model.modules():
if isinstance(m, LoRALinear):
m.lora_B.fill_(0.5)
x = torch.randint(0, 1000, (2, 16))
with torch.no_grad():
out_before = model(x)["logits"].clone()
merge_lora(model)
with torch.no_grad():
out_after = model(x)["logits"]
torch.testing.assert_close(out_before, out_after)
def test_merge_no_lora_warns(caplog):
model = _make_model()
merge_lora(model)
assert "No LoRA layers to merge" in caplog.text
def test_collect_lora_info():
model = _make_model()
info = _collect_lora_info(model)
assert "q_proj" in info
assert "o_proj" in info
assert "q_proj" in info # each layer has one