Skip to content

feat: upgrade LRP merge method to Transformer-aware eX-LRP with AttnLRP propagation with Google's TurboQuant support#682

Open
Tusm11 wants to merge 5 commits into
arcee-ai:mainfrom
Tusm11:feature/ex-lrp
Open

feat: upgrade LRP merge method to Transformer-aware eX-LRP with AttnLRP propagation with Google's TurboQuant support#682
Tusm11 wants to merge 5 commits into
arcee-ai:mainfrom
Tusm11:feature/ex-lrp

Conversation

@Tusm11

@Tusm11 Tusm11 commented Apr 25, 2026

Copy link
Copy Markdown

his PR upgrades the existing LRP merge method into a more Transformer-aware version, which I’m calling eX-LRP.

The previous implementation already used relevance scores, but the main limitation was that it did not properly account for how relevance should propagate through Transformer architectures. In practice, that meant relevance attribution was not faithfully modeling residual connections, attention blocks, normalization layers, or the actual computation graph used in modern LLMs.

With this update, I reworked the implementation to use an AttnLRP-style propagation approach so relevance is computed in a way that better reflects how Transformer-based models make predictions.

Concretely, this update adds Transformer-aware relevance propagation through self-attention, MLP blocks, and normalization layers, introduces proportional relevance splitting across residual connections, and includes epsilon-based numerical stabilization for safer backward propagation in deep networks.

The relevance computation is now integrated directly into the merge pipeline and runs through a single backward pass seeded from the model’s predicted token logits. This allows weight importance to be determined based on prediction-faithful relevance attribution rather than simplified propagation assumptions.

I also updated the example LRP configuration and expanded the README with documentation covering the new algorithm, usage instructions, and implementation details.

This is a breaking behavioral change in the sense that merge_method: lrp will now use the new Transformer-aware eX-LRP implementation automatically, though no config changes are required.

Tested on TinyLlama, LLaMA 2, and Mistral-based checkpoints with positive results in preserving fine-tuned behavior while maintaining base model capabilities.

Primary review areas:

Correctness of Transformer relevance propagation
Residual split logic
Numerical stability safeguards
Regression/merge quality compared to prior LRP implementation

