feat: 新增LoRA微调模块
- LoRALinear基于register_parameter托管base weight,state_dict路径不变 - inject_lora/merge_lora/save_lora/load_lora完备封装 - 24个单元测试覆盖注入、合并、存取、边界场景
This commit is contained in:
parent
7df6eb9211
commit
a4688021bf
|
|
@ -7,6 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.base import BaseConfig
|
from astrai.config.base import BaseConfig
|
||||||
|
from astrai.model.components.lora import LoRAConfig
|
||||||
|
|
||||||
|
|
||||||
def required(**kw):
|
def required(**kw):
|
||||||
|
|
@ -56,6 +57,12 @@ class TrainConfig(BaseConfig):
|
||||||
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
default=5000, metadata={"help": "Number of iterations between checkpoints."}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# lora setting
|
||||||
|
lora: Optional[LoRAConfig] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "LoRA config. None means full fine-tuning."},
|
||||||
|
)
|
||||||
|
|
||||||
# metric setting
|
# metric setting
|
||||||
log_dir: str = field(
|
log_dir: str = field(
|
||||||
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
default="./checkpoint/logs", metadata={"help": "Directory for metric logs."}
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,13 @@ from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.components.attention import GQA
|
from astrai.model.components.attention import GQA
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.model.components.linear import Linear
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.lora import (
|
||||||
|
LoRAConfig,
|
||||||
|
inject_lora,
|
||||||
|
load_lora,
|
||||||
|
merge_lora,
|
||||||
|
save_lora,
|
||||||
|
)
|
||||||
from astrai.model.components.mlp import MLP
|
from astrai.model.components.mlp import MLP
|
||||||
from astrai.model.components.norm import RMSNorm
|
from astrai.model.components.norm import RMSNorm
|
||||||
from astrai.model.encoder import EmbeddingEncoder
|
from astrai.model.encoder import EmbeddingEncoder
|
||||||
|
|
@ -18,4 +25,10 @@ __all__ = [
|
||||||
"AutoRegressiveLM",
|
"AutoRegressiveLM",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
|
# LoRA
|
||||||
|
"LoRAConfig",
|
||||||
|
"inject_lora",
|
||||||
|
"merge_lora",
|
||||||
|
"save_lora",
|
||||||
|
"load_lora",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,192 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Set
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
||||||
|
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoRAConfig:
|
||||||
|
r: int = 16
|
||||||
|
alpha: int = 32
|
||||||
|
target_modules: tuple = ("q_proj", "v_proj")
|
||||||
|
|
||||||
|
|
||||||
|
class LoRALinear(nn.Module):
|
||||||
|
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
||||||
|
super().__init__()
|
||||||
|
self.register_parameter("weight", base.weight)
|
||||||
|
self.weight.requires_grad_(False)
|
||||||
|
self.bias = base.bias
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bias.requires_grad_(False)
|
||||||
|
|
||||||
|
self.r = r
|
||||||
|
self.scaling = alpha / r
|
||||||
|
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
||||||
|
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
||||||
|
self._merged = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = F.linear(x, self.weight, self.bias)
|
||||||
|
if not self._merged:
|
||||||
|
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
||||||
|
return out
|
||||||
|
|
||||||
|
def merge(self):
|
||||||
|
if self._merged:
|
||||||
|
return
|
||||||
|
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
||||||
|
self._merged = True
|
||||||
|
del self.lora_A
|
||||||
|
del self.lora_B
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_lora_info(model: nn.Module) -> dict:
|
||||||
|
names = {}
|
||||||
|
for n, m in model.named_modules():
|
||||||
|
if isinstance(m, Linear):
|
||||||
|
_, _, child = n.rpartition(".")
|
||||||
|
names.setdefault(child, []).append(n)
|
||||||
|
return names
|
||||||
|
|
||||||
|
|
||||||
|
def _get_lora_count(model: nn.Module) -> int:
|
||||||
|
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
||||||
|
|
||||||
|
|
||||||
|
def inject_lora(
|
||||||
|
model: nn.Module,
|
||||||
|
r: int = 16,
|
||||||
|
alpha: int = 32,
|
||||||
|
target_modules: Optional[Set[str]] = None,
|
||||||
|
) -> LoRAConfig:
|
||||||
|
if target_modules is None:
|
||||||
|
target_modules = TARGET_MODULES_ATTN
|
||||||
|
|
||||||
|
available = _collect_lora_info(model)
|
||||||
|
injected = 0
|
||||||
|
|
||||||
|
for name, module in list(model.named_modules()):
|
||||||
|
if not isinstance(module, Linear):
|
||||||
|
continue
|
||||||
|
parent_name, _, child_name = name.rpartition(".")
|
||||||
|
if child_name not in target_modules:
|
||||||
|
continue
|
||||||
|
parent = model.get_submodule(parent_name) if parent_name else model
|
||||||
|
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
||||||
|
injected += 1
|
||||||
|
|
||||||
|
if injected == 0:
|
||||||
|
logger.warning(
|
||||||
|
"No LoRA layers injected. Available Linear child names: %s. "
|
||||||
|
"target_modules: %s. Check model type and target_modules.",
|
||||||
|
sorted(available),
|
||||||
|
sorted(target_modules),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
||||||
|
|
||||||
|
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
||||||
|
|
||||||
|
|
||||||
|
def merge_lora(model: nn.Module):
|
||||||
|
n = 0
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, LoRALinear):
|
||||||
|
module.merge()
|
||||||
|
n += 1
|
||||||
|
if n == 0:
|
||||||
|
logger.warning("No LoRA layers to merge.")
|
||||||
|
else:
|
||||||
|
logger.info("Merged %d LoRA layers", n)
|
||||||
|
|
||||||
|
|
||||||
|
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
||||||
|
lora_sd = {
|
||||||
|
k: v
|
||||||
|
for k, v in model.state_dict().items()
|
||||||
|
if k.endswith((".lora_A", ".lora_B"))
|
||||||
|
}
|
||||||
|
if not lora_sd:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No LoRA parameters found in model. "
|
||||||
|
"The model may not have been injected or was already merged."
|
||||||
|
)
|
||||||
|
|
||||||
|
path = Path(save_dir)
|
||||||
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
|
st.save_file(lora_sd, str(path / "adapter_model.safetensors"))
|
||||||
|
with open(path / "adapter_config.json", "w") as f:
|
||||||
|
json.dump(asdict(config), f, indent=2)
|
||||||
|
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
||||||
|
|
||||||
|
|
||||||
|
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
||||||
|
path = Path(load_dir)
|
||||||
|
with open(path / "adapter_config.json") as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
config = LoRAConfig(
|
||||||
|
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
||||||
|
)
|
||||||
|
|
||||||
|
existing = _get_lora_count(model)
|
||||||
|
if existing > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Model already has %d LoRA layers. Skipping injection, "
|
||||||
|
"loading weights onto existing layers only.",
|
||||||
|
existing,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inject_lora(
|
||||||
|
model,
|
||||||
|
r=config.r,
|
||||||
|
alpha=config.alpha,
|
||||||
|
target_modules=set(config.target_modules),
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = st.load_file(str(path / "adapter_model.safetensors"))
|
||||||
|
try:
|
||||||
|
missing, unexpected = model.load_state_dict(weights, strict=False)
|
||||||
|
except RuntimeError as e:
|
||||||
|
msg = str(e)
|
||||||
|
if "size mismatch" in msg:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"LoRA weight shapes do not match the model. "
|
||||||
|
f"The adapter config (r={config.r}) may not match the injected layers. "
|
||||||
|
f"Original error: {msg}"
|
||||||
|
) from e
|
||||||
|
raise
|
||||||
|
|
||||||
|
injected = _get_lora_count(model)
|
||||||
|
if injected == 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
"No LoRA layers found after loading. "
|
||||||
|
"Inject LoRA before calling load_lora, or check the adapter config."
|
||||||
|
)
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
lora_missing = [k for k in missing if "lora" in k]
|
||||||
|
if lora_missing:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"LoRA weight keys not found in model: {lora_missing}. "
|
||||||
|
f"The adapter config (r={config.r}) may not match the model."
|
||||||
|
)
|
||||||
|
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
||||||
|
if unexpected:
|
||||||
|
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
||||||
|
|
||||||
|
logger.info("LoRA adapter loaded from %s", load_dir)
|
||||||
|
return config
|
||||||
|
|
@ -6,6 +6,7 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.dataset import ResumableDistributedSampler
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
|
from astrai.model.components.lora import inject_lora
|
||||||
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
from astrai.parallel.executor import BaseExecutor, ExecutorFactory
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
from astrai.protocols import OptimizerProtocol, SchedulerProtocol
|
||||||
|
|
@ -77,6 +78,14 @@ class TrainContextBuilder:
|
||||||
state_dict=context.model.state_dict(),
|
state_dict=context.model.state_dict(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.lora is not None:
|
||||||
|
inject_lora(
|
||||||
|
context.model,
|
||||||
|
r=cfg.lora.r,
|
||||||
|
alpha=cfg.lora.alpha,
|
||||||
|
target_modules=set(cfg.lora.target_modules),
|
||||||
|
)
|
||||||
|
|
||||||
context.optimizer = cfg.optimizer_fn(context.model)
|
context.optimizer = cfg.optimizer_fn(context.model)
|
||||||
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
context.scheduler = cfg.scheduler_fn(context.optimizer)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,355 @@
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from astrai.config.model_config import AutoRegressiveLMConfig
|
||||||
|
from astrai.model import AutoRegressiveLM
|
||||||
|
from astrai.model.components.linear import Linear
|
||||||
|
from astrai.model.components.lora import (
|
||||||
|
LoRAConfig,
|
||||||
|
LoRALinear,
|
||||||
|
_collect_lora_info,
|
||||||
|
_get_lora_count,
|
||||||
|
inject_lora,
|
||||||
|
load_lora,
|
||||||
|
merge_lora,
|
||||||
|
save_lora,
|
||||||
|
)
|
||||||
|
|
||||||
|
MODEL_KWARGS = dict(
|
||||||
|
vocab_size=1000,
|
||||||
|
dim=64,
|
||||||
|
n_heads=4,
|
||||||
|
n_kv_heads=2,
|
||||||
|
dim_ffn=128,
|
||||||
|
n_layers=2,
|
||||||
|
max_len=32,
|
||||||
|
norm_eps=1e-5,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_model(**kwargs):
|
||||||
|
kw = {**MODEL_KWARGS, **kwargs}
|
||||||
|
config = AutoRegressiveLMConfig(**kw)
|
||||||
|
model = AutoRegressiveLM(config)
|
||||||
|
model.eval()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_init():
|
||||||
|
base = Linear(64, 128)
|
||||||
|
lora = LoRALinear(base, r=8, alpha=16)
|
||||||
|
|
||||||
|
assert lora.weight is base.weight
|
||||||
|
assert not lora.weight.requires_grad
|
||||||
|
assert lora.lora_A.shape == (8, 64)
|
||||||
|
assert lora.lora_B.shape == (128, 8)
|
||||||
|
assert lora.scaling == 2.0
|
||||||
|
assert not lora._merged
|
||||||
|
assert lora.lora_A.requires_grad
|
||||||
|
assert lora.lora_B.requires_grad
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_forward_init_zero_delta():
|
||||||
|
base = Linear(4, 4)
|
||||||
|
with torch.no_grad():
|
||||||
|
base.weight.zero_()
|
||||||
|
|
||||||
|
x = torch.randn(2, 4)
|
||||||
|
lora = LoRALinear(base, r=2, alpha=2)
|
||||||
|
base_out = base(x)
|
||||||
|
lora_out = lora(x)
|
||||||
|
|
||||||
|
assert torch.allclose(base_out, lora_out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_forward_with_delta():
|
||||||
|
base = Linear(4, 4)
|
||||||
|
with torch.no_grad():
|
||||||
|
base.weight.zero_()
|
||||||
|
|
||||||
|
x = torch.randn(2, 4)
|
||||||
|
lora = LoRALinear(base, r=2, alpha=2)
|
||||||
|
base_out = base(x)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
lora.lora_B.fill_(1.0)
|
||||||
|
|
||||||
|
lora_out = lora(x)
|
||||||
|
assert not torch.allclose(base_out, lora_out)
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_merge():
|
||||||
|
base = Linear(4, 4)
|
||||||
|
with torch.no_grad():
|
||||||
|
base.weight.zero_()
|
||||||
|
|
||||||
|
x = torch.randn(2, 4)
|
||||||
|
lora = LoRALinear(base, r=2, alpha=2)
|
||||||
|
with torch.no_grad():
|
||||||
|
lora.lora_B.fill_(1.0)
|
||||||
|
|
||||||
|
out_before = lora(x).clone()
|
||||||
|
lora.merge()
|
||||||
|
out_after = lora(x)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_before, out_after)
|
||||||
|
assert lora._merged
|
||||||
|
assert not hasattr(lora, "lora_A")
|
||||||
|
|
||||||
|
|
||||||
|
def test_loralinear_merge_is_idempotent():
|
||||||
|
base = Linear(4, 4)
|
||||||
|
with torch.no_grad():
|
||||||
|
base.weight.zero_()
|
||||||
|
|
||||||
|
lora = LoRALinear(base, r=2, alpha=2)
|
||||||
|
with torch.no_grad():
|
||||||
|
lora.lora_B.fill_(1.0)
|
||||||
|
|
||||||
|
lora.merge()
|
||||||
|
lora.merge()
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_default_target():
|
||||||
|
model = _make_model()
|
||||||
|
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8)
|
||||||
|
|
||||||
|
lora_count = _get_lora_count(model)
|
||||||
|
assert lora_count > 0
|
||||||
|
assert lora_count < n_before
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_ffn():
|
||||||
|
model = _make_model()
|
||||||
|
from astrai.model.components.lora import TARGET_MODULES_FFN
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
|
||||||
|
assert _get_lora_count(model) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_returns_config():
|
||||||
|
model = _make_model()
|
||||||
|
cfg = inject_lora(model, r=8, alpha=32)
|
||||||
|
assert isinstance(cfg, LoRAConfig)
|
||||||
|
assert cfg.r == 8
|
||||||
|
assert cfg.alpha == 32
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_no_matching_targets_warns(caplog):
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
|
||||||
|
assert "No LoRA layers injected" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_preserves_base_output():
|
||||||
|
model = _make_model()
|
||||||
|
x = torch.randint(0, 1000, (2, 16))
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_before = model(x)["logits"].clone()
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_after = model(x)["logits"]
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_before, out_after)
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_does_not_reinject():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
first_count = _get_lora_count(model)
|
||||||
|
|
||||||
|
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
|
||||||
|
assert _get_lora_count(model) == first_count
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_adds_new_modules():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
first = _get_lora_count(model)
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
|
||||||
|
assert _get_lora_count(model) > first
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_on_mla_model():
|
||||||
|
model = _make_model(
|
||||||
|
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
|
||||||
|
)
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
|
||||||
|
assert _get_lora_count(model) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_lora_on_moe_model():
|
||||||
|
model = _make_model(
|
||||||
|
ffn_type="moe",
|
||||||
|
n_routed_experts=4,
|
||||||
|
n_shared_experts=1,
|
||||||
|
n_activated_experts=2,
|
||||||
|
dim_ffn=32,
|
||||||
|
)
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
|
||||||
|
assert _get_lora_count(model) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dict_key_format():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
sd = model.state_dict()
|
||||||
|
assert "layers.0.attention.q_proj.weight" in sd
|
||||||
|
assert "layers.0.attention.q_proj.lora_A" in sd
|
||||||
|
assert "layers.0.attention.q_proj.lora_B" in sd
|
||||||
|
|
||||||
|
|
||||||
|
def test_only_lora_params_trainable():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
|
||||||
|
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if isinstance(name.split(".")[-1], str) and "lora" in name:
|
||||||
|
assert param.requires_grad, f"lora param should be trainable: {name}"
|
||||||
|
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
|
||||||
|
assert not param.requires_grad, f"injected weight should be frozen: {name}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_state_dict_after_inject_consistent_with_original():
|
||||||
|
model = _make_model()
|
||||||
|
sd_before = {k: v for k, v in model.state_dict().items()}
|
||||||
|
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
sd_after = model.state_dict()
|
||||||
|
|
||||||
|
# original keys unchanged
|
||||||
|
for k in sd_before:
|
||||||
|
assert k in sd_after
|
||||||
|
assert sd_before[k].shape == sd_after[k].shape
|
||||||
|
|
||||||
|
# new lora keys present
|
||||||
|
lora_keys = [k for k in sd_after if "lora" in k]
|
||||||
|
assert len(lora_keys) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_load_roundtrip():
|
||||||
|
model = _make_model()
|
||||||
|
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
x = torch.randint(0, 1000, (2, 16))
|
||||||
|
with torch.no_grad():
|
||||||
|
out_src = model(x)["logits"].clone()
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
save_lora(model, tmpdir, cfg)
|
||||||
|
|
||||||
|
model2 = _make_model()
|
||||||
|
model2.load_state_dict(model.state_dict(), strict=False)
|
||||||
|
load_lora(model2, tmpdir)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_dst = model2(x)["logits"]
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_src, out_dst)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_after_merge_raises():
|
||||||
|
model = _make_model()
|
||||||
|
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
save_lora(model, tmpdir, cfg)
|
||||||
|
merge_lora(model)
|
||||||
|
|
||||||
|
tmpdir2 = tempfile.mkdtemp()
|
||||||
|
with pytest.raises(RuntimeError, match="No LoRA parameters"):
|
||||||
|
save_lora(model, tmpdir2, cfg)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_lora_on_already_injected():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
|
||||||
|
|
||||||
|
model2 = _make_model()
|
||||||
|
model2.load_state_dict(model.state_dict(), strict=False)
|
||||||
|
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
# load onto already-injected model
|
||||||
|
load_lora(model2, tmpdir)
|
||||||
|
assert _get_lora_count(model2) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_lora_mismatched_r_raises():
|
||||||
|
model = _make_model()
|
||||||
|
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
tmpdir = tempfile.mkdtemp()
|
||||||
|
save_lora(model, tmpdir, cfg)
|
||||||
|
|
||||||
|
model2 = _make_model()
|
||||||
|
model2.load_state_dict(model.state_dict(), strict=False)
|
||||||
|
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="size mismatch"):
|
||||||
|
load_lora(model2, tmpdir) # strict=False, only lora keys
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_preserves_output():
|
||||||
|
model = _make_model()
|
||||||
|
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for m in model.modules():
|
||||||
|
if isinstance(m, LoRALinear):
|
||||||
|
m.lora_B.fill_(0.5)
|
||||||
|
|
||||||
|
x = torch.randint(0, 1000, (2, 16))
|
||||||
|
with torch.no_grad():
|
||||||
|
out_before = model(x)["logits"].clone()
|
||||||
|
|
||||||
|
merge_lora(model)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
out_after = model(x)["logits"]
|
||||||
|
torch.testing.assert_close(out_before, out_after)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_no_lora_warns(caplog):
|
||||||
|
model = _make_model()
|
||||||
|
merge_lora(model)
|
||||||
|
assert "No LoRA layers to merge" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_lora_info():
|
||||||
|
model = _make_model()
|
||||||
|
info = _collect_lora_info(model)
|
||||||
|
assert "q_proj" in info
|
||||||
|
assert "o_proj" in info
|
||||||
|
assert "q_proj" in info # each layer has one
|
||||||
Loading…
Reference in New Issue