193 lines
5.9 KiB
Python
193 lines
5.9 KiB
Python
import json
|
|
import logging
|
|
from dataclasses import asdict, dataclass
|
|
from pathlib import Path
|
|
from typing import Optional, Set
|
|
|
|
import safetensors.torch as st
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from astrai.model.components.linear import Linear
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
|
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
|
|
|
|
|
@dataclass
|
|
class LoRAConfig:
|
|
r: int = 16
|
|
alpha: int = 32
|
|
target_modules: tuple = ("q_proj", "v_proj")
|
|
|
|
|
|
class LoRALinear(nn.Module):
|
|
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
|
super().__init__()
|
|
self.register_parameter("weight", base.weight)
|
|
self.weight.requires_grad_(False)
|
|
self.bias = base.bias
|
|
if self.bias is not None:
|
|
self.bias.requires_grad_(False)
|
|
|
|
self.r = r
|
|
self.scaling = alpha / r
|
|
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
|
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
|
self._merged = False
|
|
|
|
def forward(self, x):
|
|
out = F.linear(x, self.weight, self.bias)
|
|
if not self._merged:
|
|
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
|
return out
|
|
|
|
def merge(self):
|
|
if self._merged:
|
|
return
|
|
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
|
self._merged = True
|
|
del self.lora_A
|
|
del self.lora_B
|
|
|
|
|
|
def _collect_lora_info(model: nn.Module) -> dict:
|
|
names = {}
|
|
for n, m in model.named_modules():
|
|
if isinstance(m, Linear):
|
|
_, _, child = n.rpartition(".")
|
|
names.setdefault(child, []).append(n)
|
|
return names
|
|
|
|
|
|
def _get_lora_count(model: nn.Module) -> int:
|
|
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
|
|
|
|
|
def inject_lora(
|
|
model: nn.Module,
|
|
r: int = 16,
|
|
alpha: int = 32,
|
|
target_modules: Optional[Set[str]] = None,
|
|
) -> LoRAConfig:
|
|
if target_modules is None:
|
|
target_modules = TARGET_MODULES_ATTN
|
|
|
|
available = _collect_lora_info(model)
|
|
injected = 0
|
|
|
|
for name, module in list(model.named_modules()):
|
|
if not isinstance(module, Linear):
|
|
continue
|
|
parent_name, _, child_name = name.rpartition(".")
|
|
if child_name not in target_modules:
|
|
continue
|
|
parent = model.get_submodule(parent_name) if parent_name else model
|
|
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
|
injected += 1
|
|
|
|
if injected == 0:
|
|
logger.warning(
|
|
"No LoRA layers injected. Available Linear child names: %s. "
|
|
"target_modules: %s. Check model type and target_modules.",
|
|
sorted(available),
|
|
sorted(target_modules),
|
|
)
|
|
else:
|
|
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
|
|
|
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
|
|
|
|
|
def merge_lora(model: nn.Module):
|
|
n = 0
|
|
for module in model.modules():
|
|
if isinstance(module, LoRALinear):
|
|
module.merge()
|
|
n += 1
|
|
if n == 0:
|
|
logger.warning("No LoRA layers to merge.")
|
|
else:
|
|
logger.info("Merged %d LoRA layers", n)
|
|
|
|
|
|
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
|
lora_sd = {
|
|
k: v
|
|
for k, v in model.state_dict().items()
|
|
if k.endswith((".lora_A", ".lora_B"))
|
|
}
|
|
if not lora_sd:
|
|
raise RuntimeError(
|
|
"No LoRA parameters found in model. "
|
|
"The model may not have been injected or was already merged."
|
|
)
|
|
|
|
path = Path(save_dir)
|
|
path.mkdir(parents=True, exist_ok=True)
|
|
st.save_file(lora_sd, str(path / "adapter_model.safetensors"))
|
|
with open(path / "adapter_config.json", "w") as f:
|
|
json.dump(asdict(config), f, indent=2)
|
|
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
|
|
|
|
|
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
|
path = Path(load_dir)
|
|
with open(path / "adapter_config.json") as f:
|
|
raw = json.load(f)
|
|
config = LoRAConfig(
|
|
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
|
)
|
|
|
|
existing = _get_lora_count(model)
|
|
if existing > 0:
|
|
logger.warning(
|
|
"Model already has %d LoRA layers. Skipping injection, "
|
|
"loading weights onto existing layers only.",
|
|
existing,
|
|
)
|
|
else:
|
|
inject_lora(
|
|
model,
|
|
r=config.r,
|
|
alpha=config.alpha,
|
|
target_modules=set(config.target_modules),
|
|
)
|
|
|
|
weights = st.load_file(str(path / "adapter_model.safetensors"))
|
|
try:
|
|
missing, unexpected = model.load_state_dict(weights, strict=False)
|
|
except RuntimeError as e:
|
|
msg = str(e)
|
|
if "size mismatch" in msg:
|
|
raise RuntimeError(
|
|
f"LoRA weight shapes do not match the model. "
|
|
f"The adapter config (r={config.r}) may not match the injected layers. "
|
|
f"Original error: {msg}"
|
|
) from e
|
|
raise
|
|
|
|
injected = _get_lora_count(model)
|
|
if injected == 0:
|
|
raise RuntimeError(
|
|
"No LoRA layers found after loading. "
|
|
"Inject LoRA before calling load_lora, or check the adapter config."
|
|
)
|
|
|
|
if missing:
|
|
lora_missing = [k for k in missing if "lora" in k]
|
|
if lora_missing:
|
|
raise RuntimeError(
|
|
f"LoRA weight keys not found in model: {lora_missing}. "
|
|
f"The adapter config (r={config.r}) may not match the model."
|
|
)
|
|
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
|
if unexpected:
|
|
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
|
|
|
logger.info("LoRA adapter loaded from %s", load_dir)
|
|
return config
|