diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index f310748dd..ab2a965f3 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -201,7 +201,7 @@ def start_session( ) else: backend = backend_class( - model_id, model_options=model_options, **backend_kwargs + model_id_str, model_options=model_options, **backend_kwargs ) logger.info( diff --git a/test/plugins/test_hook_call_sites.py b/test/plugins/test_hook_call_sites.py index cda0350d8..a3cae0477 100644 --- a/test/plugins/test_hook_call_sites.py +++ b/test/plugins/test_hook_call_sites.py @@ -28,7 +28,7 @@ GenerateType, ModelOutputThunk, ) -from mellea.plugins import PluginResult, hook, register +from mellea.plugins import HookType, PluginResult, hook, register from mellea.stdlib.context import SimpleContext # --------------------------------------------------------------------------- @@ -194,6 +194,47 @@ async def recorder(payload: Any, ctx: Any) -> Any: assert observed[0].latency_ms >= 0 + async def test_generation_pre_call_mutation_is_applied_before_generation( + self, + ) -> None: + """GENERATION_PRE_CALL mutations reach the backend generation call.""" + order: list[str] = [] + captured_kwargs: dict[str, Any] = {} + + class RecordingBackend(_MockBackend): + async def _generate_from_context(self, action, ctx, **kwargs): + order.append("generate") + captured_kwargs.update(kwargs) + return await super()._generate_from_context(action, ctx, **kwargs) + + async def fake_invoke_hook(hook_type, payload, **_kwargs): + assert hook_type is HookType.GENERATION_PRE_CALL + order.append("hook") + modified = payload.model_copy( + update={"model_options": {"temperature": 0.25}, "tool_calls": True} + ) + return ( + PluginResult(continue_processing=True, modified_payload=modified), + modified, + ) + + backend = RecordingBackend() + + with ( + patch("mellea.core.backend.has_plugins", return_value=True), + patch("mellea.core.backend.invoke_hook", side_effect=fake_invoke_hook), + ): + await backend.generate_from_context( + CBlock("hook order"), + MagicMock(spec=Context), + model_options={"temperature": 1.0}, + tool_calls=False, + ) + + assert order == ["hook", "generate"] + assert captured_kwargs["model_options"] == {"temperature": 0.25} + assert captured_kwargs["tool_calls"] is True + # --------------------------------------------------------------------------- # Component hook call sites @@ -748,6 +789,55 @@ async def post_recorder(payload: Any, ctx: Any) -> Any: assert order == ["pre_init", "post_init"] + def test_session_pre_init_mutation_is_applied_before_backend_init(self) -> None: + """SESSION_PRE_INIT mutations reach the backend constructor.""" + from mellea.stdlib.session import start_session + + order: list[str] = [] + captured_backend_args: dict[str, Any] = {} + + class RecordingBackend(_MockBackend): + def __init__(self, model_id, model_options=None, **kwargs): + order.append("backend_init") + captured_backend_args["model_id"] = model_id + captured_backend_args["model_options"] = model_options + captured_backend_args["kwargs"] = kwargs + + async def fake_invoke_hook(hook_type, payload, **_kwargs): + assert hook_type is HookType.SESSION_PRE_INIT + order.append("hook") + modified = payload.model_copy( + update={ + "model_id": "hook-model", + "model_options": {"temperature": 0.25}, + } + ) + return ( + PluginResult(continue_processing=True, modified_payload=modified), + modified, + ) + + def has_session_pre_init(hook_type=None): + return hook_type is HookType.SESSION_PRE_INIT + + with ( + patch( + "mellea.stdlib.session.has_plugins", side_effect=has_session_pre_init + ), + patch("mellea.stdlib.session.invoke_hook", side_effect=fake_invoke_hook), + patch( + "mellea.stdlib.session.backend_name_to_class", + return_value=RecordingBackend, + ), + ): + start_session( + "ollama", model_id="original-model", model_options={"temperature": 1.0} + ) + + assert order == ["hook", "backend_init"] + assert captured_backend_args["model_id"] == "hook-model" + assert captured_backend_args["model_options"] == {"temperature": 0.25} + # --------------------------------------------------------------------------- # Mutation tests — verify that hook-modified payloads are actually applied