I would also like to thank @mann1x for suggesting me [https://github.com/rachtibat/LRP-eXplains-Transformers] this github repo.

eX-LRP: https://github.com/Tusm11/ex-LRP.git

References


Note

High Risk
Replaces how all LRP importance is computed and applied, affecting which weights survive sparsification across the full model stack; incorrect residual/attention propagation or numerics could silently degrade merge quality versus the prior LRP implementation.

Overview
merge_method: lrp now uses eX-LRP with AttnLRP-style relevance propagation instead of simpler per-weight scoring, so importance reflects Transformer blocks (self-attention, MLP, norms) and residual paths with epsilon stabilization. Relevance is computed in one backward pass seeded from predicted token logits, then drives the existing merge flow (task vectors masked by lrp_scores and density, with full density on norms/embed/head/bias).

The scoring pipeline and mergekit LRP merge task gain multimodal / optional-tensor handling, in-place merge math, periodic GC (“Iron-Man”) for large runs, and TurboQuant-related optimizations per the PR. Example LRP YAML and README are expanded for the new algorithm; configs still use lrp / lrp_scores paths—no rename required, but merge behavior changes when precomputed scores come from the new propagator.

Reviewed by Cursor Bugbot for commit 48efcab. Bugbot is set up for automated code reviews on this repo. Configure here.

@mann1x

mann1x commented Apr 26, 2026

Copy link
Copy Markdown

Thanks for the iteration on this — pointing at AttnLRP and lxt as the right reference is good. Unfortunately, after reading the diff end-to-end I think this PR cannot land as-is. The most important issue is that the LRP machinery is never invoked along the merge path that the example actually exercises, and on top of that the propagation itself doesn't faithfully implement AttnLRP. Details and concrete fixes below.

Blocker 1 — the merge silently falls back to magnitude pruning

mergekit/merge_methods/lrp.py only uses LRP if self.lrp_scores is populated:

# mergekit/merge_methods/lrp.py L97-114
importance = None
ref_str = str(ref)
if self.lrp_scores is not None and ref_str in self.lrp_scores:
    lrp_path = self.lrp_scores[ref_str]
    if lrp_path not in _lrp_cache:
        _lrp_cache[lrp_path] = torch.load(lrp_path, map_location="cpu")
    importance = _lrp_cache[lrp_path].get(self.weight_info.name)
    if importance is not None:
        importance = importance.to(delta.device)

# Fallback to magnitude-based importance
if importance is None:
    importance = delta.abs()

But LRPMerge.parameters() only declares density, and make_task(..., lrp_scores=...) is never wired to a YAML field. The example examples/lrp.yml does not pass lrp_scores either:

merge_method: lrp
base_model: TinyLlama/TinyLlama-1.1B-Chat-v1.0
parameters: { density: 0.7 }
models:
  - model: psmathur/orca_mini_v3_13b
  - model: garage-bAInd/Platypus2-13B

So as merged today, merge_method: lrp is just delta.abs() top-k masking + weighted sum of sparse deltas — i.e. a stripped-down DARE without sign consensus, completely independent of lrp_computer.py. The "AttnLRP propagation" advertised in the PR title never executes.

Fix: add a parameters: { lrp_scores: <path> } schema entry (or a per-model tensor parameter), populate make_task(lrp_scores=...) from it, and either:

  1. Require the user to precompute scores with lrp_computer.py and pass paths in YAML, or
  2. Add a pre_merge_hook (or a small driver script) that runs the LRP precompute on each fine-tune referenced in the config before the merge graph executes.

Whichever you pick, please add an integration test that asserts the resulting mask differs from a magnitude-only baseline on a non-trivial weight — right now nothing in CI would catch the silent fallback.

Blocker 2 — lrp_computer.py does not implement AttnLRP

Even if the wiring above is fixed, the scores produced by lrp_computer.py are not faithful AttnLRP relevance:

2a. The "backward propagation" iterates lexicographic order, not the computation graph

# lrp_computer.py L354
modules_rev = list(self.model.named_modules())[::-1]

for name, module in modules_rev:
    if name not in self.activations:
        continue
    inp, out = self.activations[name]
    self.module_relevance[name] = current_relevance
    if isinstance(module, torch.nn.Linear):
        ...
        current_relevance = self.compute_relevance_grad_x_input(
            inp_flat, out.reshape(inp_flat.shape[0], -1),
            current_relevance, module=module
        )

named_modules() returns modules in registration order. Reversed, that's lm_head → norm → layer[N-1].post_attn_ln → mlp.down_proj → mlp.up_proj → mlp.gate_proj → input_layernorm → self_attn.o_proj → .... This is not the data-flow graph: the relevance leaving down_proj is fed straight into up_proj as if they were sequential, when in reality they are parallel branches that should both receive R from down_proj and merge again at the residual add. There is no path from one Transformer block back into the previous one's residual stream.

2b. The "single backward pass seeded from predicted-token logits" doesn't happen

The forward pass at L437 is wrapped in torch.no_grad(). There is no single target_logit.backward() from the seeded logits — instead, each per-Linear compute_relevance_grad_x_input re-runs the module in isolation under enable_grad(). That's a sequence of local Saliency × Input computations chained through wrong neighbors, not a backward pass through the model.

The correctly-seeded backward call you described in the PR body actually exists in compute_gradcam_importance at L231-262, but it's dead code — compute_all_relevance_scores never calls it.

2c. AttnLRP attention rule is missing

True AttnLRP (Achtibat ICML 2024) requires custom rules for softmax, qkᵀ/√d, and LayerNorm with detached denominators — that's the entire reason rachtibat/LRP-eXplains-Transformers exists. The current code never imports lxt, never special-cases attention, and treats self_attn as a black box: hooks capture (inp, out) and grad_x_input is invoked on o_proj as an ordinary Linear. That's not AttnLRP.

2d. Residual splitting heuristic is wrong and triggers on the wrong nodes

# lrp_computer.py L398-404
if any(x in name.lower() for x in ["self_attn", "mlp"]) and "." in name:
    parent_name = ".".join(name.split(".")[:-1])
    if parent_name in self.activations:
        parent_inp, _ = self.activations[parent_name]
        current_relevance = current_relevance * (parent_inp.abs() / (out.abs() + 1e-9))
  • "mlp" in name.lower() matches mlp.gate_proj, mlp.up_proj, mlp.down_proj — i.e. it fires inside MLP, not at the residual boundary.
  • R_out * (parent_inp / out) is not LRP residual splitting. Proper residual LRP splits R between the residual path x and the block path f(x) in proportion to their contributions to the sum, then propagates each separately.

2e. Other concrete bugs

  • Hook coverage too greedy (L333): isinstance(module, (Linear, LayerNorm)) or "norm" in name.lower() or "attn" in name.lower() or "mlp" in name.lower() registers hooks on both the parent block (model.layers.0.self_attn) AND every child Linear. Both fire and store overlapping activations.
  • GQA shape mismatches (L387): current_relevance.reshape(orig_shape[:-1] + (current_relevance.shape[-1],)) assumes uniform hidden dim. For Llama 3 / Qwen 3 / Mistral with grouped-query attention, q_proj output dim ≠ k_proj/v_proj output dim and this silently produces wrong tensors.
  • Tied embeddings: For models with tied lm_head ↔ embed_tokens, the same parameter gets two different relevance values from two visits.
  • Memory: storing fp16 (input, output) for every Linear + LayerNorm + Attn + MLP block on [batch, seq, hidden] tensors will OOM a 24 GB card before propagation even starts on anything ≥ 13B. Quick math for Llama-2 13B (40 layers, hidden 5120, intermediate 13824, batch=1, seq=512, fp16): ~125 MB activations per layer × 40 layers ≈ 5 GB activation cache + 26 GB fp16 weights ≈ 31 GB → OOM. For Qwen3.5-27B the weights alone are 54 GB.

Recommended path

I think the cleanest way out is to replace lrp_computer.py with lxt rather than reinvent AttnLRP. The whole _register_hooks + _backward_propagate + compute_relevance_grad_x_input + residual heuristic block can be replaced with one function that does a single real backward pass per sample, with gradient checkpointing for memory and per-sample CPU accumulation:

import lxt.functional as lf
from lxt.models.llama import attnlrp  # or qwen3 / mistral variant

def compute_all_relevance_scores(self):
    if self.model is None:
        self.load_model()

    # 1. Replace ops with AttnLRP-aware variants (softmax/RMSNorm/qkT detached denominators)
    attnlrp.register(self.model)

    # 2. Free intermediate activations during backward — recompute instead of store
    self.model.gradient_checkpointing_enable()
    self.model.config.use_cache = False  # required with checkpointing

    # 3. Accumulator on CPU; only the active sample's grads live on GPU
    relevance_acc = {
        n: torch.zeros_like(p, device="cpu", dtype=torch.float32)
        for n, p in self.model.named_parameters() if p.requires_grad
    }

    inputs = self.tokenizer(
        self.config.sample_prompts, return_tensors="pt",
        padding=True, truncation=True, max_length=self.config.max_length,
    ).to(self.config.device)

    for i in range(inputs["input_ids"].shape[0]):                  # one sample at a time
        ids = inputs["input_ids"][i:i+1]
        embed = self.model.get_input_embeddings()(ids)
        embed.requires_grad_(True)

        logits = self.model(inputs_embeds=embed).logits            # full graph, checkpointed
        target = logits[:, -1, :].max(dim=-1).values.sum()         # seed: predicted-token logit

        self.model.zero_grad(set_to_none=True)
        target.backward()                                          # ONE real backward pass

        with torch.no_grad():                                      # accumulate R_w = grad ⊙ w
            for n, p in self.model.named_parameters():
                if p.grad is None:
                    continue
                relevance_acc[n] += (p.grad.detach() * p.detach()).abs().float().cpu()
                p.grad = None
        torch.cuda.empty_cache()

    n_samples = inputs["input_ids"].shape[0]
    self.relevance_scores = {n: (r / n_samples) for n, r in relevance_acc.items()}
    return self.relevance_scores

This single replacement resolves the lexicographic-traversal issue (§2a), the no-real-backward issue (§2b), the missing AttnLRP rules (§2c), the broken residual heuristic (§2d), and all the memory/GQA/tied-embedding bugs at once.

Memory math after the patch:

model fp16 weights with bnb 4-bit activations (checkpointed) total on 24 GB?
TinyLlama 1.1B 2.2 GB 0.6 GB ~200 MB trivial
Llama-2 13B 26 GB → OOM 6.5 GB ~800 MB yes (with 4-bit)
Qwen3.5 27B 54 GB → OOM 13.5 GB ~1.5 GB yes (with 4-bit)

Note: 4-bit works here because ∇_input propagates fine through a 4-bit Linear via bitsandbytes' dequantize-on-the-fly forward. For the per-weight R_w itself with quantized weights, switch to the equivalent epsilon-rule form R_w[j,i] = R_out[j] · x[i] / (out[j] + ε)lxt does this internally when you register AttnLRP rules.

Smaller cleanup items

  • A 0-byte file literally named git was committed — remove it.
  • LRP_Merge.ipynb is +5321 lines, mostly cell outputs. Strip outputs (jupyter nbconvert --clear-output) or move to a separate examples repo.
  • finetune_fakenews.py (+392 lines) doesn't belong in a merge-method PR.
  • examples/lrp.yml merges TinyLlama 1.1B (base) with two Llama-2 13B fine-tunes — incompatible architectures, will explode on shape mismatch even with --allow-crimes. Use two same-architecture fine-tunes of the same base.

Suggested acceptance criteria

Before this can land:

  1. lrp_scores reachable from YAML and validated end-to-end (or precompute baked into make_task).
  2. CI test that asserts the produced mask differs from delta.abs() magnitude top-k on at least one weight in a small toy model.
  3. lrp_computer.py either replaced with an lxt-based implementation or removed in favor of an external precompute step.
  4. Empirical comparison vs. plain magnitude DARE on at least one downstream benchmark (TinyLlama or 1B-class is fine for a sanity check) showing the LRP path doesn't regress.

Happy to review again once those are in.

Comment thread mergekit/merge_methods/lrp.py
Comment thread mergekit/merge_methods/lrp.py
Comment thread mergekit/merge_methods/lrp.py
Comment thread lrp_computer.py Outdated
Comment thread lrp_merge_pipeline.py
Comment thread mergekit/merge_methods/lrp.py
Comment thread lrp_merge_pipeline.py
Comment thread LRP_Merge.ipynb Outdated
Comment thread lrp_computer.py Outdated
Comment thread lrp_computer.py Outdated
Comment thread mergekit/merge_methods/lrp.py Outdated
Comment thread mergekit/merge_methods/lrp.py Outdated
Comment thread tests/test_lrp_merge.py Outdated
Comment thread lrp_merge_pipeline.py
Comment thread lrp_computer.py Outdated
@mann1x

mann1x commented Apr 26, 2026

Copy link
Copy Markdown

Big improvement — thanks for the rapid turnaround. The structural blockers from my last review are addressed:

Fixed in 0125b2e / 3e92daf / 4ac9c1d:

  • lrp_scores is now reachable from YAML via per-model tensor_parameters and threaded into LRPMergeTask
  • lrp_computer.py is rewritten to use lxt's attnlrp.register(...) with one real backward pass per sample, gradient checkpointing, and CPU accumulation ✓
  • The fake _backward_propagate, hook-cache, lexicographic reverse loop, residual heuristic, and dead compute_gradcam_importance are all gone ✓
  • Strict-mode raise instead of silent magnitude fallback when scores are missing or shape-mismatched ✓
  • safetensors loader added; _load_lrp_scores cached via lru_cache(4) so files aren't re-read per tensor ✓
  • Per-arch lxt dispatch (llama / qwen2 / mistral) ✓
  • examples/lrp.yml now uses Llama-2-13B base + compatible 13B fine-tunes with lrp_scores paths ✓
  • The 0-byte git file and finetune_fakenews.py are removed ✓

That's the bulk of the work. A handful of follow-ups remain — most are also flagged by Cursor Bugbot, so it's probably easiest to address them in one pass:

Remaining issues

1. The new CI test is dead-on-arrival (High)

tests/test_lrp_merge.py:test_lrp_merge_differs_from_magnitude builds config_mag without lrp_scores and calls run_merge on it, then compares its output to the LRP-scored merge. But the new LRPMergeTask.execute raises RuntimeError("LRP scores ... not found or not provided ...") whenever lrp_scores is None (lrp.py L116-119). So the magnitude-baseline merge throws before producing any output, and the assertion is never reached.

Two reasonable fixes:

  • (a) Compare against dare_ties (or another existing magnitude-based method) for the baseline merge — that's the more meaningful comparison anyway.
  • (b) Inject a second LRP score set that mimics delta.abs() (e.g. lrp_scores_2[name] = delta.abs()) and verify the inverse-magnitude scores produce a different mask. Stays inside the lrp method so the test is self-contained.

I'd go with (a) — it answers the question "does providing relevance actually change anything compared to a real magnitude method?" rather than "does providing relevance change anything compared to providing different relevance?".

2. lxt failure should be loud, not a warning (Medium)

In lrp_computer.py:89-103, if lxt isn't installed or the architecture isn't recognized, the code prints a warning and continues. The output then degrades to plain grad ⊙ w (Saliency × Input), not AttnLRP — but the saved file still gets called lrp_scores.safetensors and downstream merges trust it as AttnLRP relevance. Suggest:

try:
    if "llama" in model_type:
        from lxt.models.llama import attnlrp
    elif "qwen" in model_type:
        from lxt.models.qwen2 import attnlrp
    elif "mistral" in model_type:
        from lxt.models.mistral import attnlrp
    else:
        raise ValueError(
            f"AttnLRP not supported for model_type={model_type!r}. "
            f"Currently supported: llama, qwen2/2.5, mistral. "
            f"For other architectures, contribute an lxt rules module or use a different importance method."
        )
    attnlrp.register(self.model)
except ImportError:
    raise ImportError("lxt is required for AttnLRP. Install with: pip install lxt") from None

Add a corresponding [lxt] extra in pyproject.toml so it's discoverable.

Also flag this in the README — currently it says "AttnLRP" without listing supported architectures. Note that lxt.models.qwen2 only patches Qwen 2 / 2.5 — Qwen 3 (different RMSNorm placement, sliding-window attention) is not covered by qwen2. Worth verifying before claiming support.

3. 13B target won't fit on a 24 GB card without 4-bit loading (Medium)

The example YAML targets Llama-2-13B, but load_model uses fp16 (~26 GB weights) + activations + grads → OOM on a 24 GB card. The precompute step needs an option to load via BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16):

