Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def start_session(
)
else:
backend = backend_class(
model_id, model_options=model_options, **backend_kwargs
model_id_str, model_options=model_options, **backend_kwargs
)

logger.info(
Expand Down
92 changes: 91 additions & 1 deletion test/plugins/test_hook_call_sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
GenerateType,
ModelOutputThunk,
)
from mellea.plugins import PluginResult, hook, register
from mellea.plugins import HookType, PluginResult, hook, register
from mellea.stdlib.context import SimpleContext

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -194,6 +194,47 @@ async def recorder(payload: Any, ctx: Any) -> Any:

assert observed[0].latency_ms >= 0

async def test_generation_pre_call_mutation_is_applied_before_generation(
self,
) -> None:
"""GENERATION_PRE_CALL mutations reach the backend generation call."""
order: list[str] = []
captured_kwargs: dict[str, Any] = {}

class RecordingBackend(_MockBackend):
async def _generate_from_context(self, action, ctx, **kwargs):
order.append("generate")
captured_kwargs.update(kwargs)
return await super()._generate_from_context(action, ctx, **kwargs)

async def fake_invoke_hook(hook_type, payload, **_kwargs):
assert hook_type is HookType.GENERATION_PRE_CALL
order.append("hook")
modified = payload.model_copy(
update={"model_options": {"temperature": 0.25}, "tool_calls": True}
)
return (
PluginResult(continue_processing=True, modified_payload=modified),
modified,
)

backend = RecordingBackend()

with (
patch("mellea.core.backend.has_plugins", return_value=True),
patch("mellea.core.backend.invoke_hook", side_effect=fake_invoke_hook),
):
await backend.generate_from_context(
CBlock("hook order"),
MagicMock(spec=Context),
model_options={"temperature": 1.0},
tool_calls=False,
)

assert order == ["hook", "generate"]
assert captured_kwargs["model_options"] == {"temperature": 0.25}
assert captured_kwargs["tool_calls"] is True


# ---------------------------------------------------------------------------
# Component hook call sites
Expand Down Expand Up @@ -748,6 +789,55 @@ async def post_recorder(payload: Any, ctx: Any) -> Any:

assert order == ["pre_init", "post_init"]

def test_session_pre_init_mutation_is_applied_before_backend_init(self) -> None:
"""SESSION_PRE_INIT mutations reach the backend constructor."""
from mellea.stdlib.session import start_session

order: list[str] = []
captured_backend_args: dict[str, Any] = {}

class RecordingBackend(_MockBackend):
def __init__(self, model_id, model_options=None, **kwargs):
order.append("backend_init")
captured_backend_args["model_id"] = model_id
captured_backend_args["model_options"] = model_options
captured_backend_args["kwargs"] = kwargs

async def fake_invoke_hook(hook_type, payload, **_kwargs):
assert hook_type is HookType.SESSION_PRE_INIT
order.append("hook")
modified = payload.model_copy(
update={
"model_id": "hook-model",
"model_options": {"temperature": 0.25},
}
)
return (
PluginResult(continue_processing=True, modified_payload=modified),
modified,
)

def has_session_pre_init(hook_type=None):
return hook_type is HookType.SESSION_PRE_INIT

with (
patch(
"mellea.stdlib.session.has_plugins", side_effect=has_session_pre_init
),
patch("mellea.stdlib.session.invoke_hook", side_effect=fake_invoke_hook),
patch(
"mellea.stdlib.session.backend_name_to_class",
return_value=RecordingBackend,
),
):
start_session(
"ollama", model_id="original-model", model_options={"temperature": 1.0}
)

assert order == ["hook", "backend_init"]
assert captured_backend_args["model_id"] == "hook-model"
assert captured_backend_args["model_options"] == {"temperature": 0.25}


# ---------------------------------------------------------------------------
# Mutation tests — verify that hook-modified payloads are actually applied
Expand Down
Loading