356 lines
9.3 KiB
Python
356 lines
9.3 KiB
Python
import tempfile
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
|
from astrai.model import AutoRegressiveLM
|
|
from astrai.model.components.linear import Linear
|
|
from astrai.model.components.lora import (
|
|
LoRAConfig,
|
|
LoRALinear,
|
|
_collect_lora_info,
|
|
_get_lora_count,
|
|
inject_lora,
|
|
load_lora,
|
|
merge_lora,
|
|
save_lora,
|
|
)
|
|
|
|
MODEL_KWARGS = dict(
|
|
vocab_size=1000,
|
|
dim=64,
|
|
n_heads=4,
|
|
n_kv_heads=2,
|
|
dim_ffn=128,
|
|
n_layers=2,
|
|
max_len=32,
|
|
norm_eps=1e-5,
|
|
)
|
|
|
|
|
|
def _make_model(**kwargs):
|
|
kw = {**MODEL_KWARGS, **kwargs}
|
|
config = AutoRegressiveLMConfig(**kw)
|
|
model = AutoRegressiveLM(config)
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def test_loralinear_init():
|
|
base = Linear(64, 128)
|
|
lora = LoRALinear(base, r=8, alpha=16)
|
|
|
|
assert lora.weight is base.weight
|
|
assert not lora.weight.requires_grad
|
|
assert lora.lora_A.shape == (8, 64)
|
|
assert lora.lora_B.shape == (128, 8)
|
|
assert lora.scaling == 2.0
|
|
assert not lora._merged
|
|
assert lora.lora_A.requires_grad
|
|
assert lora.lora_B.requires_grad
|
|
|
|
|
|
def test_loralinear_forward_init_zero_delta():
|
|
base = Linear(4, 4)
|
|
with torch.no_grad():
|
|
base.weight.zero_()
|
|
|
|
x = torch.randn(2, 4)
|
|
lora = LoRALinear(base, r=2, alpha=2)
|
|
base_out = base(x)
|
|
lora_out = lora(x)
|
|
|
|
assert torch.allclose(base_out, lora_out)
|
|
|
|
|
|
def test_loralinear_forward_with_delta():
|
|
base = Linear(4, 4)
|
|
with torch.no_grad():
|
|
base.weight.zero_()
|
|
|
|
x = torch.randn(2, 4)
|
|
lora = LoRALinear(base, r=2, alpha=2)
|
|
base_out = base(x)
|
|
|
|
with torch.no_grad():
|
|
lora.lora_B.fill_(1.0)
|
|
|
|
lora_out = lora(x)
|
|
assert not torch.allclose(base_out, lora_out)
|
|
|
|
|
|
def test_loralinear_merge():
|
|
base = Linear(4, 4)
|
|
with torch.no_grad():
|
|
base.weight.zero_()
|
|
|
|
x = torch.randn(2, 4)
|
|
lora = LoRALinear(base, r=2, alpha=2)
|
|
with torch.no_grad():
|
|
lora.lora_B.fill_(1.0)
|
|
|
|
out_before = lora(x).clone()
|
|
lora.merge()
|
|
out_after = lora(x)
|
|
|
|
torch.testing.assert_close(out_before, out_after)
|
|
assert lora._merged
|
|
assert not hasattr(lora, "lora_A")
|
|
|
|
|
|
def test_loralinear_merge_is_idempotent():
|
|
base = Linear(4, 4)
|
|
with torch.no_grad():
|
|
base.weight.zero_()
|
|
|
|
lora = LoRALinear(base, r=2, alpha=2)
|
|
with torch.no_grad():
|
|
lora.lora_B.fill_(1.0)
|
|
|
|
lora.merge()
|
|
lora.merge()
|
|
|
|
|
|
def test_inject_lora_default_target():
|
|
model = _make_model()
|
|
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
|
|
|
|
inject_lora(model, r=4, alpha=8)
|
|
|
|
lora_count = _get_lora_count(model)
|
|
assert lora_count > 0
|
|
assert lora_count < n_before
|
|
|
|
|
|
def test_inject_lora_ffn():
|
|
model = _make_model()
|
|
from astrai.model.components.lora import TARGET_MODULES_FFN
|
|
|
|
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
|
|
assert _get_lora_count(model) > 0
|
|
|
|
|
|
def test_inject_lora_returns_config():
|
|
model = _make_model()
|
|
cfg = inject_lora(model, r=8, alpha=32)
|
|
assert isinstance(cfg, LoRAConfig)
|
|
assert cfg.r == 8
|
|
assert cfg.alpha == 32
|
|
|
|
|
|
def test_inject_lora_no_matching_targets_warns(caplog):
|
|
model = _make_model()
|
|
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
|
|
assert "No LoRA layers injected" in caplog.text
|
|
|
|
|
|
def test_inject_lora_preserves_base_output():
|
|
model = _make_model()
|
|
x = torch.randint(0, 1000, (2, 16))
|
|
|
|
with torch.no_grad():
|
|
out_before = model(x)["logits"].clone()
|
|
|
|
inject_lora(model, r=4, alpha=8)
|
|
|
|
with torch.no_grad():
|
|
out_after = model(x)["logits"]
|
|
|
|
torch.testing.assert_close(out_before, out_after)
|
|
|
|
|
|
def test_inject_lora_does_not_reinject():
|
|
model = _make_model()
|
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
first_count = _get_lora_count(model)
|
|
|
|
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
|
|
assert _get_lora_count(model) == first_count
|
|
|
|
|
|
def test_inject_lora_adds_new_modules():
|
|
model = _make_model()
|
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
first = _get_lora_count(model)
|
|
|
|
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
|
|
assert _get_lora_count(model) > first
|
|
|
|
|
|
def test_inject_lora_on_mla_model():
|
|
model = _make_model(
|
|
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
|
|
)
|
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
|
|
assert _get_lora_count(model) > 0
|
|
|
|
|
|
def test_inject_lora_on_moe_model():
|
|
model = _make_model(
|
|
ffn_type="moe",
|
|
n_routed_experts=4,
|
|
n_shared_experts=1,
|
|
n_activated_experts=2,
|
|
dim_ffn=32,
|
|
)
|
|
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
|
|
assert _get_lora_count(model) > 0
|
|
|
|
|
|
def test_state_dict_key_format():
|
|
model = _make_model()
|
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
|
|
sd = model.state_dict()
|
|
assert "layers.0.attention.q_proj.weight" in sd
|
|
assert "layers.0.attention.q_proj.lora_A" in sd
|
|
assert "layers.0.attention.q_proj.lora_B" in sd
|
|
|
|
|
|
def test_only_lora_params_trainable():
|
|
model = _make_model()
|
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
|
|
|
|
for name, param in model.named_parameters():
|
|
if isinstance(name.split(".")[-1], str) and "lora" in name:
|
|
assert param.requires_grad, f"lora param should be trainable: {name}"
|
|
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
|
|
assert not param.requires_grad, f"injected weight should be frozen: {name}"
|
|
|
|
|
|
def test_state_dict_after_inject_consistent_with_original():
|
|
model = _make_model()
|
|
sd_before = {k: v for k, v in model.state_dict().items()}
|
|
|
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
sd_after = model.state_dict()
|
|
|
|
# original keys unchanged
|
|
for k in sd_before:
|
|
assert k in sd_after
|
|
assert sd_before[k].shape == sd_after[k].shape
|
|
|
|
# new lora keys present
|
|
lora_keys = [k for k in sd_after if "lora" in k]
|
|
assert len(lora_keys) > 0
|
|
|
|
|
|
def test_save_load_roundtrip():
|
|
model = _make_model()
|
|
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
|
|
with torch.no_grad():
|
|
for m in model.modules():
|
|
if isinstance(m, LoRALinear):
|
|
m.lora_B.fill_(0.5)
|
|
|
|
x = torch.randint(0, 1000, (2, 16))
|
|
with torch.no_grad():
|
|
out_src = model(x)["logits"].clone()
|
|
|
|
tmpdir = tempfile.mkdtemp()
|
|
save_lora(model, tmpdir, cfg)
|
|
|
|
model2 = _make_model()
|
|
model2.load_state_dict(model.state_dict(), strict=False)
|
|
load_lora(model2, tmpdir)
|
|
|
|
with torch.no_grad():
|
|
out_dst = model2(x)["logits"]
|
|
|
|
torch.testing.assert_close(out_src, out_dst)
|
|
|
|
|
|
def test_save_after_merge_raises():
|
|
model = _make_model()
|
|
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
|
|
with torch.no_grad():
|
|
for m in model.modules():
|
|
if isinstance(m, LoRALinear):
|
|
m.lora_B.fill_(0.5)
|
|
|
|
tmpdir = tempfile.mkdtemp()
|
|
save_lora(model, tmpdir, cfg)
|
|
merge_lora(model)
|
|
|
|
tmpdir2 = tempfile.mkdtemp()
|
|
with pytest.raises(RuntimeError, match="No LoRA parameters"):
|
|
save_lora(model, tmpdir2, cfg)
|
|
|
|
|
|
def test_load_lora_on_already_injected():
|
|
model = _make_model()
|
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
|
|
with torch.no_grad():
|
|
for m in model.modules():
|
|
if isinstance(m, LoRALinear):
|
|
m.lora_B.fill_(0.5)
|
|
|
|
tmpdir = tempfile.mkdtemp()
|
|
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
|
|
|
|
model2 = _make_model()
|
|
model2.load_state_dict(model.state_dict(), strict=False)
|
|
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
|
|
|
# load onto already-injected model
|
|
load_lora(model2, tmpdir)
|
|
assert _get_lora_count(model2) > 0
|
|
|
|
|
|
def test_load_lora_mismatched_r_raises():
|
|
model = _make_model()
|
|
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
|
|
|
|
with torch.no_grad():
|
|
for m in model.modules():
|
|
if isinstance(m, LoRALinear):
|
|
m.lora_B.fill_(0.5)
|
|
|
|
tmpdir = tempfile.mkdtemp()
|
|
save_lora(model, tmpdir, cfg)
|
|
|
|
model2 = _make_model()
|
|
model2.load_state_dict(model.state_dict(), strict=False)
|
|
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
|
|
|
with pytest.raises(RuntimeError, match="size mismatch"):
|
|
load_lora(model2, tmpdir) # strict=False, only lora keys
|
|
|
|
|
|
def test_merge_preserves_output():
|
|
model = _make_model()
|
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
|
|
with torch.no_grad():
|
|
for m in model.modules():
|
|
if isinstance(m, LoRALinear):
|
|
m.lora_B.fill_(0.5)
|
|
|
|
x = torch.randint(0, 1000, (2, 16))
|
|
with torch.no_grad():
|
|
out_before = model(x)["logits"].clone()
|
|
|
|
merge_lora(model)
|
|
|
|
with torch.no_grad():
|
|
out_after = model(x)["logits"]
|
|
torch.testing.assert_close(out_before, out_after)
|
|
|
|
|
|
def test_merge_no_lora_warns(caplog):
|
|
model = _make_model()
|
|
merge_lora(model)
|
|
assert "No LoRA layers to merge" in caplog.text
|
|
|
|
|
|
def test_collect_lora_info():
|
|
model = _make_model()
|
|
info = _collect_lora_info(model)
|
|
assert "q_proj" in info
|
|
assert "o_proj" in info
|
|
assert "q_proj" in info # each layer has one
|