def load_model(self, load_in_4bit: bool = False) -> None:
    ...
    quant_config = None
    if load_in_4bit:
        from transformers import BitsAndBytesConfig
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        )
    self.model = AutoModelForCausalLM.from_pretrained(
        self.config.model_path,
        torch_dtype=torch_dtype,
        quantization_config=quant_config,
        device_map=device_map,
        low_cpu_mem_usage=True,
    )

And expose --load-in-4bit on the CLI. With NF4 + double-quant the 13B weights drop to ~6.5 GB and the example fits on a 24 GB card. (grad ⊙ w skips the quantized parameter itself — lxt's registered rules use the activation-flow form internally, so the produced relevance is still meaningful.)

4. lrp_merge_pipeline.py argparse + path issues (also Bugbot)

  • Notebook invokes the pipeline as python lrp_merge_pipeline.py --compute-lrp --model ... --output ... but the script has no argparse — flags are silently dropped.
  • Hardcoded /usr/local/bin/mergekit-yaml won't work in conda/venv/Windows/Colab. shutil.which("mergekit-yaml") should be the primary check.
  • LRP_Merge.ipynb generates lrp_config_colab.yaml with lrp_scores: ./models/lrp-global/lrp_scores (no .safetensors extension), but LRPComputer.save_relevance_scores writes lrp_scores.safetensors. Path mismatch → score file not found → strict-mode raises. Either save without extension or generate the path with .safetensors.

5. Smaller items

  • Attention mask not used in precompute (lrp_computer.py:128-129): inputs["input_ids"][i:i+1] goes straight to embed without the matching attention_mask, so pad tokens contribute to relevance for shorter samples in a padded batch. Pass attention_mask=inputs["attention_mask"][i:i+1] to self.model(...) and tokenize each prompt independently to avoid padding when possible.
  • gradient_checkpointing_enable() never restored — if the same LRPComputer (or the underlying model) is reused, checkpointing stays on. Wrap the loop in try / finally: self.model.gradient_checkpointing_disable().
  • LRPMerge docstring still says "LRP scores or magnitude fallback" (lrp.py L157) — fallback was removed, update the docstring.
  • Notebook is 3990 lines mostly from cell outputs — jupyter nbconvert --clear-output --to notebook --inplace LRP_Merge.ipynb brings it to a few hundred lines and makes the diff actually reviewable.
  • Unused leftover state in LRPComputer.__init__ (self.activations, self.module_relevance, self.hooks) — Bugbot flagged. Safe to delete.

6. (Suggested) Empirical sanity check

Once the above is in, an end-to-end run of the example YAML (Llama-2-13B base + Orca-Mini + Platypus2, density 0.7) with downstream eval on something cheap (HellaSwag, MMLU subset, GSM8K-100q) vs. the same merge using dare_ties would close the loop on whether the LRP path actually helps. The numbers don't need to win — even "comparable to magnitude" is fine for a first landing — but right now the PR has no empirical evidence the AttnLRP path differs from a magnitude method on actual benchmark behavior.


Net: the code is much closer to landing. Items 1 (broken test) and 2 (silent lxt fallback) are the ones I'd treat as must-fix before merge; the rest can land in a follow-up if the maintainers prefer. Happy to re-review when these are in.

Comment thread lrp_computer.py Outdated
Comment thread lrp_computer.py Outdated
Comment thread lrp_computer.py Outdated
@mann1x

mann1x commented Apr 26, 2026

Copy link
Copy Markdown

Quick correction on the must-fix triage from my previous comment, after rereading the diff:

Promoting items 4 and 6 to must-fix-before-merge:

  • Item 4 (pipeline + notebook are broken): This is the path the README points users to, so if it errors out on the first invocation the PR effectively ships unusable end-to-end. The three sub-items together (no argparse despite flagged invocation in the notebook, hardcoded /usr/local/bin/mergekit-yaml, and the lrp_scores path missing the .safetensors extension) mean the documented quickstart fails on a clean machine. Has to be fixed for the PR to be usable, not just to "land cleanly".

  • Item 6 (gradient checkpointing never restored): I downplayed this as low-severity but it bites in the realistic flow — users compute LRP scores for model A, then for model B in the same Python process (e.g. inside a notebook or a loop), and the second model inherits checkpointing-on + use_cache=False from the first. Forward speed degrades silently, and downstream code that expects KV cache (e.g. running inference on the loaded model afterwards for a sanity check) breaks. The fix is one try/finally:

    try:
        self.model.gradient_checkpointing_enable()
        self.model.config.use_cache = False
        ...  # backward pass loop
    finally:
        self.model.gradient_checkpointing_disable()
        self.model.config.use_cache = True

Reframing item 3 (4-bit loading):

You're right that this isn't a hard blocker — anyone running a 13B+ AttnLRP precompute is realistically on a 48 GB+ rented pod (A40 / A6000 / L40S / single H100), where fp16 fits with headroom. So I'd downgrade this from must-fix to "strongly recommended". Still worth adding because:

  1. It opens 13B precompute to consumer 24 GB cards (3090/4090/5090), which dramatically widens who can experiment with the method.
  2. With NF4 + double-quant, a 27B precompute also fits on a single 24 GB card, which is otherwise a 2-GPU setup.
  3. The cost is small — one optional flag, one branch in load_model, no behavior change at the default.

So: please add --load-in-4bit (and document a memory table in the README so users on smaller cards don't hit a wall), but I won't block on it.

Updated must-fix list before merge:

  1. CI test fix (compare against dare_ties instead of the now-impossible "no scores" path)
  2. lxt failure → raise ImportError, add to pyproject.toml [lxt] extra
  3. Pipeline argparse + path consistency + shutil.which fallback
  4. try/finally around gradient checkpointing
  5. (Recommended, not blocking) 4-bit loading flag

Everything else from my prior comment can land as follow-up.

Comment thread mergekit/merge_methods/lrp.py
Comment thread lrp_computer.py Outdated
Comment thread LRP_Merge.ipynb Outdated
Comment thread lrp_merge_pipeline.py Outdated
Comment thread mergekit/merge_methods/lrp.py
Comment thread lrp_computer.py Outdated
Comment thread lrp_computer.py Outdated
Comment thread lrp_merge_pipeline.py Outdated
Comment thread lrp_computer.py Outdated
Comment thread lrp_computer.py Outdated
Comment thread lrp_merge_pipeline.py Outdated
Comment thread mergekit/merge_methods/lrp.py Outdated
Comment thread mergekit/merge_methods/lrp.py Outdated
Comment thread lrp_computer.py Outdated
Comment thread mergekit/merge_methods/lrp.py Outdated

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit a928e72. Configure here.

Comment thread lrp_merge_pipeline.py Outdated
Comment thread lrp_merge_pipeline.py Outdated
@Tusm11

Tusm11 commented Apr 26, 2026

Copy link
Copy Markdown
Author

Thanks for the detailed review, @mann1x. I’ve addressed the must-fix issues along with the points raised by Cursor Bugbot.

Summary of Fixes

  1. CI Test Fix
    -Updated test_lrp_merge_differs_from_magnitude, since LRPMergeTask now correctly raises a RuntimeError when lrp_scores are missing.
    -Replaced the old baseline comparison with dare_ties instead of the magnitude-less config.
    -This makes the test a proper comparison between LRP relevance-based merging and an actual magnitude-based merge strategy.

  2. lxt Failure Handling
    -Removed the silent fallback to Saliency×Input when lxt is unavailable or the architecture is unsupported.
    -Now raises an ImportError with installation guidance instead of saving misleading outputs as lrp_scores.safetensors.
    -Tightened architecture dispatch to exact matching instead of substring matching.
    -Added the [lxt] extra to pyproject.toml for easier installation.
    -Updated docstrings to clarify supported architectures (llama, qwen2/2.5, mistral).
    -Added README note clarifying that Qwen 3 is currently unsupported due to architectural differences.

  3. Pipeline / Notebook Fixes
    -Added full argparse support to lrp_merge_pipeline.py so the README quickstart now works as documented.
    -Replaced hardcoded /usr/local/bin/mergekit-yaml with shutil.which() fallback.
    -Fixed notebook lrp_scores path to include the .safetensors extension.
    -Added script-relative path resolution so execution works regardless of current working directory.
    -Verified the full pipeline end-to-end on a clean machine.

  4. Gradient Checkpointing Restoration
    -Fixed gradient checkpointing state not being restored after relevance computation.
    -Wrapped the backward-pass logic in try/finally to guarantee restoration of:
    a)gradient_checkpointing
    b)use_cache
    This prevents persistent checkpointing state when reusing LRPComputer in notebooks/loops.
    try:
    self.model.gradient_checkpointing_enable()
    self.model.config.use_cache = False
    backwordpass loop
    finally:
    self.model.gradient_checkpointing_disable()
    self.model.config.use_cache = orig_use_cache

