Compare commits
No commits in common. "432145a798847cd50431c0064fa4d1ec135ddadc" and "82a3f2626f56a8cb22bae3b2f03ed3dc0af4b7d7" have entirely different histories.
432145a798
...
82a3f2626f
|
|
@ -97,7 +97,7 @@ class TrainConfig(BaseConfig):
|
||||||
)
|
)
|
||||||
parallel_mode: str = field(
|
parallel_mode: str = field(
|
||||||
default="none",
|
default="none",
|
||||||
metadata={"help": "Parallel strategy: none, ddp, fsdp."},
|
metadata={"help": "Parallel strategy: none, ddp."},
|
||||||
)
|
)
|
||||||
start_method: str = field(
|
start_method: str = field(
|
||||||
default="spawn",
|
default="spawn",
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,6 @@ from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.components.attention import GQA
|
from astrai.model.components.attention import GQA
|
||||||
from astrai.model.components.decoder_block import DecoderBlock
|
from astrai.model.components.decoder_block import DecoderBlock
|
||||||
from astrai.model.components.linear import Linear
|
from astrai.model.components.linear import Linear
|
||||||
from astrai.model.components.lora import (
|
|
||||||
LoRAConfig,
|
|
||||||
inject_lora,
|
|
||||||
load_lora,
|
|
||||||
merge_lora,
|
|
||||||
save_lora,
|
|
||||||
)
|
|
||||||
from astrai.model.components.mlp import MLP
|
from astrai.model.components.mlp import MLP
|
||||||
from astrai.model.components.norm import RMSNorm
|
from astrai.model.components.norm import RMSNorm
|
||||||
from astrai.model.encoder import EmbeddingEncoder
|
from astrai.model.encoder import EmbeddingEncoder
|
||||||
|
|
@ -25,10 +18,4 @@ __all__ = [
|
||||||
"AutoRegressiveLM",
|
"AutoRegressiveLM",
|
||||||
"EmbeddingEncoder",
|
"EmbeddingEncoder",
|
||||||
"AutoModel",
|
"AutoModel",
|
||||||
# LoRA
|
|
||||||
"LoRAConfig",
|
|
||||||
"inject_lora",
|
|
||||||
"merge_lora",
|
|
||||||
"save_lora",
|
|
||||||
"load_lora",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,192 +0,0 @@
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Set
|
|
||||||
|
|
||||||
import safetensors.torch as st
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
TARGET_MODULES_ATTN = {"q_proj", "k_proj", "v_proj", "o_proj"}
|
|
||||||
TARGET_MODULES_FFN = {"up", "gate", "down"}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LoRAConfig:
|
|
||||||
r: int = 16
|
|
||||||
alpha: int = 32
|
|
||||||
target_modules: tuple = ("q_proj", "v_proj")
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALinear(nn.Module):
|
|
||||||
def __init__(self, base: Linear, r: int = 16, alpha: int = 32):
|
|
||||||
super().__init__()
|
|
||||||
self.register_parameter("weight", base.weight)
|
|
||||||
self.weight.requires_grad_(False)
|
|
||||||
self.bias = base.bias
|
|
||||||
if self.bias is not None:
|
|
||||||
self.bias.requires_grad_(False)
|
|
||||||
|
|
||||||
self.r = r
|
|
||||||
self.scaling = alpha / r
|
|
||||||
self.lora_A = nn.Parameter(torch.randn(r, self.weight.shape[1]) / r)
|
|
||||||
self.lora_B = nn.Parameter(torch.zeros(self.weight.shape[0], r))
|
|
||||||
self._merged = False
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
out = F.linear(x, self.weight, self.bias)
|
|
||||||
if not self._merged:
|
|
||||||
out += (F.linear(x, self.lora_A) @ self.lora_B.T) * self.scaling
|
|
||||||
return out
|
|
||||||
|
|
||||||
def merge(self):
|
|
||||||
if self._merged:
|
|
||||||
return
|
|
||||||
self.weight.data += (self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self._merged = True
|
|
||||||
del self.lora_A
|
|
||||||
del self.lora_B
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_lora_info(model: nn.Module) -> dict:
|
|
||||||
names = {}
|
|
||||||
for n, m in model.named_modules():
|
|
||||||
if isinstance(m, Linear):
|
|
||||||
_, _, child = n.rpartition(".")
|
|
||||||
names.setdefault(child, []).append(n)
|
|
||||||
return names
|
|
||||||
|
|
||||||
|
|
||||||
def _get_lora_count(model: nn.Module) -> int:
|
|
||||||
return sum(1 for m in model.modules() if isinstance(m, LoRALinear))
|
|
||||||
|
|
||||||
|
|
||||||
def inject_lora(
|
|
||||||
model: nn.Module,
|
|
||||||
r: int = 16,
|
|
||||||
alpha: int = 32,
|
|
||||||
target_modules: Optional[Set[str]] = None,
|
|
||||||
) -> LoRAConfig:
|
|
||||||
if target_modules is None:
|
|
||||||
target_modules = TARGET_MODULES_ATTN
|
|
||||||
|
|
||||||
available = _collect_lora_info(model)
|
|
||||||
injected = 0
|
|
||||||
|
|
||||||
for name, module in list(model.named_modules()):
|
|
||||||
if not isinstance(module, Linear):
|
|
||||||
continue
|
|
||||||
parent_name, _, child_name = name.rpartition(".")
|
|
||||||
if child_name not in target_modules:
|
|
||||||
continue
|
|
||||||
parent = model.get_submodule(parent_name) if parent_name else model
|
|
||||||
setattr(parent, child_name, LoRALinear(module, r=r, alpha=alpha))
|
|
||||||
injected += 1
|
|
||||||
|
|
||||||
if injected == 0:
|
|
||||||
logger.warning(
|
|
||||||
"No LoRA layers injected. Available Linear child names: %s. "
|
|
||||||
"target_modules: %s. Check model type and target_modules.",
|
|
||||||
sorted(available),
|
|
||||||
sorted(target_modules),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("LoRA injected: %d layers (r=%d, alpha=%d)", injected, r, alpha)
|
|
||||||
|
|
||||||
return LoRAConfig(r=r, alpha=alpha, target_modules=tuple(target_modules))
|
|
||||||
|
|
||||||
|
|
||||||
def merge_lora(model: nn.Module):
|
|
||||||
n = 0
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, LoRALinear):
|
|
||||||
module.merge()
|
|
||||||
n += 1
|
|
||||||
if n == 0:
|
|
||||||
logger.warning("No LoRA layers to merge.")
|
|
||||||
else:
|
|
||||||
logger.info("Merged %d LoRA layers", n)
|
|
||||||
|
|
||||||
|
|
||||||
def save_lora(model: nn.Module, save_dir: str, config: LoRAConfig):
|
|
||||||
lora_sd = {
|
|
||||||
k: v
|
|
||||||
for k, v in model.state_dict().items()
|
|
||||||
if k.endswith((".lora_A", ".lora_B"))
|
|
||||||
}
|
|
||||||
if not lora_sd:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No LoRA parameters found in model. "
|
|
||||||
"The model may not have been injected or was already merged."
|
|
||||||
)
|
|
||||||
|
|
||||||
path = Path(save_dir)
|
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
|
||||||
st.save_file(lora_sd, str(path / "adapter_model.safetensors"))
|
|
||||||
with open(path / "adapter_config.json", "w") as f:
|
|
||||||
json.dump(asdict(config), f, indent=2)
|
|
||||||
logger.info("LoRA adapter saved to %s (%d keys)", save_dir, len(lora_sd))
|
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model: nn.Module, load_dir: str) -> LoRAConfig:
|
|
||||||
path = Path(load_dir)
|
|
||||||
with open(path / "adapter_config.json") as f:
|
|
||||||
raw = json.load(f)
|
|
||||||
config = LoRAConfig(
|
|
||||||
r=raw["r"], alpha=raw["alpha"], target_modules=tuple(raw["target_modules"])
|
|
||||||
)
|
|
||||||
|
|
||||||
existing = _get_lora_count(model)
|
|
||||||
if existing > 0:
|
|
||||||
logger.warning(
|
|
||||||
"Model already has %d LoRA layers. Skipping injection, "
|
|
||||||
"loading weights onto existing layers only.",
|
|
||||||
existing,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
inject_lora(
|
|
||||||
model,
|
|
||||||
r=config.r,
|
|
||||||
alpha=config.alpha,
|
|
||||||
target_modules=set(config.target_modules),
|
|
||||||
)
|
|
||||||
|
|
||||||
weights = st.load_file(str(path / "adapter_model.safetensors"))
|
|
||||||
try:
|
|
||||||
missing, unexpected = model.load_state_dict(weights, strict=False)
|
|
||||||
except RuntimeError as e:
|
|
||||||
msg = str(e)
|
|
||||||
if "size mismatch" in msg:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"LoRA weight shapes do not match the model. "
|
|
||||||
f"The adapter config (r={config.r}) may not match the injected layers. "
|
|
||||||
f"Original error: {msg}"
|
|
||||||
) from e
|
|
||||||
raise
|
|
||||||
|
|
||||||
injected = _get_lora_count(model)
|
|
||||||
if injected == 0:
|
|
||||||
raise RuntimeError(
|
|
||||||
"No LoRA layers found after loading. "
|
|
||||||
"Inject LoRA before calling load_lora, or check the adapter config."
|
|
||||||
)
|
|
||||||
|
|
||||||
if missing:
|
|
||||||
lora_missing = [k for k in missing if "lora" in k]
|
|
||||||
if lora_missing:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"LoRA weight keys not found in model: {lora_missing}. "
|
|
||||||
f"The adapter config (r={config.r}) may not match the model."
|
|
||||||
)
|
|
||||||
logger.debug("LoRA load: %d missing base-weight keys (expected)", len(missing))
|
|
||||||
if unexpected:
|
|
||||||
logger.warning("LoRA load: %d unexpected keys", len(unexpected))
|
|
||||||
|
|
||||||
logger.info("LoRA adapter loaded from %s", load_dir)
|
|
||||||
return config
|
|
||||||
|
|
@ -2,9 +2,7 @@ from astrai.parallel.executor import (
|
||||||
AccumOptimizer,
|
AccumOptimizer,
|
||||||
AccumScheduler,
|
AccumScheduler,
|
||||||
BaseExecutor,
|
BaseExecutor,
|
||||||
DDPExecutor,
|
|
||||||
ExecutorFactory,
|
ExecutorFactory,
|
||||||
FSDPExecutor,
|
|
||||||
GradientState,
|
GradientState,
|
||||||
NoneExecutor,
|
NoneExecutor,
|
||||||
)
|
)
|
||||||
|
|
@ -33,6 +31,4 @@ __all__ = [
|
||||||
"AccumOptimizer",
|
"AccumOptimizer",
|
||||||
"AccumScheduler",
|
"AccumScheduler",
|
||||||
"NoneExecutor",
|
"NoneExecutor",
|
||||||
"DDPExecutor",
|
|
||||||
"FSDPExecutor",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
@ -199,33 +198,3 @@ class DDPExecutor(BaseExecutor):
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
return model.module
|
return model.module
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@ExecutorFactory.register("fsdp")
|
|
||||||
class FSDPExecutor(BaseExecutor):
|
|
||||||
def __init__(self, grad_accum_steps: int = 1, **fsdp_kwargs):
|
|
||||||
super().__init__(grad_accum_steps=grad_accum_steps)
|
|
||||||
self._fsdp_kwargs = fsdp_kwargs
|
|
||||||
self._original_model: Optional[nn.Module] = None
|
|
||||||
|
|
||||||
def _prepare_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
if not self.use_distributed:
|
|
||||||
logger.warning("FSDP backend selected but world_size=1, model not wrapped")
|
|
||||||
return model
|
|
||||||
self._original_model = model
|
|
||||||
device_id = torch.device("cuda", get_rank())
|
|
||||||
model = FSDP(model, device_id=device_id, **self._fsdp_kwargs)
|
|
||||||
logger.info("Model wrapped with FSDP (world_size=%d)", get_world_size())
|
|
||||||
return model
|
|
||||||
|
|
||||||
def _no_sync(self, model: nn.Module):
|
|
||||||
if isinstance(model, FSDP):
|
|
||||||
return model.no_sync()
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
def unwrap_model(self, model: nn.Module) -> nn.Module:
|
|
||||||
if self._original_model is not None:
|
|
||||||
return self._original_model
|
|
||||||
if isinstance(model, FSDP):
|
|
||||||
return model._fsdp_wrapped_module
|
|
||||||
return model
|
|
||||||
|
|
|
||||||
|
|
@ -8,17 +8,15 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
|
"""Unwrap DDP wrapper if present to get the original model."""
|
||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
return model.module
|
return model.module
|
||||||
if isinstance(model, FSDP):
|
|
||||||
return model._fsdp_wrapped_module
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,355 +0,0 @@
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from astrai.config.model_config import AutoRegressiveLMConfig
|
|
||||||
from astrai.model import AutoRegressiveLM
|
|
||||||
from astrai.model.components.lora import (
|
|
||||||
LoRAConfig,
|
|
||||||
LoRALinear,
|
|
||||||
_get_lora_count,
|
|
||||||
_collect_lora_info,
|
|
||||||
inject_lora,
|
|
||||||
load_lora,
|
|
||||||
merge_lora,
|
|
||||||
save_lora,
|
|
||||||
)
|
|
||||||
from astrai.model.components.linear import Linear
|
|
||||||
|
|
||||||
MODEL_KWARGS = dict(
|
|
||||||
vocab_size=1000,
|
|
||||||
dim=64,
|
|
||||||
n_heads=4,
|
|
||||||
n_kv_heads=2,
|
|
||||||
dim_ffn=128,
|
|
||||||
n_layers=2,
|
|
||||||
max_len=32,
|
|
||||||
norm_eps=1e-5,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_model(**kwargs):
|
|
||||||
kw = {**MODEL_KWARGS, **kwargs}
|
|
||||||
config = AutoRegressiveLMConfig(**kw)
|
|
||||||
model = AutoRegressiveLM(config)
|
|
||||||
model.eval()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_init():
|
|
||||||
base = Linear(64, 128)
|
|
||||||
lora = LoRALinear(base, r=8, alpha=16)
|
|
||||||
|
|
||||||
assert lora.weight is base.weight
|
|
||||||
assert not lora.weight.requires_grad
|
|
||||||
assert lora.lora_A.shape == (8, 64)
|
|
||||||
assert lora.lora_B.shape == (128, 8)
|
|
||||||
assert lora.scaling == 2.0
|
|
||||||
assert not lora._merged
|
|
||||||
assert lora.lora_A.requires_grad
|
|
||||||
assert lora.lora_B.requires_grad
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_forward_init_zero_delta():
|
|
||||||
base = Linear(4, 4)
|
|
||||||
with torch.no_grad():
|
|
||||||
base.weight.zero_()
|
|
||||||
|
|
||||||
x = torch.randn(2, 4)
|
|
||||||
lora = LoRALinear(base, r=2, alpha=2)
|
|
||||||
base_out = base(x)
|
|
||||||
lora_out = lora(x)
|
|
||||||
|
|
||||||
assert torch.allclose(base_out, lora_out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_forward_with_delta():
|
|
||||||
base = Linear(4, 4)
|
|
||||||
with torch.no_grad():
|
|
||||||
base.weight.zero_()
|
|
||||||
|
|
||||||
x = torch.randn(2, 4)
|
|
||||||
lora = LoRALinear(base, r=2, alpha=2)
|
|
||||||
base_out = base(x)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
lora.lora_B.fill_(1.0)
|
|
||||||
|
|
||||||
lora_out = lora(x)
|
|
||||||
assert not torch.allclose(base_out, lora_out)
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_merge():
|
|
||||||
base = Linear(4, 4)
|
|
||||||
with torch.no_grad():
|
|
||||||
base.weight.zero_()
|
|
||||||
|
|
||||||
x = torch.randn(2, 4)
|
|
||||||
lora = LoRALinear(base, r=2, alpha=2)
|
|
||||||
with torch.no_grad():
|
|
||||||
lora.lora_B.fill_(1.0)
|
|
||||||
|
|
||||||
out_before = lora(x).clone()
|
|
||||||
lora.merge()
|
|
||||||
out_after = lora(x)
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_before, out_after)
|
|
||||||
assert lora._merged
|
|
||||||
assert not hasattr(lora, "lora_A")
|
|
||||||
|
|
||||||
|
|
||||||
def test_loralinear_merge_is_idempotent():
|
|
||||||
base = Linear(4, 4)
|
|
||||||
with torch.no_grad():
|
|
||||||
base.weight.zero_()
|
|
||||||
|
|
||||||
lora = LoRALinear(base, r=2, alpha=2)
|
|
||||||
with torch.no_grad():
|
|
||||||
lora.lora_B.fill_(1.0)
|
|
||||||
|
|
||||||
lora.merge()
|
|
||||||
lora.merge()
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_default_target():
|
|
||||||
model = _make_model()
|
|
||||||
n_before = sum(1 for m in model.modules() if isinstance(m, Linear))
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8)
|
|
||||||
|
|
||||||
lora_count = _get_lora_count(model)
|
|
||||||
assert lora_count > 0
|
|
||||||
assert lora_count < n_before
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_ffn():
|
|
||||||
model = _make_model()
|
|
||||||
from astrai.model.components.lora import TARGET_MODULES_FFN
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules=TARGET_MODULES_FFN)
|
|
||||||
assert _get_lora_count(model) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_returns_config():
|
|
||||||
model = _make_model()
|
|
||||||
cfg = inject_lora(model, r=8, alpha=32)
|
|
||||||
assert isinstance(cfg, LoRAConfig)
|
|
||||||
assert cfg.r == 8
|
|
||||||
assert cfg.alpha == 32
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_no_matching_targets_warns(caplog):
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"nonexistent"})
|
|
||||||
assert "No LoRA layers injected" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_preserves_base_output():
|
|
||||||
model = _make_model()
|
|
||||||
x = torch.randint(0, 1000, (2, 16))
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out_before = model(x)["logits"].clone()
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out_after = model(x)["logits"]
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_before, out_after)
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_does_not_reinject():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
first_count = _get_lora_count(model)
|
|
||||||
|
|
||||||
inject_lora(model, r=2, alpha=4, target_modules={"q_proj"})
|
|
||||||
assert _get_lora_count(model) == first_count
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_adds_new_modules():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
first = _get_lora_count(model)
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"v_proj"})
|
|
||||||
assert _get_lora_count(model) > first
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_on_mla_model():
|
|
||||||
model = _make_model(
|
|
||||||
attn_type="mla", kv_lora_rank=16, qk_nope_head_dim=16, qk_rope_head_dim=16
|
|
||||||
)
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "o_proj"})
|
|
||||||
assert _get_lora_count(model) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_inject_lora_on_moe_model():
|
|
||||||
model = _make_model(
|
|
||||||
ffn_type="moe",
|
|
||||||
n_routed_experts=4,
|
|
||||||
n_shared_experts=1,
|
|
||||||
n_activated_experts=2,
|
|
||||||
dim_ffn=32,
|
|
||||||
)
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"up", "gate", "down"})
|
|
||||||
assert _get_lora_count(model) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_dict_key_format():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
sd = model.state_dict()
|
|
||||||
assert "layers.0.attention.q_proj.weight" in sd
|
|
||||||
assert "layers.0.attention.q_proj.lora_A" in sd
|
|
||||||
assert "layers.0.attention.q_proj.lora_B" in sd
|
|
||||||
|
|
||||||
|
|
||||||
def test_only_lora_params_trainable():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj", "v_proj"})
|
|
||||||
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if isinstance(name.split(".")[-1], str) and "lora" in name:
|
|
||||||
assert param.requires_grad, f"lora param should be trainable: {name}"
|
|
||||||
elif any(name.endswith(f".{t}.weight") for t in ("q_proj", "v_proj")):
|
|
||||||
assert not param.requires_grad, f"injected weight should be frozen: {name}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_state_dict_after_inject_consistent_with_original():
|
|
||||||
model = _make_model()
|
|
||||||
sd_before = {k: v for k, v in model.state_dict().items()}
|
|
||||||
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
sd_after = model.state_dict()
|
|
||||||
|
|
||||||
# original keys unchanged
|
|
||||||
for k in sd_before:
|
|
||||||
assert k in sd_after
|
|
||||||
assert sd_before[k].shape == sd_after[k].shape
|
|
||||||
|
|
||||||
# new lora keys present
|
|
||||||
lora_keys = [k for k in sd_after if "lora" in k]
|
|
||||||
assert len(lora_keys) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_load_roundtrip():
|
|
||||||
model = _make_model()
|
|
||||||
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
x = torch.randint(0, 1000, (2, 16))
|
|
||||||
with torch.no_grad():
|
|
||||||
out_src = model(x)["logits"].clone()
|
|
||||||
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
save_lora(model, tmpdir, cfg)
|
|
||||||
|
|
||||||
model2 = _make_model()
|
|
||||||
model2.load_state_dict(model.state_dict(), strict=False)
|
|
||||||
load_lora(model2, tmpdir)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out_dst = model2(x)["logits"]
|
|
||||||
|
|
||||||
torch.testing.assert_close(out_src, out_dst)
|
|
||||||
|
|
||||||
|
|
||||||
def test_save_after_merge_raises():
|
|
||||||
model = _make_model()
|
|
||||||
cfg = inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
save_lora(model, tmpdir, cfg)
|
|
||||||
merge_lora(model)
|
|
||||||
|
|
||||||
tmpdir2 = tempfile.mkdtemp()
|
|
||||||
with pytest.raises(RuntimeError, match="No LoRA parameters"):
|
|
||||||
save_lora(model, tmpdir2, cfg)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_lora_on_already_injected():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
save_lora(model, tmpdir, LoRAConfig(r=4, alpha=8, target_modules=("q_proj",)))
|
|
||||||
|
|
||||||
model2 = _make_model()
|
|
||||||
model2.load_state_dict(model.state_dict(), strict=False)
|
|
||||||
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
# load onto already-injected model
|
|
||||||
load_lora(model2, tmpdir)
|
|
||||||
assert _get_lora_count(model2) > 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_lora_mismatched_r_raises():
|
|
||||||
model = _make_model()
|
|
||||||
cfg = inject_lora(model, r=8, alpha=16, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
tmpdir = tempfile.mkdtemp()
|
|
||||||
save_lora(model, tmpdir, cfg)
|
|
||||||
|
|
||||||
model2 = _make_model()
|
|
||||||
model2.load_state_dict(model.state_dict(), strict=False)
|
|
||||||
inject_lora(model2, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with pytest.raises(RuntimeError, match="size mismatch"):
|
|
||||||
load_lora(model2, tmpdir) # strict=False, only lora keys
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_preserves_output():
|
|
||||||
model = _make_model()
|
|
||||||
inject_lora(model, r=4, alpha=8, target_modules={"q_proj"})
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for m in model.modules():
|
|
||||||
if isinstance(m, LoRALinear):
|
|
||||||
m.lora_B.fill_(0.5)
|
|
||||||
|
|
||||||
x = torch.randint(0, 1000, (2, 16))
|
|
||||||
with torch.no_grad():
|
|
||||||
out_before = model(x)["logits"].clone()
|
|
||||||
|
|
||||||
merge_lora(model)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
out_after = model(x)["logits"]
|
|
||||||
torch.testing.assert_close(out_before, out_after)
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_no_lora_warns(caplog):
|
|
||||||
model = _make_model()
|
|
||||||
merge_lora(model)
|
|
||||||
assert "No LoRA layers to merge" in caplog.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_collect_lora_info():
|
|
||||||
model = _make_model()
|
|
||||||
info = _collect_lora_info(model)
|
|
||||||
assert "q_proj" in info
|
|
||||||
assert "o_proj" in info
|
|
||||||
assert "q_proj" in info # each layer has one
|
|
||||||
Loading…
Reference in New Issue