diff --git a/provider/google/generativeai/generativeai.go b/provider/google/generativeai/generativeai.go index e11a1bf..42d0542 100644 --- a/provider/google/generativeai/generativeai.go +++ b/provider/google/generativeai/generativeai.go @@ -355,9 +355,9 @@ func convertTools(tools []sdk.Tool, toolChoice any) ([]toolGroup, *toolConfig) { decls := make([]functionDeclaration, 0, len(tools)) for _, t := range tools { decls = append(decls, functionDeclaration{ - Name: t.Name, - Description: t.Description, - Parameters: t.Parameters, + Name: t.Name, + Description: t.Description, + ParametersJSONSchema: t.Parameters, }) } diff --git a/provider/google/generativeai/generativeai_test.go b/provider/google/generativeai/generativeai_test.go index 644c0b8..117e381 100644 --- a/provider/google/generativeai/generativeai_test.go +++ b/provider/google/generativeai/generativeai_test.go @@ -85,12 +85,22 @@ func TestDoGenerate(t *testing.T) { func TestDoGenerate_ToolCall(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var rawBody map[string]any + if err := json.NewDecoder(r.Body).Decode(&rawBody); err != nil { + t.Fatalf("decode request body: %v", err) + } + data, err := json.Marshal(rawBody) + if err != nil { + t.Fatalf("marshal request body: %v", err) + } + var body struct { Tools []struct { FunctionDeclarations []struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters any `json:"parameters"` + Name string `json:"name"` + Description string `json:"description"` + Parameters any `json:"parameters"` + ParametersJSONSchema any `json:"parametersJsonSchema"` } `json:"functionDeclarations"` } `json:"tools"` ToolConfig *struct { @@ -99,7 +109,9 @@ func TestDoGenerate_ToolCall(t *testing.T) { } `json:"functionCallingConfig"` } `json:"toolConfig"` } - json.NewDecoder(r.Body).Decode(&body) + if err := json.Unmarshal(data, &body); err != nil { + t.Fatalf("unmarshal request body: %v", err) + } if len(body.Tools) != 1 { t.Fatalf("expected 1 tool group, got %d", len(body.Tools)) @@ -110,6 +122,54 @@ func TestDoGenerate_ToolCall(t *testing.T) { if body.Tools[0].FunctionDeclarations[0].Name != "get_weather" { t.Errorf("tool name: got %q", body.Tools[0].FunctionDeclarations[0].Name) } + if body.Tools[0].FunctionDeclarations[0].Parameters != nil { + t.Errorf("expected parameters to be omitted") + } + if body.Tools[0].FunctionDeclarations[0].ParametersJSONSchema == nil { + t.Errorf("expected parametersJsonSchema to be set") + } + + toolsRaw, ok := rawBody["tools"].([]any) + if !ok || len(toolsRaw) != 1 { + t.Fatalf("raw tools: got %T %v", rawBody["tools"], rawBody["tools"]) + } + toolRaw, ok := toolsRaw[0].(map[string]any) + if !ok { + t.Fatalf("raw tool type: got %T", toolsRaw[0]) + } + declsRaw, ok := toolRaw["functionDeclarations"].([]any) + if !ok || len(declsRaw) != 1 { + t.Fatalf("raw functionDeclarations: got %T %v", toolRaw["functionDeclarations"], toolRaw["functionDeclarations"]) + } + declRaw, ok := declsRaw[0].(map[string]any) + if !ok { + t.Fatalf("raw function declaration type: got %T", declsRaw[0]) + } + if _, ok := declRaw["parameters"]; ok { + t.Errorf("expected raw parameters key to be omitted") + } + schema, ok := declRaw["parametersJsonSchema"].(map[string]any) + if !ok { + t.Fatalf("raw parametersJsonSchema type: got %T", declRaw["parametersJsonSchema"]) + } + if schema["additionalProperties"] != false { + t.Errorf("additionalProperties: got %v, want false", schema["additionalProperties"]) + } + props, ok := schema["properties"].(map[string]any) + if !ok { + t.Fatalf("properties type: got %T", schema["properties"]) + } + location, ok := props["location"].(map[string]any) + if !ok { + t.Fatalf("location property type: got %T", props["location"]) + } + if location["type"] != "string" { + t.Errorf("location type: got %v, want string", location["type"]) + } + required, ok := schema["required"].([]any) + if !ok || len(required) != 1 || required[0] != "location" { + t.Errorf("required: got %v, want [location]", schema["required"]) + } if body.ToolConfig == nil || body.ToolConfig.FunctionCallingConfig.Mode != "AUTO" { t.Errorf("expected AUTO tool config mode") } @@ -153,7 +213,8 @@ func TestDoGenerate_ToolCall(t *testing.T) { Properties: map[string]*jsonschema.Schema{ "location": {Type: "string"}, }, - Required: []string{"location"}, + Required: []string{"location"}, + AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, }, }}, ToolChoice: "auto", @@ -1068,6 +1129,58 @@ func TestIntegration_ToolCall(t *testing.T) { } } +func TestIntegration_ToolCallWithAdditionalPropertiesSchema(t *testing.T) { + p := newIntegrationProvider(t) + model := integrationModel(t) + model.Provider = p + + result, err := p.DoGenerate(context.Background(), sdk.GenerateParams{ + Model: model, + Messages: []sdk.Message{{ + Role: sdk.MessageRoleUser, + Content: []sdk.MessagePart{sdk.TextPart{ + Text: "Call get_weather with location San Francisco.", + }}, + }}, + Tools: []sdk.Tool{{ + Name: "get_weather", + Description: "Get the weather for a location.", + Parameters: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "location": {Type: "string", Description: "City name"}, + }, + Required: []string{"location"}, + AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, + }, + }}, + ToolChoice: "required", + }) + if err != nil { + t.Fatalf("DoGenerate with additionalProperties in tool schema: %v", err) + } + t.Logf("finish=%s toolCalls=%d text=%q", result.FinishReason, len(result.ToolCalls), result.Text) + + if result.FinishReason != sdk.FinishReasonToolCalls { + t.Errorf("finish: got %q, want %q", result.FinishReason, sdk.FinishReasonToolCalls) + } + if len(result.ToolCalls) == 0 { + t.Fatal("expected at least one tool call") + } + tc := result.ToolCalls[0] + t.Logf(" tool=%q id=%s input=%v", tc.ToolName, tc.ToolCallID, tc.Input) + if tc.ToolName != "get_weather" { + t.Errorf("tool name: got %q, want get_weather", tc.ToolName) + } + input, ok := tc.Input.(map[string]any) + if !ok { + t.Fatalf("input type: got %T", tc.Input) + } + if location, ok := input["location"].(string); !ok || location == "" { + t.Errorf("location input: got %v", input["location"]) + } +} + // ---------- ListModels / Test / TestModel unit tests ---------- func TestListModels(t *testing.T) { diff --git a/provider/google/generativeai/types.go b/provider/google/generativeai/types.go index 1214410..1c36a61 100644 --- a/provider/google/generativeai/types.go +++ b/provider/google/generativeai/types.go @@ -3,12 +3,12 @@ package generativeai // --- Request types --- type generateRequest struct { - Contents []content `json:"contents"` - SystemInstruction *content `json:"systemInstruction,omitempty"` + Contents []content `json:"contents"` + SystemInstruction *content `json:"systemInstruction,omitempty"` GenerationConfig *generationConfig `json:"generationConfig,omitempty"` - Tools []toolGroup `json:"tools,omitempty"` - ToolConfig *toolConfig `json:"toolConfig,omitempty"` - SafetySettings []safetySetting `json:"safetySettings,omitempty"` + Tools []toolGroup `json:"tools,omitempty"` + ToolConfig *toolConfig `json:"toolConfig,omitempty"` + SafetySettings []safetySetting `json:"safetySettings,omitempty"` } type content struct { @@ -69,9 +69,9 @@ type toolGroup struct { } type functionDeclaration struct { - Name string `json:"name"` - Description string `json:"description"` - Parameters any `json:"parameters,omitempty"` + Name string `json:"name"` + Description string `json:"description"` + ParametersJSONSchema any `json:"parametersJsonSchema,omitempty"` } type toolConfig struct { @@ -91,14 +91,14 @@ type safetySetting struct { // --- Response types --- type generateResponse struct { - Candidates []candidate `json:"candidates"` - UsageMetadata *usageMetadata `json:"usageMetadata,omitempty"` + Candidates []candidate `json:"candidates"` + UsageMetadata *usageMetadata `json:"usageMetadata,omitempty"` PromptFeedback *promptFeedback `json:"promptFeedback,omitempty"` } type candidate struct { - Content *content `json:"content,omitempty"` - FinishReason string `json:"finishReason,omitempty"` + Content *content `json:"content,omitempty"` + FinishReason string `json:"finishReason,omitempty"` SafetyRatings []safetyRating `json:"safetyRatings,omitempty"` }