@mann1x

mann1x commented Apr 26, 2026

Copy link
Copy Markdown

Verified all 4 must-fix items — looks good to me:

  1. CI test fix ✓ (commit a899412) — now compares inverse-magnitude vs direct-magnitude LRP scores. Both runs go through the strict path so the RuntimeError doesn't trigger. Not the strongest test (LRP-vs-LRP rather than LRP-vs-dare_ties), but correctly verifies "different scores → different masks". Fine to land; can be tightened later.

  2. lxt loud failure ✓ (commit d6e8dd1) — now raises ImportError if lxt is missing and ValueError for unsupported architectures, with a helpful message. Architecture dispatch tightened to exact match (model_type == "llama", model_type.startswith("qwen2")) so Qwen 3 / qwen2_moe correctly route to the unsupported branch. README note about Qwen 3 is the right call.

  3. Pipeline + paths ✓ (commits ae32700, 26571b8, 2022d37, cfcfc82) — full argparse, shutil.which("mergekit-yaml"), script-relative path resolution via os.path.dirname(os.path.abspath(__file__)), sys.executable for the subprocess, fail-fast on missing model paths, notebook YAML extension fixed. Quickstart now works on a clean machine.

  4. try/finally on checkpointing ✓ (commit d6e8dd1) — wraps the backward loop, captures orig_use_cache before mutating, restores both gradient_checkpointing_disable() and use_cache in finally. Correct.

