Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion mergekit/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only

import fnmatch
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import yaml
Expand All @@ -12,6 +13,31 @@

ScalarOrGradient: TypeAlias = Union[float, List[float]]

_GLOB_CHARS = frozenset("*?[")


def _filter_matches(filter_str: str, tensor_name: str) -> bool:
"""Return True if filter_str matches tensor_name.

Matching rules (evaluated in order):
1. ``"*"`` matches everything (handled by caller, kept for clarity).
2. If filter_str contains any glob metacharacter (``*``, ``?``, ``[``),
use :func:`fnmatch.fnmatch` so users can write precise patterns such as
``"*self_attn*"`` to avoid accidentally matching ``"linear_attn"``.
3. Otherwise fall back to the original substring check (``filter in tensor_name``).

**Why this matters for hybrid-attention models** (e.g. FLA / linear-attention
architectures): a filter of ``"attn"`` substring-matches *both* ``self_attn``
and ``linear_attn`` weights, causing unintended parameter blending of SSM state
tensors (``A_log``, ``dt_bias``, etc.) that must remain intact. Users can opt
into exact control by writing ``filter: "*self_attn*"`` in their merge YAML.
"""
if not filter_str or not tensor_name:
return False
if _GLOB_CHARS.intersection(filter_str):
return fnmatch.fnmatch(tensor_name, f"*{filter_str}*")
return filter_str in tensor_name


class ConditionalParameter(BaseModel):
value: ScalarOrGradient
Expand Down Expand Up @@ -43,7 +69,7 @@ def evaluate_setting(
if (
(cond.filter is None)
or (cond.filter == "*")
or (tensor_name and cond.filter in tensor_name)
or (tensor_name and _filter_matches(cond.filter, tensor_name))
):
res = evaluate_setting(tensor_name, cond.value, t)
return res
Expand Down
101 changes: 101 additions & 0 deletions tests/test_config_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: LGPL-3.0-only

"""Tests for ConditionalParameter filter matching in evaluate_setting."""

import pytest

from mergekit.config import _filter_matches, evaluate_setting, ConditionalParameter


class TestFilterMatches:
"""Unit tests for _filter_matches helper."""

def test_plain_substring_backward_compat(self):
"""Original behavior: plain string matches as substring."""
assert _filter_matches("attn", "model.layers.0.self_attn.q_proj.weight")
assert _filter_matches("mlp", "model.layers.0.mlp.down_proj.weight")
assert _filter_matches("embed", "model.embed_tokens.weight")

def test_plain_substring_false(self):
assert not _filter_matches("mlp", "model.layers.0.self_attn.q_proj.weight")
assert not _filter_matches("attn", "model.layers.0.mlp.down_proj.weight")

def test_linear_attn_substring_collision(self):
"""Plain 'attn' still matches 'linear_attn' (backward-compatible behavior).
Users should use '*self_attn*' to avoid this."""
# This is the known limitation documented in _filter_matches docstring.
assert _filter_matches("attn", "model.layers.0.linear_attn.A_log")

def test_glob_self_attn_excludes_linear_attn(self):
"""Glob pattern '*self_attn*' matches self_attn but NOT linear_attn."""
assert _filter_matches("*self_attn*", "model.layers.0.self_attn.q_proj.weight")
assert not _filter_matches("*self_attn*", "model.layers.0.linear_attn.A_log")
assert not _filter_matches("*self_attn*", "model.layers.0.linear_attn.dt_bias")

def test_glob_star_wildcard(self):
assert _filter_matches("*.weight", "model.layers.0.mlp.down_proj.weight")
assert not _filter_matches("*.weight", "model.layers.0.mlp.down_proj.bias")

def test_glob_question_mark(self):
assert _filter_matches("layer?.mlp", "model.layer0.mlp.weight")
assert not _filter_matches("layer?.mlp", "model.layer10.mlp.weight")

def test_empty_filter(self):
assert not _filter_matches("", "model.layers.0.self_attn.q_proj.weight")

def test_empty_tensor_name(self):
assert not _filter_matches("attn", "")


class TestEvaluateSetting:
"""Integration tests for evaluate_setting with conditional parameter lists."""

def _make_setting(self, *filter_value_pairs, default=1.0):
conds = [
ConditionalParameter(filter=f, value=v) for f, v in filter_value_pairs
]
conds.append(ConditionalParameter(filter=None, value=default))
return conds

def test_glob_filter_self_attn_only(self):
"""Verify that '*self_attn*' filter selects the correct density for
self_attn weights while leaving linear_attn weights at the default."""
setting = self._make_setting(("*self_attn*", 0.03), default=1.0)

q_proj = "model.layers.0.self_attn.q_proj.weight"
a_log = "model.layers.0.linear_attn.A_log"

assert evaluate_setting(q_proj, setting) == pytest.approx(0.03)
assert evaluate_setting(a_log, setting) == pytest.approx(1.0)

def test_plain_attn_matches_both(self):
"""Plain 'attn' substring matches both self_attn and linear_attn
(documented backward-compatible behavior)."""
setting = self._make_setting(("attn", 0.03), default=1.0)

q_proj = "model.layers.0.self_attn.q_proj.weight"
a_log = "model.layers.0.linear_attn.A_log"

assert evaluate_setting(q_proj, setting) == pytest.approx(0.03)
# Both match under plain substring — callers should use '*self_attn*' to avoid this
assert evaluate_setting(a_log, setting) == pytest.approx(0.03)

def test_wildcard_star_matches_all(self):
setting = self._make_setting(("*", 0.5), default=1.0)
assert evaluate_setting("anything.weight", setting) == pytest.approx(0.5)

def test_none_filter_matches_all(self):
conds = [ConditionalParameter(filter=None, value=0.7)]
assert evaluate_setting("anything.weight", conds) == pytest.approx(0.7)

def test_first_match_wins(self):
"""evaluate_setting returns value for first matching conditional."""
setting = self._make_setting(
("*self_attn*", 0.03),
("mlp", 0.05),
default=1.0,
)
assert evaluate_setting("model.layers.0.self_attn.q_proj.weight", setting) == pytest.approx(0.03)
assert evaluate_setting("model.layers.0.mlp.down_proj.weight", setting) == pytest.approx(0.05)
assert evaluate_setting("model.embed_tokens.weight", setting) == pytest.approx(1.0)
Loading