From 4c4ac524e5dd604d6ae4b73c98c28bd6e891adc1 Mon Sep 17 00:00:00 2001 From: akazwz Date: Wed, 17 Jun 2026 01:21:58 -0700 Subject: [PATCH] fix(google): preserve tool call thought signatures --- provider/google/generativeai/generativeai.go | 23 +++++++++++---- .../google/generativeai/generativeai_test.go | 24 +++++++++++++++- sdk/step_helpers.go | 6 +--- sdk/step_helpers_test.go | 28 +++++++++++++++++++ sdk/stream.go | 14 ++++++---- sdk/stream_text.go | 7 +++-- sdk/tool_call.go | 7 +++-- 7 files changed, 85 insertions(+), 24 deletions(-) create mode 100644 sdk/step_helpers_test.go diff --git a/provider/google/generativeai/generativeai.go b/provider/google/generativeai/generativeai.go index 42d0542..6624d8e 100644 --- a/provider/google/generativeai/generativeai.go +++ b/provider/google/generativeai/generativeai.go @@ -415,9 +415,10 @@ func (p *Provider) parseResponse(resp *generateResponse) (*sdk.GenerateResult, e return result, fmt.Errorf("google: unmarshal function call args for %q: %w", part.FunctionCall.Name, err) } result.ToolCalls = append(result.ToolCalls, sdk.ToolCall{ - ToolCallID: id, - ToolName: part.FunctionCall.Name, - Input: input, + ToolCallID: id, + ToolName: part.FunctionCall.Name, + Input: input, + ProviderMetadata: googleThoughtSignatureMetadata(part.ThoughtSignature), }) case part.Text != "": isThought := part.Thought != nil && *part.Thought @@ -572,9 +573,10 @@ func (p *Provider) DoStream(ctx context.Context, params sdk.GenerateParams) (*sd } send(&sdk.StreamToolCallPart{ - ToolCallID: toolCallID, - ToolName: part.FunctionCall.Name, - Input: input, + ToolCallID: toolCallID, + ToolName: part.FunctionCall.Name, + Input: input, + ProviderMetadata: googleThoughtSignatureMetadata(part.ThoughtSignature), }) case part.Text != "": isThought := part.Thought != nil && *part.Thought @@ -738,6 +740,15 @@ func extractGoogleThoughtSignature(meta map[string]any) string { return sig } +func googleThoughtSignatureMetadata(sig string) map[string]any { + if sig == "" { + return nil + } + return map[string]any{ + "google": map[string]any{"thoughtSignature": sig}, + } +} + func classifyError(err error) *sdk.ProviderTestResult { var apiErr *utils.APIError if errors.As(err, &apiErr) { diff --git a/provider/google/generativeai/generativeai_test.go b/provider/google/generativeai/generativeai_test.go index 117e381..cbe0862 100644 --- a/provider/google/generativeai/generativeai_test.go +++ b/provider/google/generativeai/generativeai_test.go @@ -184,6 +184,7 @@ func TestDoGenerate_ToolCall(t *testing.T) { "name": "get_weather", "args": map[string]any{"location": "Beijing"}, }, + "thoughtSignature": "sig-generate", }}, }, "finishReason": "STOP", @@ -240,6 +241,13 @@ func TestDoGenerate_ToolCall(t *testing.T) { if input["location"] != "Beijing" { t.Errorf("location: got %v", input["location"]) } + googleMeta, ok := tc.ProviderMetadata["google"].(map[string]any) + if !ok { + t.Fatalf("tool call provider metadata = %#v, want google map", tc.ProviderMetadata) + } + if googleMeta["thoughtSignature"] != "sig-generate" { + t.Fatalf("thoughtSignature = %#v, want sig-generate", googleMeta["thoughtSignature"]) + } } func TestDoGenerate_ToolCallMultiTurn(t *testing.T) { @@ -261,6 +269,7 @@ func TestDoGenerate_ToolCallMultiTurn(t *testing.T) { Name string `json:"name"` Args any `json:"args"` } `json:"functionCall,omitempty"` + ThoughtSignature string `json:"thoughtSignature,omitempty"` } `json:"parts"` } json.Unmarshal(body.Contents[1], &modelMsg) @@ -270,6 +279,9 @@ func TestDoGenerate_ToolCallMultiTurn(t *testing.T) { if len(modelMsg.Parts) != 1 || modelMsg.Parts[0].FunctionCall == nil { t.Errorf("msg[1] should have functionCall part") } + if modelMsg.Parts[0].ThoughtSignature != "sig-multiturn" { + t.Errorf("msg[1] thoughtSignature: got %q, want sig-multiturn", modelMsg.Parts[0].ThoughtSignature) + } // verify tool result message has functionResponse var toolMsg struct { @@ -323,6 +335,9 @@ func TestDoGenerate_ToolCallMultiTurn(t *testing.T) { ToolCallID: "call_abc", ToolName: "get_weather", Input: map[string]any{"location": "Beijing"}, + ProviderMetadata: map[string]any{ + "google": map[string]any{"thoughtSignature": "sig-multiturn"}, + }, }}, }, { @@ -630,7 +645,7 @@ func TestDoStream_ToolCall(t *testing.T) { flusher := w.(http.Flusher) chunks := []string{ - `{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"location":"Tokyo"}}}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}`, + `{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"get_weather","args":{"location":"Tokyo"}},"thoughtSignature":"sig-tool"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}`, } for _, c := range chunks { fmt.Fprintf(w, "data: %s\n\n", c) @@ -705,6 +720,13 @@ func TestDoStream_ToolCall(t *testing.T) { if !gotFinish { t.Error("missing FinishPart") } + googleMeta, ok := gotToolCall.ProviderMetadata["google"].(map[string]any) + if !ok { + t.Fatalf("tool call provider metadata = %#v, want google map", gotToolCall.ProviderMetadata) + } + if googleMeta["thoughtSignature"] != "sig-tool" { + t.Fatalf("thoughtSignature = %#v, want sig-tool", googleMeta["thoughtSignature"]) + } } func TestDoStream_Reasoning(t *testing.T) { diff --git a/sdk/step_helpers.go b/sdk/step_helpers.go index ccca707..8f3d713 100644 --- a/sdk/step_helpers.go +++ b/sdk/step_helpers.go @@ -57,11 +57,7 @@ func buildStepMessages(text, reasoning string, reasoningMeta map[string]any, too assistantParts = append(assistantParts, TextPart{Text: text}) } for _, tc := range toolCalls { - assistantParts = append(assistantParts, ToolCallPart{ - ToolCallID: tc.ToolCallID, - ToolName: tc.ToolName, - Input: tc.Input, - }) + assistantParts = append(assistantParts, ToolCallPart(tc)) } msgs := []Message{{Role: MessageRoleAssistant, Content: assistantParts, Usage: usage}} diff --git a/sdk/step_helpers_test.go b/sdk/step_helpers_test.go new file mode 100644 index 0000000..dd86f87 --- /dev/null +++ b/sdk/step_helpers_test.go @@ -0,0 +1,28 @@ +package sdk + +import "testing" + +func TestBuildStepMessagesPreservesToolCallProviderMetadata(t *testing.T) { + meta := map[string]any{"google": map[string]any{"thoughtSignature": "sig-1"}} + msgs := buildStepMessages("", "", nil, []ToolCall{{ + ToolCallID: "call-1", + ToolName: "lookup", + Input: map[string]any{"q": "memoh"}, + ProviderMetadata: meta, + }}, nil, nil) + + if len(msgs) != 1 || len(msgs[0].Content) != 1 { + t.Fatalf("unexpected messages: %#v", msgs) + } + part, ok := msgs[0].Content[0].(ToolCallPart) + if !ok { + t.Fatalf("content part = %T, want ToolCallPart", msgs[0].Content[0]) + } + gotGoogle, ok := part.ProviderMetadata["google"].(map[string]any) + if !ok { + t.Fatalf("provider metadata = %#v, want google map", part.ProviderMetadata) + } + if gotGoogle["thoughtSignature"] != "sig-1" { + t.Fatalf("thoughtSignature = %#v, want sig-1", gotGoogle["thoughtSignature"]) + } +} diff --git a/sdk/stream.go b/sdk/stream.go index 37b1195..b292cd4 100644 --- a/sdk/stream.go +++ b/sdk/stream.go @@ -111,9 +111,10 @@ func (p *ToolInputEndPart) Type() StreamPartType { return StreamPartTypeToolInpu // --- Tool Execution --- type StreamToolCallPart struct { - ToolCallID string - ToolName string - Input any + ToolCallID string + ToolName string + Input any + ProviderMetadata map[string]any } func (p *StreamToolCallPart) Type() StreamPartType { return StreamPartTypeToolCall } @@ -264,9 +265,10 @@ func (sr *StreamResult) ToResult() (*GenerateResult, error) { reasoning += p.Text case *StreamToolCallPart: result.ToolCalls = append(result.ToolCalls, ToolCall{ - ToolCallID: p.ToolCallID, - ToolName: p.ToolName, - Input: p.Input, + ToolCallID: p.ToolCallID, + ToolName: p.ToolName, + Input: p.Input, + ProviderMetadata: p.ProviderMetadata, }) case *StreamToolResultPart: result.ToolResults = append(result.ToolResults, ToolResult{ diff --git a/sdk/stream_text.go b/sdk/stream_text.go index 07b4535..7cc38c8 100644 --- a/sdk/stream_text.go +++ b/sdk/stream_text.go @@ -81,9 +81,10 @@ func (c *Client) StreamText(ctx context.Context, options ...GenerateOption) (*St } case *StreamToolCallPart: stepToolCalls = append(stepToolCalls, ToolCall{ - ToolCallID: p.ToolCallID, - ToolName: p.ToolName, - Input: p.Input, + ToolCallID: p.ToolCallID, + ToolName: p.ToolName, + Input: p.Input, + ProviderMetadata: p.ProviderMetadata, }) case *FinishStepPart: stepUsage = p.Usage diff --git a/sdk/tool_call.go b/sdk/tool_call.go index 90a0fde..c8de988 100644 --- a/sdk/tool_call.go +++ b/sdk/tool_call.go @@ -1,9 +1,10 @@ package sdk type ToolCall struct { - ToolCallID string `json:"toolCallId"` - ToolName string `json:"toolName"` - Input any `json:"input"` + ToolCallID string `json:"toolCallId"` + ToolName string `json:"toolName"` + Input any `json:"input"` + ProviderMetadata map[string]any `json:"providerMetadata,omitempty"` } type ToolResult struct {