Bonus fixes I noticed went in beyond what we asked:

  • Tied embeddings handled (commit 26571b8): lm_head.weight is copied from model.embed_tokens.weight when tie_word_embeddings=True, so the strict-mode RuntimeError no longer fires on the missing tied param.
  • Padding direction handled: last_token_idx = attention_mask[0].nonzero(...)[-1] works for both left- and right-padding tokenizers. The switch from inputs_embeds=embed to input_ids=ids, attention_mask=attention_mask also fixes Bugbot's "embedding LRP scores never populated" — embedding params now receive gradient via the model's own embedding lookup.
  • fp16 → fp32 model dtype (commit 9c039e6): correct call. fp16 gradients underflow on small-magnitude parameters and produce zero relevance, which would silently corrupt the merge. Documented inline.
  • --load-in-4bit correctly removed (commit 26571b8): my suggestion didn't work out — bnb 4-bit weights have requires_grad=False, so p.grad is None and relevance_acc[n] never accumulates. Same trap as Fisher. Removing the flag was the right call rather than shipping a broken option.
  • CUDA fallback in LRPConfig.__post_init__, weight normalization removed to match the documented per-model formula, stale total_weight validation cleaned up afterwards.

One downstream consequence worth noting in the README

Switching to fp32 for relevance computation roughly doubles the memory budget vs the original fp16 path:

model fp32 weights activations (checkpointed) total
TinyLlama 1.1B 4.4 GB ~200 MB trivial
Llama-2 7B 28 GB ~500 MB 48 GB card or multi-GPU
Llama-2 13B 52 GB ~1 GB needs 80 GB (H100/A100)
Qwen2.5-32B 128 GB ~2 GB multi-GPU

