import json import logging from dataclasses import asdict, dataclass from pathlib import Path from typing import Optional, Set import safetensors.torch as st import torch import torch.nn as nn import torch.nn.functional as F from astrai.model.components.linear import Linear logger = logging.getLogger(__name__) TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"} TARGET_MODULES_FFN = {"up", "gate", "down"} @dataclass class LoRAConfig: r: int = 16 alpha: int = 32 target_modules: tuple = ("q_proj", "v_proj") class LoRALinear(nn.Module): def __init__(self, base: Linear, r: int = 16, alpha: int = 32): super().__init__() self.register_parameter("weight", base.weight) self.weight.requires_grad_(False) self.bias = base.bias if self.bias is not None: self.bias.requires_grad_(False) self.r = r self.scaling = alpha / r self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r) self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r)) self._merged = False def forward(self, x): out = F.linear(x, self.weight, self.bias) if not self._merged: out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling return out def merge(self): if self._merged: return self.weight.data += (self.lora_B @ self.lora_A) * self.scaling self._merged = True del self.lora_A del self.lora_B def _collect_lora_info(model: nn.Module) -> dict: names = {} for n, m in model.named_modules(): if isinstance(m, Linear): _, _, child = n.rpartition(".") names.setdefault(child, []).append(n) return names def _get_lora_count(model: nn.Module) -> int: return sum(1 for m in model.modules() if isinstance(m, LoRALinear)) def inject_lora( model: nn.Module, r: int = 16, alpha: int = 32, target_modules: Optional[Set[str]] = None, ) -> LoRAConfig: if target_modules is None: target_modules = TARGET_MODULES_ATTN available = _collect_lora_info(model) injected = 0 for name, module in list(model.named_modules()): if not isinstance(module, Linear): continue parent_name, _, child_name = name.rpartition(".") if child_name not in target_modules: continue parent = model.get_submodule(parent_name) if parent_name else model setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha)) injected += 1 if injected == 0: logger.warning( "No LoRA layers injected. Available Linear child names: %s. " "target_modules: %s. Check model type and target_modules.", sorted(available), sorted(target_modules), ) else: logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha) return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules)) def merge_lora(model: nn.Module): n = 0 for module in model.modules(): if isinstance(module, LoRALinear): module.merge() n += 1 if n == 0: logger.warning("No LoRA layers to merge.") else: logger.info("Merged %d LoRA layers", n) def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig): lora_sd = { k: v for k, v in model.state_dict().items() if k.endswith((".lora_A", ".lora_B")) } if not lora_sd: raise RuntimeError( "No LoRA parameters found in model. " "The model may not have been injected or was already merged." ) path = Path(save_dir) path.mkdir(parents=True, exist_ok=True) st.save_file(lora_sd, str(path / "adapter_model.safetensors")) with open(path / "adapter_config.json", "w") as f: json.dump(asdict(config), f, indent=2) logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd)) def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig: path = Path(load_dir) with open(path / "adapter_config.json") as f: raw = json.load(f) config = LoRAConfig( r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"]) ) existing = _get_lora_count(model) if existing > 0: logger.warning( "Model already has %d LoRA layers. Skipping injection, " "loading weights onto existing layers only.", existing, ) else: inject_lora( model, r=config.r, alpha=config.alpha, target_modules=set(config.target_modules), ) weights = st.load_file(str(path / "adapter_model.safetensors")) try: missing, unexpected = model.load_state_dict(weights, strict=False) except RuntimeError as e: msg = str(e) if "size mismatch" in msg: raise RuntimeError( f"LoRA weight shapes do not match the model. " f"The adapter config (r={config.r}) may not match the injected layers. " f"Original error: {msg}" ) from e raise injected = _get_lora_count(model) if injected == 0: raise RuntimeError( "No LoRA layers found after loading. " "Inject LoRA before calling load_lora, or check the adapter config." ) if missing: lora_missing = [k for k in missing if "lora" in k] if lora_missing: raise RuntimeError( f"LoRA weight keys not found in model: {lora_missing}. " f"The adapter config (r={config.r}) may not match the model." ) logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing)) if unexpected: logger.warning("LoRA load: %d unexpected keys", len(unexpected)) logger.info("LoRA adapter loaded from %s", load_dir) return config