diff --git a/astrai/model/__init__.py b/astrai/model/__init__.py index 004a7ed..5bdba1c 100644 --- a/astrai/model/__init__.py +++ b/astrai/model/__init__.py @@ -2,6 +2,13 @@ from astrai.model.automodel import AutoModel from astrai.model.components.attention import GQA from astrai.model.components.decoder_block import DecoderBlock from astrai.model.components.linear import Linear +from astrai.model.components.lora import ( + LoRAConfig, + inject_lora, + load_lora, + merge_lora, + save_lora, +) from astrai.model.components.mlp import MLP from astrai.model.components.norm import RMSNorm from astrai.model.encoder import EmbeddingEncoder @@ -18,4 +25,10 @@ __all__ = [ "AutoRegressiveLM", "EmbeddingEncoder", "AutoModel", + # LoRA + "LoRAConfig", + "inject_lora", + "merge_lora", + "save_lora", + "load_lora", ] diff --git a/astrai/model/components/lora.py b/astrai/model/components/lora.py new file mode 100644 index 0000000..197bd43 --- /dev/null +++ b/astrai/model/components/lora.py @@ -0,0 +1,192 @@ +import json +import logging +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Optional, Set + +import safetensors.torch as st +import torch +import torch.nn as nn +import torch.nn.functional as F + +from astrai.model.components.linear import Linear + +logger = logging.getLogger(__name__) + +TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"} +TARGET_MODULES_FFN = {"up", "gate", "down"} + + +@dataclass +class LoRAConfig: + r: int = 16 + alpha: int = 32 + target_modules: tuple = ("q_proj", "v_proj") + + +class LoRALinear(nn.Module): + def __init__(self, base: Linear, r: int = 16, alpha: int = 32): + super().__init__() + self.register_parameter("weight", base.weight) + self.weight.requires_grad_(False) + self.bias = base.bias + if self.bias is not None: + self.bias.requires_grad_(False) + + self.r = r + self.scaling = alpha / r + self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r) + self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r)) + self._merged = False + + def forward(self, x): + out = F.linear(x, self.weight, self.bias) + if not self._merged: + out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling + return out + + def merge(self): + if self._merged: + return + self.weight.data += (self.lora_B @ self.lora_A) * self.scaling + self._merged = True + del self.lora_A + del self.lora_B + + +def _collect_lora_info(model: nn.Module) -> dict: + names = {} + for n, m in model.named_modules(): + if isinstance(m, Linear): + _, _, child = n.rpartition(".") + names.setdefault(child, []).append(n) + return names + + +def _get_lora_count(model: nn.Module) -> int: + return sum(1 for m in model.modules() if isinstance(m, LoRALinear)) + + +def inject_lora( + model: nn.Module, + r: int = 16, + alpha: int = 32, + target_modules: Optional[Set[str]] = None, +) -> LoRAConfig: + if target_modules is None: + target_modules = TARGET_MODULES_ATTN + + available = _collect_lora_info(model) + injected = 0 + + for name, module in list(model.named_modules()): + if not isinstance(module, Linear): + continue + parent_name, _, child_name = name.rpartition(".") + if child_name not in target_modules: + continue + parent = model.get_submodule(parent_name) if parent_name else model + setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha)) + injected += 1 + + if injected == 0: + logger.warning( + "No LoRA layers injected. Available Linear child names: %s. " + "target_modules: %s. Check model type and target_modules.", + sorted(available), + sorted(target_modules), + ) + else: + logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha) + + return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules)) + + +def merge_lora(model: nn.Module): + n = 0 + for module in model.modules(): + if isinstance(module, LoRALinear): + module.merge() + n += 1 + if n == 0: + logger.warning("No LoRA layers to merge.") + else: + logger.info("Merged %d LoRA layers", n) + + +def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig): + lora_sd = { + k: v + for k, v in model.state_dict().items() + if k.endswith((".lora_A", ".lora_B")) + } + if not lora_sd: + raise RuntimeError( + "No LoRA parameters found in model. " + "The model may not have been injected or was already merged." + ) + + path = Path(save_dir) + path.mkdir(parents=True, exist_ok=True) + st.save_file(lora_sd, str(path / "adapter_model.safetensors")) + with open(path / "adapter_config.json", "w") as f: + json.dump(asdict(config), f, indent=2) + logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd)) + + +def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig: + path = Path(load_dir) + with open(path / "adapter_config.json") as f: + raw = json.load(f) + config = LoRAConfig( + r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"]) + ) + + existing = _get_lora_count(model) + if existing > 0: + logger.warning( + "Model already has %d LoRA layers. Skipping injection, " + "loading weights onto existing layers only.", + existing, + ) + else: + inject_lora( + model, + r=config.r, + alpha=config.alpha, + target_modules=set(config.target_modules), + ) + + weights = st.load_file(str(path / "adapter_model.safetensors")) + try: + missing, unexpected = model.load_state_dict(weights, strict=False) + except RuntimeError as e: + msg = str(e) + if "size mismatch" in msg: + raise RuntimeError( + f"LoRA weight shapes do not match the model. " + f"The adapter config (r={config.r}) may not match the injected layers. " + f"Original error: {msg}" + ) from e + raise + + injected = _get_lora_count(model) + if injected == 0: + raise RuntimeError( + "No LoRA layers found after loading. " + "Inject LoRA before calling load_lora, or check the adapter config." + ) + + if missing: + lora_missing = [k for k in missing if "lora" in k] + if lora_missing: + raise RuntimeError( + f"LoRA weight keys not found in model: {lora_missing}. " + f"The adapter config (r={config.r}) may not match the model." + ) + logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing)) + if unexpected: + logger.warning("LoRA load: %d unexpected keys", len(unexpected)) + + logger.info("LoRA adapter loaded from %s", load_dir) + return config diff --git a/tests/module/test_lora.py b/tests/module/test_lora.py new file mode 100644 index 0000000..31f7852 --- /dev/null +++ b/tests/module/test_lora.py @@ -0,0 +1,355 @@ +import tempfile + +import pytest +import torch + +from astrai.config.model_config import AutoRegressiveLMConfig +from astrai.model import AutoRegressiveLM +from astrai.model.components.lora import ( + LoRAConfig, + LoRALinear, + _get_lora_count, + _collect_lora_info, + inject_lora, + load_lora, + merge_lora, + save_lora, +) +from astrai.model.components.linear import Linear + +MODEL_KWARGS = dict( + vocab_size=1000, + dim=64, + n_heads=4, + n_kv_heads=2, + dim_ffn=128, + n_layers=2, + max_len=32, + norm_eps=1e-5, +) + + +def _make_model(**kwargs): + kw = {**MODEL_KWARGS, **kwargs} + config = AutoRegressiveLMConfig(**kw) + model = AutoRegressiveLM(config) + model.eval() + return model + + +def test_loralinear_init(): + base = Linear(64, 128) + lora = LoRALinear(base, r=8, alpha=16) + + assert lora.weight is base.weight + assert not lora.weight.requires_grad + assert lora.lora_A.shape == (8, 64) + assert lora.lora_B.shape == (128, 8) + assert lora.scaling == 2.0 + assert not lora._merged + assert lora.lora_A.requires_grad + assert lora.lora_B.requires_grad + + +def test_loralinear_forward_init_zero_delta(): + base = Linear(4, 4) + with torch.no_grad(): + base.weight.zero_() + + x = torch.randn(2, 4) + lora = LoRALinear(base, r=2, alpha=2) + base_out = base(x) + lora_out = lora(x) + + assert torch.allclose(base_out, lora_out) + + +def test_loralinear_forward_with_delta(): + base = Linear(4, 4) + with torch.no_grad(): + base.weight.zero_() + + x = torch.randn(2, 4) + lora = LoRALinear(base, r=2, alpha=2) + base_out = base(x) + + with torch.no_grad(): + lora.lora_B.fill_(1.0) + + lora_out = lora(x) + assert not torch.allclose(base_out, lora_out) + + +def test_loralinear_merge(): + base = Linear(4, 4) + with torch.no_grad(): + base.weight.zero_() + + x = torch.randn(2, 4) + lora = LoRALinear(base, r=2, alpha=2) + with torch.no_grad(): + lora.lora_B.fill_(1.0) + + out_before = lora(x).clone() + lora.merge() + out_after = lora(x) + + torch.testing.assert_close(out_before, out_after) + assert lora._merged + assert not hasattr(lora, "lora_A") + + +def test_loralinear_merge_is_idempotent(): + base = Linear(4, 4) + with torch.no_grad(): + base.weight.zero_() + + lora = LoRALinear(base, r=2, alpha=2) + with torch.no_grad(): + lora.lora_B.fill_(1.0) + + lora.merge() + lora.merge() + + +def test_inject_lora_default_target(): + model = _make_model() + n_before = sum(1 for m in model.modules() if isinstance(m, Linear)) + + inject_lora(model, r=4, alpha=8) + + lora_count = _get_lora_count(model) + assert lora_count > 0 + assert lora_count < n_before + + +def test_inject_lora_ffn(): + model = _make_model() + from astrai.model.components.lora import TARGET_MODULES_FFN + + inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN) + assert _get_lora_count(model) > 0 + + +def test_inject_lora_returns_config(): + model = _make_model() + cfg = inject_lora(model, r=8, alpha=32) + assert isinstance(cfg, LoRAConfig) + assert cfg.r == 8 + assert cfg.alpha == 32 + + +def test_inject_lora_no_matching_targets_warns(caplog): + model = _make_model() + inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"}) + assert "No LoRA layers injected" in caplog.text + + +def test_inject_lora_preserves_base_output(): + model = _make_model() + x = torch.randint(0, 1000, (2, 16)) + + with torch.no_grad(): + out_before = model(x)["logits"].clone() + + inject_lora(model, r=4, alpha=8) + + with torch.no_grad(): + out_after = model(x)["logits"] + + torch.testing.assert_close(out_before, out_after) + + +def test_inject_lora_does_not_reinject(): + model = _make_model() + inject_lora(model, r=4, alpha=8, target_modules={"q_proj"}) + first_count = _get_lora_count(model) + + inject_lora(model, r=2, alpha=4, target_modules={"q_proj"}) + assert _get_lora_count(model) == first_count + + +def test_inject_lora_adds_new_modules(): + model = _make_model() + inject_lora(model, r=4, alpha=8, target_modules={"q_proj"}) + first = _get_lora_count(model) + + inject_lora(model, r=4, alpha=8, target_modules={"v_proj"}) + assert _get_lora_count(model) > first + + +def test_inject_lora_on_mla_model(): + model = _make_model( + attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16 + ) + inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"}) + assert _get_lora_count(model) > 0 + + +def test_inject_lora_on_moe_model(): + model = _make_model( + ffn_type="moe", + n_routed_experts=4, + n_shared_experts=1, + n_activated_experts=2, + dim_ffn=32, + ) + inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"}) + assert _get_lora_count(model) > 0 + + +def test_state_dict_key_format(): + model = _make_model() + inject_lora(model, r=4, alpha=8, target_modules={"q_proj"}) + + sd = model.state_dict() + assert "layers.0.attention.q_proj.weight" in sd + assert "layers.0.attention.q_proj.lora_A" in sd + assert "layers.0.attention.q_proj.lora_B" in sd + + +def test_only_lora_params_trainable(): + model = _make_model() + inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"}) + + for name, param in model.named_parameters(): + if isinstance(name.split(".")[-1], str) and "lora" in name: + assert param.requires_grad, f"lora param should be trainable: {name}" + elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")): + assert not param.requires_grad, f"injected weight should be frozen: {name}" + + +def test_state_dict_after_inject_consistent_with_original(): + model = _make_model() + sd_before = {k: v for k, v in model.state_dict().items()} + + inject_lora(model, r=4, alpha=8, target_modules={"q_proj"}) + sd_after = model.state_dict() + + # original keys unchanged + for k in sd_before: + assert k in sd_after + assert sd_before[k].shape == sd_after[k].shape + + # new lora keys present + lora_keys = [k for k in sd_after if "lora" in k] + assert len(lora_keys) > 0 + + +def test_save_load_roundtrip(): + model = _make_model() + cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"}) + + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, LoRALinear): + m.lora_B.fill_(0.5) + + x = torch.randint(0, 1000, (2, 16)) + with torch.no_grad(): + out_src = model(x)["logits"].clone() + + tmpdir = tempfile.mkdtemp() + save_lora(model, tmpdir, cfg) + + model2 = _make_model() + model2.load_state_dict(model.state_dict(), strict=False) + load_lora(model2, tmpdir) + + with torch.no_grad(): + out_dst = model2(x)["logits"] + + torch.testing.assert_close(out_src, out_dst) + + +def test_save_after_merge_raises(): + model = _make_model() + cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"}) + + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, LoRALinear): + m.lora_B.fill_(0.5) + + tmpdir = tempfile.mkdtemp() + save_lora(model, tmpdir, cfg) + merge_lora(model) + + tmpdir2 = tempfile.mkdtemp() + with pytest.raises(RuntimeError, match="No LoRA parameters"): + save_lora(model, tmpdir2, cfg) + + +def test_load_lora_on_already_injected(): + model = _make_model() + inject_lora(model, r=4, alpha=8, target_modules={"q_proj"}) + + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, LoRALinear): + m.lora_B.fill_(0.5) + + tmpdir = tempfile.mkdtemp() + save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",))) + + model2 = _make_model() + model2.load_state_dict(model.state_dict(), strict=False) + inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"}) + + # load onto already-injected model + load_lora(model2, tmpdir) + assert _get_lora_count(model2) > 0 + + +def test_load_lora_mismatched_r_raises(): + model = _make_model() + cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"}) + + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, LoRALinear): + m.lora_B.fill_(0.5) + + tmpdir = tempfile.mkdtemp() + save_lora(model, tmpdir, cfg) + + model2 = _make_model() + model2.load_state_dict(model.state_dict(), strict=False) + inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"}) + + with pytest.raises(RuntimeError, match="size mismatch"): + load_lora(model2, tmpdir) # strict=False, only lora keys + + +def test_merge_preserves_output(): + model = _make_model() + inject_lora(model, r=4, alpha=8, target_modules={"q_proj"}) + + with torch.no_grad(): + for m in model.modules(): + if isinstance(m, LoRALinear): + m.lora_B.fill_(0.5) + + x = torch.randint(0, 1000, (2, 16)) + with torch.no_grad(): + out_before = model(x)["logits"].clone() + + merge_lora(model) + + with torch.no_grad(): + out_after = model(x)["logits"] + torch.testing.assert_close(out_before, out_after) + + +def test_merge_no_lora_warns(caplog): + model = _make_model() + merge_lora(model) + assert "No LoRA layers to merge" in caplog.text + + +def test_collect_lora_info(): + model = _make_model() + info = _collect_lora_info(model) + assert "q_proj" in info + assert "o_proj" in info + assert "q_proj" in info # each layer has one