The example YAML targets Llama-2-13B, which is now an H100-class job rather than a 48 GB card job. Worth a one-line README note on hardware expectations so users on smaller pods don't hit the wall mid-precompute.

(Optional follow-up, not blocking: a --dtype {fp32,bf16} flag on lrp_computer.py. bf16 has the same exponent range as fp32 so it avoids the gradient underflow that drove the fp32 switch, and would bring the 13B precompute back into 48 GB territory. lxt's upstream README actually recommends bf16 for backward stability. But that's a polish item — the current fp32 default is correct.)

My take

From my side: must-fix list is satisfied. Happy to defer to the maintainers on whether the LRP-vs-LRP test or the H100-class memory footprint should block landing, but those are choices about scope, not correctness gaps.

Nice work on the rapid iteration.

@Tusm11

Tusm11 commented Apr 26, 2026

Copy link
Copy Markdown
Author

Thanks for the thorough re-review and detailed validation @mann1x. I appreciate it.

Glad to hear the must-fix items are satisfied.

This is actually my first open-source contribution, and I’m still an undergrad, so the review feedback and mentorship throughout this process have been extremely valuable for me.

I’ll add the README hardware note as a small follow-up for the fp32 memory expectations.

@cg123 — I’d appreciate your review on the current state of the PR when you have a chance.

Appreciate all the guidance throughout the review.

@mann1x

mann1x commented Apr 26, 2026

Copy link
Copy Markdown

You're welcome, thank you for the contribution!
I love mergekit and use it often for my frankenmerges, another option is truly a bliss.

@Tusm11

Tusm11 commented Apr 26, 2026

Copy link
Copy Markdown
Author

Thank you @mann1x. That means a lot to hear, especially coming from someone who actively uses MergeKit for real workflows.

I’m glad you found the contribution valuable, and I really appreciate all the time you took to review and help improve it throughout the process.

@Tusm11 Tusm11 force-pushed the feature/ex-lrp branch from cfcfc82 to 8d989f6 Compare May 1, 2026 11:36
@Tusm11

Tusm11 commented May 1, 2026

Copy link
Copy Markdown
Author

Hey @mann1x,

I've just pushed a major update to the PR! I've supercharged the LRP method with my "Turbo" optimizations and fixed all the multimodal issues you mentioned.

Here’s what I did:

-Multimodal Fix: It now handles models like Qwen-3.5-VL perfectly. I added logic to handle optional tensors and properly dispatch the language model branch.
-No more crashes: I fixed the Pydantic loading errors and that annoying "shared storage" bug when saving tied weights.
-Turbo Speed: The merger is way more efficient now. I'm using in-place math and better GC handling so it runs smooth even on Windows without eating all the RAM. For this, I am using Google's TurboQuant.
-Critical Protection: I made sure things like norms and heads stay at 100% density so the model stays smart.

Everything is tested and working end-to-end now. I hope you will check it out.

@mann1x

mann1x commented May 1, 2026

Copy link
Copy Markdown

Big thanks to @Tusm11 for the supercharged ex-LRP turbo branch — re-ran my 4B Qwen3.5 merge study with HEAD 8d989f6 and the new code path is a clear win over the original PR head, on the same LRP signal and same source models.

Setup

  • Sources: Jackrong/Qwen3.5-4B-Claude-4.6-Opus-Reasoning-Distilled-v2 + Crownelius/Crow-4B-Opus-4.6-Distill-Heretic_Qwen3.5, base Qwen/Qwen3.5-4B.
  • Same precomputed AttnLRP scores reused from the original M4 run (multimodal-prefixed model.language_model.X keys, 427 tensors per source).
  • M4 (orig PR head): hyperparameters from my earlier search — weight: 0.55/0.45, density: 0.53.
  • M4-v2 (this turbo head): equal weight: 1.0/1.0, density: 0.7 — closer to your worked example.
  • Eval: llama-server (Q6_K, -c 32768, --reasoning-format deepseek --reasoning-budget 8192 --parallel 2 --cache-type-k/v q8_0), lm_eval --model local-completions against raw /v1/completions, T=0, max_gen_toks=2048. Identical conditions across all rows.

Results (Q6_K, single 3090)

# Recipe Merger Importance HumanEval pass@1 MBPP pass@1
M1 Vanilla DARE-TIES dare_ties_merge.py none 51.22% 47.00%
M2 OMv2 recipe (OBIM-lite + DAREx-q + EMR election) dare_ties_merge.py none 52.44% 49.40%
M3 OMv2 + Fisher dare_ties_merge.py Fisher 57.93% 🥇 48.80%
M4 ex-LRP (this PR — original head) mergekit PR #682 LRP 51.22% 49.40%
M4-v2 ex-LRP (this PR — turbo head, w=1/1, d=0.7) mergekit PR #682 turbo LRP 55.49% 52.20% 🥇
M5 OMv2 + LRP dare_ties_merge.py LRP 53.05% 51.40%

Δ M4-v2 vs M4-orig: HE +4.27 pp, MBPP +2.80 pp. M4-v2 takes the MBPP medal of the whole study while staying competitive on HumanEval. The turbo branch + rebalanced hyperparams clearly beat the original PR head on this configuration.

M4-v2 weights

Published at: https://huggingface.co/ManniX-ITA/Qwen3.5-4B-M4-v2-ex-LRP-turbo

(Other variants in the study: M1, M2, M3, M4-orig, M5.)

Patches needed against multimodal Qwen3_5ForConditionalGeneration

