feat: 新增LoRA微调模块
- LoRALinear基于register_parameter托管base weight,state_dict路径不变 - inject_lora/merge_lora/save_lora/load_lora完备封装 - 24个单元测试覆盖注入、合并、存取、边界场景
This commit is contained in:
parent
7df6eb9211
commit
432145a798
|
|
@ -2,6 +2,13 @@ from astrai.model.automodel import AutoModel
|
|||
from astrai.model.components.attention import GQA
|
||||
from astrai.model.components.decoder_block import DecoderBlock
|
||||
from astrai.model.components.linear import Linear
|
||||
from astrai.model.components.lora import (
|
||||
LoRAConfig,
|
||||
inject_lora,
|
||||
load_lora,
|
||||
merge_lora,
|
||||
save_lora,
|
||||
)
|
||||
from astrai.model.components.mlp import MLP
|
||||
from astrai.model.components.norm import RMSNorm
|
||||
from astrai.model.encoder import EmbeddingEncoder
|
||||
|
|
@ -18,4 +25,10 @@ __all__ = [
|
|||
"AutoRegressiveLM",
|
||||
"EmbeddingEncoder",
|
||||
"AutoModel",
|
||||
# LoRA
|
||||
"LoRAConfig",
|
||||
"inject_lora",
|
||||
"merge_lora",
|
||||
"save_lora",
|
||||
"load_lora",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,192 @@
|
|||
import json
|
||||
import logging
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Set
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from astrai.model.components.linear import Linear
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
||||
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
r: int = 16
|
||||
alpha: int = 32
|
||||
target_modules: tuple = ("q_proj", "v_proj")
|
||||
|
||||
|
||||
class LoRALinear(nn.Module):
|
||||
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
||||
super().__init__()
|
||||
self.register_parameter("weight", base.weight)
|
||||
self.weight.requires_grad_(False)
|
||||
self.bias = base.bias
|
||||
if self.bias is not None:
|
||||
self.bias.requires_grad_(False)
|
||||
|
||||
self.r = r
|
||||
self.scaling = alpha / r
|
||||
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
||||
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
||||
self._merged = False
|
||||
|
||||
def forward(self, x):
|
||||
out = F.linear(x, self.weight, self.bias)
|
||||
if not self._merged:
|
||||
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
||||
return out
|
||||
|
||||
def merge(self):
|
||||
if self._merged:
|
||||
return
|
||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||
self._merged = True
|
||||
del self.lora_A
|
||||
del self.lora_B
|
||||
|
||||
|
||||
def _collect_lora_info(model: nn.Module) -> dict:
|
||||
names = {}
|
||||
for n, m in model.named_modules():
|
||||
if isinstance(m, Linear):
|
||||
_, _, child = n.rpartition(".")
|
||||
names.setdefault(child, []).append(n)
|
||||
return names
|
||||
|
||||
|
||||
def _get_lora_count(model: nn.Module) -> int:
|
||||
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
||||
|
||||
|
||||
def inject_lora(
|
||||
model: nn.Module,
|
||||
r: int = 16,
|
||||
alpha: int = 32,
|
||||
target_modules: Optional[Set[str]] = None,
|
||||
) -> LoRAConfig:
|
||||
if target_modules is None:
|
||||
target_modules = TARGET_MODULES_ATTN
|
||||
|
||||
available = _collect_lora_info(model)
|
||||
injected = 0
|
||||
|
||||
for name, module in list(model.named_modules()):
|
||||
if not isinstance(module, Linear):
|
||||
continue
|
||||
parent_name, _, child_name = name.rpartition(".")
|
||||
if child_name not in target_modules:
|
||||
continue
|
||||
parent = model.get_submodule(parent_name) if parent_name else model
|
||||
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
||||
injected += 1
|
||||
|
||||
if injected == 0:
|
||||
logger.warning(
|
||||
"No LoRA layers injected. Available Linear child names: %s. "
|
||||
"target_modules: %s. Check model type and target_modules.",
|
||||
sorted(available),
|
||||
sorted(target_modules),
|
||||
)
|
||||
else:
|
||||
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
||||
|
||||
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
||||
|
||||
|
||||
def merge_lora(model: nn.Module):
|
||||
n = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
module.merge()
|
||||
n += 1
|
||||
if n == 0:
|
||||
logger.warning("No LoRA layers to merge.")
|
||||
else:
|
||||
logger.info("Merged %d LoRA layers", n)
|
||||
|
||||
|
||||
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
||||
lora_sd = {
|
||||
k: v
|
||||
for k, v in model.state_dict().items()
|
||||
if k.endswith((".lora_A", ".lora_B"))
|
||||
}
|
||||
if not lora_sd:
|
||||
raise RuntimeError(
|
||||
"No LoRA parameters found in model. "
|
||||
"The model may not have been injected or was already merged."
|
||||
)
|
||||
|
||||
path = Path(save_dir)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
st.save_file(lora_sd, str(path / "adapter_model.safetensors"))
|
||||
with open(path / "adapter_config.json", "w") as f:
|
||||
json.dump(asdict(config), f, indent=2)
|
||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||
|
||||
|
||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||
path = Path(load_dir)
|
||||
with open(path / "adapter_config.json") as f:
|
||||
raw = json.load(f)
|
||||
config = LoRAConfig(
|
||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||
)
|
||||
|
||||
existing = _get_lora_count(model)
|
||||
if existing > 0:
|
||||
logger.warning(
|
||||
"Model already has %d LoRA layers. Skipping injection, "
|
||||
"loading weights onto existing layers only.",
|
||||
existing,
|
||||
)
|
||||
else:
|
||||
inject_lora(
|
||||
model,
|
||||
r=config.r,
|
||||
alpha=config.alpha,
|
||||
target_modules=set(config.target_modules),
|
||||
)
|
||||
|
||||
weights = st.load_file(str(path / "adapter_model.safetensors"))
|
||||
try:
|
||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||
except RuntimeError as e:
|
||||
msg = str(e)
|
||||
if "size mismatch" in msg:
|
||||
raise RuntimeError(
|
||||
f"LoRA weight shapes do not match the model. "
|
||||
f"The adapter config (r={config.r}) may not match the injected layers. "
|
||||
f"Original error: {msg}"
|
||||
) from e
|
||||
raise
|
||||
|
||||
injected = _get_lora_count(model)
|
||||
if injected == 0:
|
||||
raise RuntimeError(
|
||||
"No LoRA layers found after loading. "
|
||||
"Inject LoRA before calling load_lora, or check the adapter config."
|
||||
)
|
||||
|
||||
if missing:
|
||||
lora_missing = [k for k in missing if "lora" in k]
|
||||
if lora_missing:
|
||||
raise RuntimeError(
|
||||
f"LoRA weight keys not found in model: {lora_missing}. "
|
||||
f"The adapter config (r={config.r}) may not match the model."
|
||||
)
|
||||
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
||||
if unexpected:
|
||||
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
||||
|
||||
logger.info("LoRA adapter loaded from %s", load_dir)
|
||||
return config
|
||||
|
|
@ -0,0 +1,355 @@
|
|||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||
from astrai.model import AutoRegressiveLM
|
||||
from astrai.model.components.lora import (
|
||||
LoRAConfig,
|
||||
LoRALinear,
|
||||
_get_lora_count,
|
||||
_collect_lora_info,
|
||||
inject_lora,
|
||||
load_lora,
|
||||
merge_lora,
|
||||
save_lora,
|
||||
)
|
||||
from astrai.model.components.linear import Linear
|
||||
|
||||
MODEL_KWARGS = dict(
|
||||
vocab_size=1000,
|
||||
dim=64,
|
||||
n_heads=4,
|
||||
n_kv_heads=2,
|
||||
dim_ffn=128,
|
||||
n_layers=2,
|
||||
max_len=32,
|
||||
norm_eps=1e-5,
|
||||
)
|
||||
|
||||
|
||||
def _make_model(**kwargs):
|
||||
kw = {**MODEL_KWARGS, **kwargs}
|
||||
config = AutoRegressiveLMConfig(**kw)
|
||||
model = AutoRegressiveLM(config)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def test_loralinear_init():
|
||||
base = Linear(64, 128)
|
||||
lora = LoRALinear(base, r=8, alpha=16)
|
||||
|
||||
assert lora.weight is base.weight
|
||||
assert not lora.weight.requires_grad
|
||||
assert lora.lora_A.shape == (8, 64)
|
||||
assert lora.lora_B.shape == (128, 8)
|
||||
assert lora.scaling == 2.0
|
||||
assert not lora._merged
|
||||
assert lora.lora_A.requires_grad
|
||||
assert lora.lora_B.requires_grad
|
||||
|
||||
|
||||
def test_loralinear_forward_init_zero_delta():
|
||||
base = Linear(4, 4)
|
||||
with torch.no_grad():
|
||||
base.weight.zero_()
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
lora = LoRALinear(base, r=2, alpha=2)
|
||||
base_out = base(x)
|
||||
lora_out = lora(x)
|
||||
|
||||
assert torch.allclose(base_out, lora_out)
|
||||
|
||||
|
||||
def test_loralinear_forward_with_delta():
|
||||
base = Linear(4, 4)
|
||||
with torch.no_grad():
|
||||
base.weight.zero_()
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
lora = LoRALinear(base, r=2, alpha=2)
|
||||
base_out = base(x)
|
||||
|
||||
with torch.no_grad():
|
||||
lora.lora_B.fill_(1.0)
|
||||
|
||||
lora_out = lora(x)
|
||||
assert not torch.allclose(base_out, lora_out)
|
||||
|
||||
|
||||
def test_loralinear_merge():
|
||||
base = Linear(4, 4)
|
||||
with torch.no_grad():
|
||||
base.weight.zero_()
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
lora = LoRALinear(base, r=2, alpha=2)
|
||||
with torch.no_grad():
|
||||
lora.lora_B.fill_(1.0)
|
||||
|
||||
out_before = lora(x).clone()
|
||||
lora.merge()
|
||||
out_after = lora(x)
|
||||
|
||||
torch.testing.assert_close(out_before, out_after)
|
||||
assert lora._merged
|
||||
assert not hasattr(lora, "lora_A")
|
||||
|
||||
|
||||
def test_loralinear_merge_is_idempotent():
|
||||
base = Linear(4, 4)
|
||||
with torch.no_grad():
|
||||
base.weight.zero_()
|
||||
|
||||
lora = LoRALinear(base, r=2, alpha=2)
|
||||
with torch.no_grad():
|
||||
lora.lora_B.fill_(1.0)
|
||||
|
||||
lora.merge()
|
||||
lora.merge()
|
||||
|
||||
|
||||
def test_inject_lora_default_target():
|
||||
model = _make_model()
|
||||
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
|
||||
|
||||
inject_lora(model, r=4, alpha=8)
|
||||
|
||||
lora_count = _get_lora_count(model)
|
||||
assert lora_count > 0
|
||||
assert lora_count < n_before
|
||||
|
||||
|
||||
def test_inject_lora_ffn():
|
||||
model = _make_model()
|
||||
from astrai.model.components.lora import TARGET_MODULES_FFN
|
||||
|
||||
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
|
||||
assert _get_lora_count(model) > 0
|
||||
|
||||
|
||||
def test_inject_lora_returns_config():
|
||||
model = _make_model()
|
||||
cfg = inject_lora(model, r=8, alpha=32)
|
||||
assert isinstance(cfg, LoRAConfig)
|
||||
assert cfg.r == 8
|
||||
assert cfg.alpha == 32
|
||||
|
||||
|
||||
def test_inject_lora_no_matching_targets_warns(caplog):
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
|
||||
assert "No LoRA layers injected" in caplog.text
|
||||
|
||||
|
||||
def test_inject_lora_preserves_base_output():
|
||||
model = _make_model()
|
||||
x = torch.randint(0, 1000, (2, 16))
|
||||
|
||||
with torch.no_grad():
|
||||
out_before = model(x)["logits"].clone()
|
||||
|
||||
inject_lora(model, r=4, alpha=8)
|
||||
|
||||
with torch.no_grad():
|
||||
out_after = model(x)["logits"]
|
||||
|
||||
torch.testing.assert_close(out_before, out_after)
|
||||
|
||||
|
||||
def test_inject_lora_does_not_reinject():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
first_count = _get_lora_count(model)
|
||||
|
||||
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
|
||||
assert _get_lora_count(model) == first_count
|
||||
|
||||
|
||||
def test_inject_lora_adds_new_modules():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
first = _get_lora_count(model)
|
||||
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
|
||||
assert _get_lora_count(model) > first
|
||||
|
||||
|
||||
def test_inject_lora_on_mla_model():
|
||||
model = _make_model(
|
||||
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
|
||||
)
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
|
||||
assert _get_lora_count(model) > 0
|
||||
|
||||
|
||||
def test_inject_lora_on_moe_model():
|
||||
model = _make_model(
|
||||
ffn_type="moe",
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
n_activated_experts=2,
|
||||
dim_ffn=32,
|
||||
)
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
|
||||
assert _get_lora_count(model) > 0
|
||||
|
||||
|
||||
def test_state_dict_key_format():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
sd = model.state_dict()
|
||||
assert "layers.0.attention.q_proj.weight" in sd
|
||||
assert "layers.0.attention.q_proj.lora_A" in sd
|
||||
assert "layers.0.attention.q_proj.lora_B" in sd
|
||||
|
||||
|
||||
def test_only_lora_params_trainable():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if isinstance(name.split(".")[-1], str) and "lora" in name:
|
||||
assert param.requires_grad, f"lora param should be trainable: {name}"
|
||||
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
|
||||
assert not param.requires_grad, f"injected weight should be frozen: {name}"
|
||||
|
||||
|
||||
def test_state_dict_after_inject_consistent_with_original():
|
||||
model = _make_model()
|
||||
sd_before = {k: v for k, v in model.state_dict().items()}
|
||||
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
sd_after = model.state_dict()
|
||||
|
||||
# original keys unchanged
|
||||
for k in sd_before:
|
||||
assert k in sd_after
|
||||
assert sd_before[k].shape == sd_after[k].shape
|
||||
|
||||
# new lora keys present
|
||||
lora_keys = [k for k in sd_after if "lora" in k]
|
||||
assert len(lora_keys) > 0
|
||||
|
||||
|
||||
def test_save_load_roundtrip():
|
||||
model = _make_model()
|
||||
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
x = torch.randint(0, 1000, (2, 16))
|
||||
with torch.no_grad():
|
||||
out_src = model(x)["logits"].clone()
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
save_lora(model, tmpdir, cfg)
|
||||
|
||||
model2 = _make_model()
|
||||
model2.load_state_dict(model.state_dict(), strict=False)
|
||||
load_lora(model2, tmpdir)
|
||||
|
||||
with torch.no_grad():
|
||||
out_dst = model2(x)["logits"]
|
||||
|
||||
torch.testing.assert_close(out_src, out_dst)
|
||||
|
||||
|
||||
def test_save_after_merge_raises():
|
||||
model = _make_model()
|
||||
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
save_lora(model, tmpdir, cfg)
|
||||
merge_lora(model)
|
||||
|
||||
tmpdir2 = tempfile.mkdtemp()
|
||||
with pytest.raises(RuntimeError, match="No LoRA parameters"):
|
||||
save_lora(model, tmpdir2, cfg)
|
||||
|
||||
|
||||
def test_load_lora_on_already_injected():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
|
||||
|
||||
model2 = _make_model()
|
||||
model2.load_state_dict(model.state_dict(), strict=False)
|
||||
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
# load onto already-injected model
|
||||
load_lora(model2, tmpdir)
|
||||
assert _get_lora_count(model2) > 0
|
||||
|
||||
|
||||
def test_load_lora_mismatched_r_raises():
|
||||
model = _make_model()
|
||||
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
tmpdir = tempfile.mkdtemp()
|
||||
save_lora(model, tmpdir, cfg)
|
||||
|
||||
model2 = _make_model()
|
||||
model2.load_state_dict(model.state_dict(), strict=False)
|
||||
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with pytest.raises(RuntimeError, match="size mismatch"):
|
||||
load_lora(model2, tmpdir) # strict=False, only lora keys
|
||||
|
||||
|
||||
def test_merge_preserves_output():
|
||||
model = _make_model()
|
||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||
|
||||
with torch.no_grad():
|
||||
for m in model.modules():
|
||||
if isinstance(m, LoRALinear):
|
||||
m.lora_B.fill_(0.5)
|
||||
|
||||
x = torch.randint(0, 1000, (2, 16))
|
||||
with torch.no_grad():
|
||||
out_before = model(x)["logits"].clone()
|
||||
|
||||
merge_lora(model)
|
||||
|
||||
with torch.no_grad():
|
||||
out_after = model(x)["logits"]
|
||||
torch.testing.assert_close(out_before, out_after)
|
||||
|
||||
|
||||
def test_merge_no_lora_warns(caplog):
|
||||
model = _make_model()
|
||||
merge_lora(model)
|
||||
assert "No LoRA layers to merge" in caplog.text
|
||||
|
||||
|
||||
def test_collect_lora_info():
|
||||
model = _make_model()
|
||||
info = _collect_lora_info(model)
|
||||
assert "q_proj" in info
|
||||
assert "o_proj" in info
|
||||
assert "q_proj" in info # each layer has one
|
||||
Loading…
Reference in New Issue