diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index a726797396..e3e2776e0b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -815,20 +815,16 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; - ChatCompletionToolMessageParam.Builder builder = ChatCompletionToolMessageParam.builder(); - builder.content(toolMessage.getText() != null ? toolMessage.getText() : ""); - builder.role(JsonValue.from(MessageType.TOOL.getValue())); - - if (toolMessage.getResponses().isEmpty()) { - return List.of(ChatCompletionMessageParam.ofTool(builder.build())); - } - return toolMessage.getResponses().stream().map(response -> { - String callId = response.id(); - String callResponse = response.responseData(); - - return ChatCompletionMessageParam - .ofTool(builder.toolCallId(callId).content(callResponse).build()); - }).toList(); + return toolMessage.getResponses() + .stream() + .filter(response -> StringUtils.hasText(response.id())) + .map(response -> { + ChatCompletionToolMessageParam.Builder builder = ChatCompletionToolMessageParam.builder(); + builder.role(JsonValue.from(MessageType.TOOL.getValue())); + return ChatCompletionMessageParam + .ofTool(builder.toolCallId(response.id()).content(response.responseData()).build()); + }) + .toList(); } else { throw new IllegalArgumentException("Unsupported message type: " + message.getMessageType()); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatModelTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatModelTests.java index 2c1c13718b..3c598c5cb8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatModelTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatModelTests.java @@ -21,6 +21,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.stream.Collectors; import com.openai.client.OpenAIClient; import com.openai.client.OpenAIClientAsync; @@ -36,6 +37,8 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.ChatResponseMetadata; import org.springframework.ai.chat.metadata.PromptMetadata; import org.springframework.ai.chat.metadata.RateLimit; @@ -193,6 +196,73 @@ void toolChoiceInvalidJson() { .hasMessageContaining("Failed to parse toolChoice JSON"); } + @Test + void toolResponseMessageWithPopulatedResponses_mapsToOneParamPerResponse() { + OpenAiChatOptions options = OpenAiChatOptions.builder().model("test-model").build(); + OpenAiChatModel chatModel = OpenAiChatModel.builder() + .openAiClient(this.openAiClient) + .openAiClientAsync(this.openAiClientAsync) + .options(options) + .build(); + + ToolResponseMessage toolMsg = ToolResponseMessage.builder() + .responses(List.of(new ToolResponseMessage.ToolResponse("call_1", "myTool", "result1"), + new ToolResponseMessage.ToolResponse("call_2", "myTool", "result2"))) + .build(); + + ChatCompletionCreateParams request = chatModel.createRequest(new Prompt(List.of(toolMsg), options), false); + + List toolCallIds = request.messages() + .stream() + .filter(msg -> msg.isTool()) + .map(msg -> msg.asTool().toolCallId()) + .collect(Collectors.toList()); + + assertThat(toolCallIds).containsExactly("call_1", "call_2"); + } + + @Test + void toolResponseMessageWithEmptyResponses_producesNoMessages() { + OpenAiChatOptions options = OpenAiChatOptions.builder().model("test-model").build(); + OpenAiChatModel chatModel = OpenAiChatModel.builder() + .openAiClient(this.openAiClient) + .openAiClientAsync(this.openAiClientAsync) + .options(options) + .build(); + + ToolResponseMessage emptyToolMsg = ToolResponseMessage.builder().build(); + + ChatCompletionCreateParams request = chatModel + .createRequest(new Prompt(List.of(new UserMessage("hello"), emptyToolMsg), options), false); + + assertThat(request.messages().stream().filter(msg -> msg.isTool()).toList()).isEmpty(); + } + + @Test + void toolResponseMessageWithNullId_skipsResponse() { + OpenAiChatOptions options = OpenAiChatOptions.builder().model("test-model").build(); + OpenAiChatModel chatModel = OpenAiChatModel.builder() + .openAiClient(this.openAiClient) + .openAiClientAsync(this.openAiClientAsync) + .options(options) + .build(); + + ToolResponseMessage toolMsg = ToolResponseMessage.builder() + .responses(List.of(new ToolResponseMessage.ToolResponse("call_1", "myTool", "result1"), + new ToolResponseMessage.ToolResponse(null, "badTool", "result2"))) + .build(); + + ChatCompletionCreateParams request = chatModel.createRequest(new Prompt(List.of(toolMsg), options), false); + + List toolCallIds = request.messages() + .stream() + .filter(msg -> msg.isTool()) + .map(msg -> msg.asTool().toolCallId()) + .collect(Collectors.toList()); + + assertThat(toolCallIds).containsExactly("call_1"); + } + @Test void preserveRateLimitAndPromptMetadataInAggregation() throws Exception { RateLimit rateLimit = mock(RateLimit.class);