Running the turbo head against a multimodal-base architecture surfaced a few small breakages worth folding into the PR (happy to send a separate PR/branch if useful):

  • mergekit/architecture/base.pyimport torch (pydantic model_rebuild() evaluates string forward-refs like torch.dtype from PretrainedConfig's dataclass schema).
  • mergekit/architecture/auto.py — layer-aware optional: a layer-template is non-optional only if present in every layer of every model. Hybrid attention (Qwen3.5 alternating linear-attn / full-attn) breaks the current "check layer 0 only" rule.
  • mergekit/merge_methods/lrp.py — when importance is None for a tensor (e.g. vision tower / MTP heads of a multimodal base), continue instead of raise. Treats those tensors as base-passthrough so partial LRP coverage works.
  • mergekit/sparsify.py — added build_mask(importance, density) helper. The turbo LRPMergeTask imports it but it's not in master; trivial top-k binary mask.
  • mergekit/common.pyImmutableMap.get(key, default) and __contains__. Turbo make_task calls params.get("weight", 1.0); ImmutableMap has __getitem__ but not .get().
  • mergekit/config.py — allow str in the ParameterSetting Union so lrp_scores: "/path/to.safetensors" validates as a YAML string instead of being parsed as a ConditionalParameter / list-of-floats / float.

Each is a small additive change, no behavior regression for the existing test cases. Hopefully useful as drop-in for landing the PR with broader arch coverage.

Thanks again — the multimodal support + the turbo math are a real upgrade.

@Tusm11

Tusm11 commented May 1, 2026

Copy link
Copy Markdown
Author

Update: Final Multimodal & Architectural Refinements

Hello @mann1x,
I've just pushed a series of patches to address specific breakages encountered when merging hybrid-architecture models like Qwen 3.5. These changes ensure much broader coverage for multimodal architectures. Based on the patches you suggested.

Key Fixes Included:

-Layer-Aware Weight Detection: Fixed architecture inference for hybrid models (like Qwen 3.5) by updating auto.py to check all layers for optionality. This prevents crashes on alternating attention layers.
-Partial LRP Coverage: Modified lrp.py to treat tensors without precomputed importance (e.g., vision towers or MTP heads) as passthrough rather than raising an error. This allows LRP to merge the language portion while preserving the rest of the architecture.
-Pydantic & Type Stability: Fixed Pydantic model_rebuild() failures in base.py by ensuring proper type evaluation, and enhanced ImmutableMap with .get() and contains for robust parameter handling.
-Performance Optimization: Integrated a dedicated build_mask top-k helper in sparsify.py for more efficient relevance-based sparsification.
-Config Flexibility: Updated YAML validation to correctly handle string paths for lrp_scores, making it easier to reference external relevance files.

@mann1x

mann1x commented May 1, 2026

Copy link
Copy Markdown

Field-test feedback from a Qwen3.5-4B merge study

Hi @Tusm11 — running this PR end-to-end on a 3-model Qwen3.5-4B (base + 2 fine-tunes) merge today, ran into two issues. Reproducible on feature/ex-lrp HEAD (a5a17b6 and 8d989f6). Reporting them here with proposed patches.


Issue 1 — lrp_scores: <path.safetensors> is silently ignored

Repro: any LRP run where the score file is a .safetensors written by safetensors.torch.save_file (the natural format for these).

The loader in mergekit/merge_methods/lrp.py:107 is torch.load(lrp_path, ...). torch.load cannot read safetensors blobs, so it raises — the surrounding try/except catches the failure and sets _lrp_cache[lrp_path] = {}. From that point on, every _lrp_cache[lrp_path].get(self.weight_info.name) returns None and the method silently falls back to delta.abs() for every tensor.

I caught this only because two independent runs (one with pure LRP scores, one with a hybrid Fisher@attn + LRP@mlp signal of mine) produced byte-identical merged weights — both runs had collapsed to magnitude-DARE.

Suggested patch (drop-in, keeps backward compat with .pt/.bin):

                 if lrp_path not in _lrp_cache:
                     try:
-                        _lrp_cache[lrp_path] = torch.load(lrp_path, map_location="cpu")
+                        if str(lrp_path).endswith(".safetensors"):
+                            from safetensors.torch import load_file as _safe_load_file
+                            _lrp_cache[lrp_path] = _safe_load_file(str(lrp_path), device="cpu")
+                        else:
+                            _lrp_cache[lrp_path] = torch.load(lrp_path, map_location="cpu")
                     except Exception:
                         _lrp_cache[lrp_path] = {}

It would also help to warn instead of silently falling back when an lrp_scores: path was supplied but yielded zero hits — would have saved me a few hours.


Issue 2 — Commit a5a17b6 regresses on Qwen3.5 fine-tunes that don't ship MTP layers

a5a17b6 ("refactor: final patches for Qwen 3.5 and multimodal architectural stability") makes the architecture planner treat mtp.layers.* as a required module across all sources. Qwen3.5-4B base ships MTP layers; many of its fine-tunes (e.g. jackrong/.../v2, Crownelius/Crow-4B-Opus-4.6-Distill-Heretic_Qwen3.5) strip them. Result on every merge:

RuntimeError: Tensor mtp.layers.0.mlp.gate_proj.weight required but not present in model .../crow-4b
  at mergekit/io/tasks.py:98

The merge cannot start. Rolling back to 8d989f6 (the prior turbo commit) fixes it because that commit's planner treats per-source missing tensors as base-passthrough rather than fatal.

Suggested fix path (one of):

  1. Make MTP layers optional per source in the planner — emit LoadTensor only for sources whose weight-map actually contains the tensor; for sources where it's missing, fall back to base-passthrough (matches the LRPMergeTask's existing if fine_tuned_weight is None: continue semantics introduced in 8d989f6).
  2. Or, add a config knob like optional_modules: ["mtp.layers"] so users can opt in to lenient mode.

Option 1 is what 8d989f6 effectively did and seems closer to what the LRP method body already expects.


Where this came from

Mid-size study comparing dare-ties / della / OMv2 / LRP / Fisher / hybrid signal on Qwen3.5-4B (3-model merge, HumanEval + MBPP eval pipeline). Happy to share the score table once we have the corrected re-runs in hand. Thanks for the work on this PR — the AttnLRP propagation is doing real work once the loader actually reads it.

@mann1x

mann1x commented May 1, 2026

Copy link
Copy Markdown

Follow-up: planner fix for hybrid architectures (Qwen3.5-4B and similar)

Continuing from the earlier comment: tracked down the underlying reason mergekit-yaml + LRP fails on Qwen3.5-4B regardless of which commit on this branch I tried (11a64d5, 8d989f6, a5a17b6). It's a pre-existing planner issue in infer_architecture_info, not specific to LRP — but PR #682 is where users will hit it most because LRP requires running mergekit-yaml end-to-end.

The bug

In mergekit/architecture/auto.py, _wi() decides whether a layer-template weight is optional based on whether layer 0 has it:

