diff --git a/src/google/adk/tools/google_search_tool.py b/src/google/adk/tools/google_search_tool.py index 8d73ced8d2..f8d9f0491f 100644 --- a/src/google/adk/tools/google_search_tool.py +++ b/src/google/adk/tools/google_search_tool.py @@ -35,17 +35,26 @@ class GoogleSearchTool(BaseTool): local code execution. """ - def __init__(self, *, bypass_multi_tools_limit: bool = False): + def __init__( + self, + *, + bypass_multi_tools_limit: bool = False, + model: str | None = None, + ): """Initializes the Google search tool. Args: bypass_multi_tools_limit: Whether to bypass the multi tools limitation, so that the tool can be used with other tools in the same agent. + model: Optional model name to use for processing the LLM request. If + provided, this model will be used instead of the model from the + incoming llm_request. """ # Name and description are not used because this is a model built-in tool. super().__init__(name='google_search', description='google_search') self.bypass_multi_tools_limit = bypass_multi_tools_limit + self.model = model @override async def process_llm_request( @@ -54,6 +63,10 @@ async def process_llm_request( tool_context: ToolContext, llm_request: LlmRequest, ) -> None: + # If a custom model is specified, use it instead of the original model + if self.model is not None: + llm_request.model = self.model + llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.tools = llm_request.config.tools or [] if is_gemini_1_model(llm_request.model): diff --git a/tests/unittests/tools/test_google_search_tool.py b/tests/unittests/tools/test_google_search_tool.py index 2f090abb17..8266f6e0d6 100644 --- a/tests/unittests/tools/test_google_search_tool.py +++ b/tests/unittests/tools/test_google_search_tool.py @@ -432,3 +432,46 @@ async def test_process_llm_request_gemini_version_specifics(self): assert len(llm_request.config.tools) == 1 assert llm_request.config.tools[0].google_search is not None assert llm_request.config.tools[0].google_search_retrieval is None + + @pytest.mark.asyncio + @pytest.mark.parametrize( + ( + 'tool_model', + 'request_model', + 'expected_model', + ), + [ + ( + 'gemini-2.5-flash-lite', + 'gemini-2.5-flash', + 'gemini-2.5-flash-lite', + ), + ( + None, + 'gemini-2.5-flash', + 'gemini-2.5-flash', + ), + ], + ids=['with_custom_model', 'without_custom_model'], + ) + async def test_process_llm_request_custom_model_behavior( + self, + tool_model, + request_model, + expected_model, + ): + """Tests custom model parameter behavior in process_llm_request.""" + tool = GoogleSearchTool(model=tool_model) + tool_context = await _create_tool_context() + + llm_request = LlmRequest( + model=request_model, config=types.GenerateContentConfig() + ) + + await tool.process_llm_request( + tool_context=tool_context, llm_request=llm_request + ) + + assert llm_request.model == expected_model + assert llm_request.config.tools is not None + assert len(llm_request.config.tools) == 1