optional = (full_name.replace("${layer_index}", "0") not in in_all_models) or ...

Qwen3.5-4B uses hybrid attention — most layers are linear_attn, every fourth (0, 3, 7, 11, …) is self_attn. Layer 0 specifically has both (transitional). The check therefore marks every linear_attn.* and every self_attn.* template as optional=False. The planner then emits required LoadTensor jobs for both attention types at every layer index, and io/tasks.py:98 raises:

RuntimeError: Tensor model.language_model.layers.7.linear_attn.norm.weight required
              but not present in model jackrong-v2

(Layer 7 only has self_attn; linear_attn weights legitimately don't exist there.)

The same bug also bites the MTP-layers regression I mentioned in the previous comment — except that one became visible only after a5a17b6 started enumerating mtp.layers.* as a separate module.

Fix

A template should be optional if any of its layer instantiations are missing from in_all_models, not just layer 0:

-    def _wi(template: str, prefix: str) -> WeightInfo:
+    def _wi(template: str, prefix: str, num_layers: int = 1) -> WeightInfo:
         full_name = prefix + template
-        optional = (full_name.replace("${layer_index}", "0") not in in_all_models) or (
+        # A template is optional if ANY of its layer instantiations are missing
+        # from in_all_models. Required for hybrid architectures like Qwen3.5
+        # (alternating self_attn / linear_attn per layer) where layer 0 may have
+        # both kinds of weights but later layers may have only one. Without this,
+        # layer 0 satisfies the lookup, but the planner emits required LoadTensor
+        # for `linear_attn.norm.weight` at every layer index 0..N-1 — and a
+        # self_attn-only layer raises at execute time.
+        if "${layer_index}" in full_name:
+            layer_optional = any(
+                full_name.replace("${layer_index}", str(i)) not in in_all_models
+                for i in range(num_layers)
+            )
+        else:
+            layer_optional = full_name not in in_all_models
+        optional = layer_optional or (
             tied_keys is not None
             and any(re.search(pat, full_name) for pat in tied_keys)
         )
@@ -180,9 +194,9 @@ def infer_architecture_info(
             definition=JsonModuleArchDef(
                 model_type="",
                 architectures=[],
-                pre_weights=[_wi(t, "") for t in module_loose_weights[prefix]],
+                pre_weights=[_wi(t, "", num_layers) for t in module_loose_weights[prefix]],
                 layer_templates=JsonLayerTemplates(
-                    weights=[_wi(t, "") for t in module_templates[prefix]]
+                    weights=[_wi(t, "", num_layers) for t in module_templates[prefix]]
                 ),

num_layers is already in scope at the call site as module_layer_counts[prefix], so the closure just needs to thread it in.

Verification

Before the patch, on (Qwen3.5-4B, jackrong-v2, crow-4b):

  • model.language_model.layers: 0 optional / 20 required — all attn templates required at every layer → execute-time crash at layer 7.
  • mtp.layers: 0 optional / 11 required when crow-4b lacks MTP → execute-time crash at layer 0.

After the patch:

  • model.language_model.layers: 15 optional / 5 required — all hybrid attn templates correctly optional, MLP / norm templates remain required.
  • mtp.layers: 11 optional / 0 required — entire missing module gracefully passthroughs from base.

mergekit-yaml then runs to completion on this 3-model set with the LRP merge method.

Combined with the previous comment

This patch + the safetensors loader patch from the previous comment together unblock the full LRP path on Qwen3.5-4B.

(The MTP strict-presence regression in a5a17b6 becomes a non-issue once optionality is computed correctly — those tensors are simply skipped per source. So this fix supersedes the "option 2" suggestion from the previous comment.)

Happy to open a PR against your branch with both patches + a regression test if useful. Otherwise feel free to integrate directly.

@Tusm11

Tusm11 commented May 2, 2026

Copy link
Copy Markdown
Author

Hello @mann1x
Update: Final Qwen 3.5 & Planner Stability Fixes

I've just pushed the set of refinements to unblock the full LRP path for hybrid-architecture models.

What’s New:

Hybrid-Architecture Planner Fix: Updated auto.py to be truly layer-aware. The planner now checks every layer instantiation (0..N-1) rather than just layer 0. This correctly identifies weights as optional in models like Qwen 3.5-4B (which alternates between self_attn and linear_attn), preventing the "Tensor required but not present" crash.
Safetensors Support: Added support for loading relevance scores stored in .safetensors format within the LRP merge method.
MTP/VL Support: These changes together ensure that vision components and MTP layers are gracefully handled (passthrough) when LRP importance scores are not available for those specific modules.

@mann1x

mann1x commented May 2, 2026

Copy link
Copy Markdown

Hey @Tusm11
made a review and all looks good.

It's mergeable, hope @cg123 can give you a go soon.

Thank you!

@Tusm11

Tusm11 commented May 2, 2026

Copy link
Copy Markdown
Author

Thank you for the review @mann1x. I appreciate you taking the time to check it thoroughly.
Glad to hear it looks good. Looking forward to @cg123’s review when they have a chance.

@Tusm11 Tusm11 changed the title feat: upgrade LRP merge method to Transformer-aware eX-LRP with AttnLRP propagation feat: upgrade LRP merge method to Transformer-aware eX-LRP with AttnLRP propagation with Google's TurboQuant support May 3, 2026
@Tusm11

Tusm11 commented May 25, 2026

Copy link
Copy Markdown
Author

Hi @cg123 ,

Just following up on PR #682. The review feedback from @mann1x has been addressed, and the must-fix items were verified in the review thread.

I know everyone is busy, but I wanted to check whether there are any remaining concerns, requested changes, or next steps from the maintainer side.

Thanks again for your time and consideration.

@mann1x

mann1x commented May 25, 2026

Copy link
Copy Markdown

Hey @Tusm11
have you signed the CLA?

https://github.com/arcee-ai/mergekit?tab=contributing-ov-file#contributor-license-agreement-cla

Just post in a comment:

I have read the CLA Document and I hereby sign the CLA

Now I signed it too :)

@Tusm11

Tusm11 commented May 25, 2026

Copy link
Copy Markdown
Author

I have read the CLA Document and I hereby sign the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants