diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5da4ea4b..0436f34e 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -53,7 +53,7 @@ jobs: uses: goreleaser/goreleaser-action@v7 with: # renovate: datasource=github-releases depName=goreleaser/goreleaser - version: "v2.14.3" + version: "v2.15.2" args: release --clean env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 2fc26655..329a036a 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -104,13 +104,7 @@ winget: # in case there is an indicator for prerelease in the tag e.g. v1.0.0-rc1 skip_upload: auto - # Release notes. - # - # If you want to use the release notes generated by GoReleaser, use - # `{{.Changelog}}` as the value. - # - # Templates: allowed - release_notes: "{{.Changelog}}" + release_notes_url: "https://github.com/overmindtech/cli/releases/tag/{{ .Tag }}" # Repository to push the generated files to. repository: diff --git a/.terraform.lock.hcl b/.terraform.lock.hcl index 4d8e919b..f40171f6 100644 --- a/.terraform.lock.hcl +++ b/.terraform.lock.hcl @@ -2,37 +2,37 @@ # Manual edits may be lost in future updates. provider "registry.terraform.io/hashicorp/aws" { - version = "6.38.0" + version = "6.39.0" constraints = ">= 4.56.0" hashes = [ - "h1:2NhckHRVSF36dyFYd15myz4AjReQbOYgNQ/7Li29wcc=", - "h1:3MP3AOAntDTXnrWry8XlWL+6M8rMRhlTp0ZUxNyrH/A=", - "h1:7EtWzjLeg1qmbf7mSOvy3T/alXehJRCZxFu1Et/IHkw=", - "h1:7F3W4qGLTbr4aploSI8eIqE4AueoNe/Tq5Osuo0IgJ4=", - "h1:7al1E+/zrmxL7qgGRoR7b//8d2oJxpguLEuFnBmRCVY=", - "h1:GOlIkFIhuKpVeBG+m77rXhiT7/r1AKFfhokm9WdZj1o=", - "h1:IMf41BcW9huOeVcrt6XjQqadYR2xD8zkUpGLLERJ4NM=", - "h1:RDoKIzXmt7H1mNFcNIyRT+nA/gTJyO3+iW9QGN5I2eQ=", - "h1:Yx7kDYSjFBAuvq3nbmgy+N9+ilJB8NNsIrBusm0nm5w=", - "h1:fJlpWm8M4RGM5TZGeblUiP3WqhUI0zV0hM+NAJ9lVlY=", - "h1:kQD9Ehy9Iy+11jp2JkKE3I62DxshXeIMhA/su43Md9k=", - "h1:t/tF5CzmNbAxPVVVYjojbK2T60b4X/5RINmY8yKbu1E=", - "h1:thWyDCjV9CmSOSWBTCrG/P3bNlYBzRl6QVj0WcSisLM=", - "h1:wBWvJVZJBEQT6ty8LYk/QoS7xM2E+zAEwKTqFlvayGw=", - "zh:143f118ae71059a7a7026c6b950da23fef04a06e2362ffa688bef75e43e869ed", - "zh:29ee220a017306effd877e1280f8b2934dc957e16e0e72ca0222e5514d0db522", - "zh:3a31baabf7aea7aa7669f5a3d76f3445e0e6cce5e9aea0279992765c0df12aee", - "zh:4c1908e62040dbc9901d4426ffb253f53e5dae9e3e1a9125311291ee265c8d8c", - "zh:550f4789f5f5b00e16118d4c17770be3ef4535d6b6928af1cf91ebd30f2c263b", - "zh:6537b7b70bf2c127771b0b84e4b726c834d10666b6104f017edae50c67ebae37", + "h1:2VV7im+l1aTGOdEtkbojBTxma4yG2ctsjpIt6mC4mXU=", + "h1:35OBeunWJLXG9xQj0p8YtzF6QHSGJ1U89rNuAeIJHUk=", + "h1:AGpJ/Nr+fI9mMAK3XAZ/nve1CBWjPSlRVW5oZrljPDo=", + "h1:C6npjQ4Tv3Zuz0yOGhnUzbOMavOz84hkTw+dS/u++sc=", + "h1:Gc5A9vbYp3jEqKiO/tBAs8pssPq2n4e+7UGFJWOFKTU=", + "h1:H3uZqogXDrbDVCrdX5BH7e8dTO2PIIJHhc2u+BbU8MY=", + "h1:QwtlSBjpxlw3Almi8cMWWlTkRj1Od5jBBEFYR+8XXcs=", + "h1:TpotTvnGakFtGjCpAVpDkmG2sNQxkl3Mz7+j3x0TB7w=", + "h1:XfyHGbKbpOJYnmn1ztPe2GDIImS7OdUCIeI+QPIFhb8=", + "h1:bRfrZGx++Bbyo9HfUCNqMghxI+M+Vg5W1AfQqtgQrnY=", + "h1:f59AV0JXoocRC3K1XIvArY0OBmPinn5NOYITPmnxzwQ=", + "h1:iPaoWnLfDWDBFgBcCCct8TLrK2MZ6skDApETYxYdATc=", + "h1:jweey4Iefm/DuuBg84saQ8vz5IO3vC6hDFTU/eGdmBI=", + "h1:l6eMEVbdjjbgQcezuY3Iyk1FpdwYjCiEPAk/CidhqEI=", + "zh:00c3e3c38063ff629d6fdbce04e9ac2e241566e0f5ad5399c335f0abdefd7bff", + "zh:148f95b62791080537d926b9d2f5d8457cca45921d9b1019d03ceb3ab93bf9db", + "zh:203da629ed5191dd5d7aa3427a5d1d1a83eed5c1b0114166897206973f0d0fd0", + "zh:21923eedbc60b4f68c8d717b951d16b0b1bbf31d66330c7be228869bec18f7ce", + "zh:26226f02e3661b3d071c01601b654a308b29d21758b75692bec66f70c6f6b945", + "zh:271c7c6fadcd8ac7ed37c11e61c0f374773eaaa5293703499f8a0f75830060e0", + "zh:46e319a8888dc50ed8d26a1cbee9637f529112a88f5d44decc8f1d10ef968ffe", "zh:9b12af85486a96aedd8d7984b0ff811a4b42e3d88dad1a3fb4c0b580d04fa425", - "zh:af2f9cea0c8bdf5b2a2391f2d179a946c117196f7c829b919673cae3b71d2943", - "zh:c53ffa685381aa4e73158fd9f529239f95938dea330e7aca0b32e7b2a1210432", - "zh:d0995e1d64a7ec8bbc79fc3fbec3749f989e07f211a318705c37cd6a7c7d19e4", - "zh:d2348ffcffc1282983d7a5838dd5d61f372152fe6c0d10868cd6473352318750", - "zh:e449312efb73e4747165e689302a68a1df8ba5755e7f59097069acf82c94f011", - "zh:ec3a538d264ef79380e56fdf107ffb6c0446814f07fc5890c36855fe1e03196b", - "zh:f441e69699b22e32c96a8cdd3bbe694ed302c0dcfe867cd9bd683a16df362714", - "zh:f6f8eaa605ff902234d7e9bdab4fda977185fce14f8576f7b622c914c7d98008", + "zh:a3c3ca09cdbf3b9a3f892a23c000ff04772bdf19f626959ea83d0803c8fd2350", + "zh:a5fa6515ffc3c815e0d2204d67e838f5bad8635009dab85211d166c7ae729d2c", + "zh:c0807566b4ddde8390f50c5475464103f066bc7f511a6c0be762d75cb6d1a078", + "zh:da754a529fd0e06ac372f62d88566f85a8c4bcec7ee9a231b65e0a0148165e63", + "zh:dcb768e48363a9f4dffaf2dc7d01f1877285528925ec50de6335286298e37e1d", + "zh:eac9de9d123c679ea3035199fb9c588a08cda281cbabf948dc696e2a1a1b9063", + "zh:fef276b6331c663ca0e60dc7f637b2b8244825b8c9bc721481957e58f74ffb4f", ] } diff --git a/aws-source/build/package/Dockerfile b/aws-source/build/package/Dockerfile index 6f9aba9d..4865ad14 100644 --- a/aws-source/build/package/Dockerfile +++ b/aws-source/build/package/Dockerfile @@ -10,8 +10,12 @@ RUN apk upgrade --no-cache && apk add --no-cache git WORKDIR /workspace -# Copy the go source -COPY . . +COPY go.mod go.sum ./ +RUN --mount=type=cache,target=/go/pkg \ + go mod download + +COPY go/ go/ +COPY aws-source/ aws-source/ # Build RUN --mount=type=cache,target=/go/pkg \ diff --git a/aws-source/module/provider/.github/workflows/release.yml b/aws-source/module/provider/.github/workflows/release.yml index 465023d4..d937b6b4 100644 --- a/aws-source/module/provider/.github/workflows/release.yml +++ b/aws-source/module/provider/.github/workflows/release.yml @@ -46,7 +46,7 @@ jobs: uses: goreleaser/goreleaser-action@v7 with: # renovate: datasource=github-releases depName=goreleaser/goreleaser - version: "v2.14.3" + version: "v2.15.2" args: release --clean env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/go.mod b/go.mod index 14ad431b..0ed57270 100644 --- a/go.mod +++ b/go.mod @@ -13,34 +13,34 @@ require ( buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.11-20260209202127-80ab13bee0bf.1 buf.build/go/protovalidate v1.1.3 charm.land/lipgloss/v2 v2.0.2 - cloud.google.com/go/aiplatform v1.121.0 + cloud.google.com/go/aiplatform v1.122.0 cloud.google.com/go/auth v0.19.0 cloud.google.com/go/auth/oauth2adapt v0.2.8 - cloud.google.com/go/bigquery v1.74.0 - cloud.google.com/go/bigtable v1.43.0 - cloud.google.com/go/certificatemanager v1.9.6 - cloud.google.com/go/compute v1.57.0 + cloud.google.com/go/bigquery v1.75.0 + cloud.google.com/go/bigtable v1.45.0 + cloud.google.com/go/certificatemanager v1.10.0 + cloud.google.com/go/compute v1.58.0 cloud.google.com/go/compute/metadata v0.9.0 // indirect - cloud.google.com/go/container v1.46.0 - cloud.google.com/go/dataplex v1.29.0 - cloud.google.com/go/dataproc/v2 v2.16.0 - cloud.google.com/go/eventarc v1.18.0 - cloud.google.com/go/filestore v1.10.3 - cloud.google.com/go/functions v1.19.7 - cloud.google.com/go/iam v1.6.0 - cloud.google.com/go/kms v1.26.0 - cloud.google.com/go/logging v1.13.2 - cloud.google.com/go/monitoring v1.24.3 - cloud.google.com/go/networksecurity v0.11.0 - cloud.google.com/go/orgpolicy v1.15.1 - cloud.google.com/go/redis v1.18.3 - cloud.google.com/go/resourcemanager v1.10.7 - cloud.google.com/go/run v1.16.0 - cloud.google.com/go/secretmanager v1.16.0 - cloud.google.com/go/securitycentermanagement v1.1.6 + cloud.google.com/go/container v1.47.0 + cloud.google.com/go/dataplex v1.30.0 + cloud.google.com/go/dataproc/v2 v2.17.0 + cloud.google.com/go/eventarc v1.19.0 + cloud.google.com/go/filestore v1.11.0 + cloud.google.com/go/functions v1.20.0 + cloud.google.com/go/iam v1.7.0 + cloud.google.com/go/kms v1.27.0 + cloud.google.com/go/logging v1.14.0 + cloud.google.com/go/monitoring v1.25.0 + cloud.google.com/go/networksecurity v0.12.0 + cloud.google.com/go/orgpolicy v1.16.0 + cloud.google.com/go/redis v1.19.0 + cloud.google.com/go/resourcemanager v1.11.0 + cloud.google.com/go/run v1.17.0 + cloud.google.com/go/secretmanager v1.17.0 + cloud.google.com/go/securitycentermanagement v1.2.0 cloud.google.com/go/spanner v1.89.0 cloud.google.com/go/storage v1.61.3 - cloud.google.com/go/storagetransfer v1.13.1 + cloud.google.com/go/storagetransfer v1.14.0 connectrpc.com/connect v1.18.1 // v1.19.0 was faulty, wait until it is above this version github.com/Azure/azure-sdk-for-go/sdk/azcore v1.21.0 github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 @@ -53,6 +53,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault/v2 v2.0.2 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.3.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9 v9.0.0 + github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights v1.2.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5 v5.0.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.3.0 github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2 v2.1.0 @@ -62,17 +63,17 @@ require ( github.com/MrAlias/otel-schema-utils v0.4.0-alpha github.com/auth0/go-jwt-middleware/v3 v3.0.0 github.com/aws/aws-sdk-go-v2 v1.41.5 - github.com/aws/aws-sdk-go-v2/config v1.32.13 - github.com/aws/aws-sdk-go-v2/credentials v1.19.13 + github.com/aws/aws-sdk-go-v2/config v1.32.14 + github.com/aws/aws-sdk-go-v2/credentials v1.19.14 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 github.com/aws/aws-sdk-go-v2/service/apigateway v1.39.1 - github.com/aws/aws-sdk-go-v2/service/autoscaling v1.64.4 - github.com/aws/aws-sdk-go-v2/service/cloudfront v1.60.4 - github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.55.3 + github.com/aws/aws-sdk-go-v2/service/autoscaling v1.65.0 + github.com/aws/aws-sdk-go-v2/service/cloudfront v1.61.0 + github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0 github.com/aws/aws-sdk-go-v2/service/directconnect v1.38.15 github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1 - github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.1 - github.com/aws/aws-sdk-go-v2/service/ecs v1.74.1 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2 + github.com/aws/aws-sdk-go-v2/service/ecs v1.76.0 github.com/aws/aws-sdk-go-v2/service/efs v1.41.14 github.com/aws/aws-sdk-go-v2/service/eks v1.81.2 github.com/aws/aws-sdk-go-v2/service/elasticloadbalancing v1.33.23 @@ -84,21 +85,21 @@ require ( github.com/aws/aws-sdk-go-v2/service/networkmanager v1.41.8 github.com/aws/aws-sdk-go-v2/service/rds v1.117.1 github.com/aws/aws-sdk-go-v2/service/route53 v1.62.5 - github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 + github.com/aws/aws-sdk-go-v2/service/s3 v1.98.0 github.com/aws/aws-sdk-go-v2/service/sns v1.39.15 github.com/aws/aws-sdk-go-v2/service/sqs v1.42.25 github.com/aws/aws-sdk-go-v2/service/ssm v1.68.4 github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 - github.com/aws/smithy-go v1.24.2 + github.com/aws/smithy-go v1.24.3 github.com/cenkalti/backoff/v5 v5.0.3 github.com/charmbracelet/glamour v0.10.0 github.com/coder/websocket v1.8.14 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/getsentry/sentry-go v0.44.1 - github.com/go-jose/go-jose/v4 v4.1.3 + github.com/go-jose/go-jose/v4 v4.1.4 github.com/google/btree v1.1.3 github.com/google/uuid v1.6.0 - github.com/googleapis/gax-go/v2 v2.20.0 + github.com/googleapis/gax-go/v2 v2.21.0 github.com/goombaio/namegenerator v0.0.0-20181006234301-989e774b106e github.com/hashicorp/go-retryablehttp v0.7.8 github.com/hashicorp/hcl/v2 v2.24.0 @@ -106,7 +107,7 @@ require ( github.com/hashicorp/terraform-plugin-framework v1.19.0 github.com/hashicorp/terraform-plugin-go v0.31.0 github.com/hashicorp/terraform-plugin-testing v1.15.0 - github.com/jedib0t/go-pretty/v6 v6.7.8 + github.com/jedib0t/go-pretty/v6 v6.7.9 github.com/lithammer/fuzzysearch v1.1.8 // indirect github.com/micahhausler/aws-iam-policy v0.4.4 github.com/miekg/dns v1.1.72 @@ -134,12 +135,12 @@ require ( go.etcd.io/bbolt v1.4.3 go.opentelemetry.io/contrib/detectors/aws/ec2/v2 v2.0.0-20250901115419-474a7992e57c go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 - go.opentelemetry.io/otel v1.42.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 - go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 - go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0 - go.opentelemetry.io/otel/sdk v1.42.0 - go.opentelemetry.io/otel/trace v1.42.0 + go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 + go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 + go.opentelemetry.io/otel/sdk v1.43.0 + go.opentelemetry.io/otel/trace v1.43.0 go.uber.org/automaxprocs v1.6.0 go.uber.org/goleak v1.3.0 go.uber.org/mock v0.6.0 @@ -149,9 +150,9 @@ require ( golang.org/x/sync v0.20.0 golang.org/x/text v0.35.0 gonum.org/v1/gonum v0.17.0 - google.golang.org/api v0.273.0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 - google.golang.org/grpc v1.79.3 + google.golang.org/api v0.274.0 + google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 + google.golang.org/grpc v1.80.0 google.golang.org/protobuf v1.36.11 gopkg.in/ini.v1 v1.67.1 k8s.io/api v0.35.3 @@ -173,7 +174,7 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v3 v3.0.1 // indirect github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 // indirect github.com/BurntSushi/toml v1.4.0 // indirect - github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 // indirect + github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 // indirect github.com/ProtonMail/go-crypto v1.3.0 // indirect @@ -196,8 +197,8 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.21 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.21 // indirect github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 // indirect - github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 // indirect - github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/aymerick/douceur v0.2.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect @@ -325,10 +326,10 @@ require ( go.opentelemetry.io/contrib/detectors/gcp v1.39.0 // indirect go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0 // indirect go.opentelemetry.io/otel/log v0.11.0 // indirect - go.opentelemetry.io/otel/metric v1.42.0 // indirect + go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/otel/schema v0.0.12 // indirect - go.opentelemetry.io/otel/sdk/metric v1.42.0 // indirect - go.opentelemetry.io/proto/otlp v1.9.0 // indirect + go.opentelemetry.io/otel/sdk/metric v1.43.0 // indirect + go.opentelemetry.io/proto/otlp v1.10.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/crypto v0.49.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect @@ -341,7 +342,7 @@ require ( golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/appengine v1.6.8 // indirect google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 3933647b..733b65fd 100644 --- a/go.sum +++ b/go.sum @@ -18,66 +18,66 @@ charm.land/lipgloss/v2 v2.0.2 h1:xFolbF8JdpNkM2cEPTfXEcW1p6NRzOWTSamRfYEw8cs= charm.land/lipgloss/v2 v2.0.2/go.mod h1:KjPle2Qd3YmvP1KL5OMHiHysGcNwq6u83MUjYkFvEkM= cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE= cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU= -cloud.google.com/go/aiplatform v1.121.0 h1:8y8sNfVAW1DVhFbSbI7d8rrqBGGJFk6EoV6atidlyQc= -cloud.google.com/go/aiplatform v1.121.0/go.mod h1:juMdDWeNphHV40KhWdN+563zNCOKNmLJjk5D2TA43ls= +cloud.google.com/go/aiplatform v1.122.0 h1:l9euMxde3/yL/txA1ngCM6zyTSXbjmos/OtUXDAeqhw= +cloud.google.com/go/aiplatform v1.122.0/go.mod h1:Z0QhICmtYNxf0fbl/axHrVXCGcfdYuedH5NK+GwVu1I= cloud.google.com/go/auth v0.19.0 h1:DGYwtbcsGsT1ywuxsIoWi1u/vlks0moIblQHgSDgQkQ= cloud.google.com/go/auth v0.19.0/go.mod h1:2Aph7BT2KnaSFOM0JDPyiYgNh6PL9vGMiP8CUIXZ+IY= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= -cloud.google.com/go/bigquery v1.74.0 h1:Q6bAMv+eyvufOpIrfrYxhM46qq1D3ZQTdgUDQqKS+n8= -cloud.google.com/go/bigquery v1.74.0/go.mod h1:iViO7Cx3A/cRKcHNRsHB3yqGAMInFBswrE9Pxazsc90= -cloud.google.com/go/bigtable v1.43.0 h1:ysbiG+AZElMELOKDrTkHF2N9xPzj5Dl7tmPr92/FUiQ= -cloud.google.com/go/bigtable v1.43.0/go.mod h1:uH4AQlKpsGMZvicLU+RJs263UiTQhQKzXFmANPjOSUk= -cloud.google.com/go/certificatemanager v1.9.6 h1:v5X8X+THKrS9OFZb6k0GRDP1WQxLXTdMko7OInBliw4= -cloud.google.com/go/certificatemanager v1.9.6/go.mod h1:vWogV874jKZkSRDFCMM3r7wqybv8WXs3XhyNff6o/Zo= -cloud.google.com/go/compute v1.57.0 h1:uACoYJCUftJxxoI7si8u1S9szRDalftrWSjo1Dizfx4= -cloud.google.com/go/compute v1.57.0/go.mod h1:3shEe5By6FSIqBbZJBuqC0InvJKBKUiWZjrwGd1wkyA= +cloud.google.com/go/bigquery v1.75.0 h1:gI4AgIhXNZ8hxvPDOp4hLGUnpNBjoBor6POSLcrdWkY= +cloud.google.com/go/bigquery v1.75.0/go.mod h1:zNCHWok+hfTgKCwNqT+V7GH/YmFFgZqjzljKCZBJTWc= +cloud.google.com/go/bigtable v1.45.0 h1:OWl6kq6Ju8m4fZkV3bve7n9882W4HrlXPqtW5zT5ttA= +cloud.google.com/go/bigtable v1.45.0/go.mod h1:Ztklmotutn5zkAYzsn2w8ye8wvy+azwyGwYmujW5JHg= +cloud.google.com/go/certificatemanager v1.10.0 h1:3cyEmZs+sK+dXsO9R04ZprBInovTVpMz0DfFidiwY4Y= +cloud.google.com/go/certificatemanager v1.10.0/go.mod h1:dfJSJnWK2wgw2BTQ5oaAalombcajsviqsA6liniLL/U= +cloud.google.com/go/compute v1.58.0 h1:WQamD8d9Vxu+m5IJBxgUzqsbHxXwx/T7Ol4V5GyVCZE= +cloud.google.com/go/compute v1.58.0/go.mod h1:5/1KiFDIdFPcx/Fw6pFQpcQaBGc6MGtgSfczeLFBj7o= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= -cloud.google.com/go/container v1.46.0 h1:xX94Lo3xrS5OkdMWKvpEVAbBwjN9uleVv6vOi02fL4s= -cloud.google.com/go/container v1.46.0/go.mod h1:A7gMqdQduTk46+zssWDTKbGS2z46UsJNXfKqvMI1ZO4= +cloud.google.com/go/container v1.47.0 h1:bQFpaGHEnNOv3FM8tYj3efzi/omInP/YO/kzZrLjTRU= +cloud.google.com/go/container v1.47.0/go.mod h1:ySZKwTx9fTEqo1YiWN09SrM96a/dgpxXsi1tEzWGQd0= cloud.google.com/go/datacatalog v1.26.1 h1:bCRKA8uSQN8wGW3Tw0gwko4E9a64GRmbW1nCblhgC2k= cloud.google.com/go/datacatalog v1.26.1/go.mod h1:2Qcq8vsHNxMDgjgadRFmFG47Y+uuIVsyEGUrlrKEdrg= -cloud.google.com/go/dataplex v1.29.0 h1:g1RsvpxELtGdVwmuOiktBM6BPfFy8TyNzmWvf+6yDgc= -cloud.google.com/go/dataplex v1.29.0/go.mod h1:32rAjJhxo1tY5KivJ33872X5ZqR6ZjlE5ng5Uz7+hH0= -cloud.google.com/go/dataproc/v2 v2.16.0 h1:0g2hnjlQ8SQTnNeu+Bqqa61QPssfSZF3t+9ldRmx+VQ= -cloud.google.com/go/dataproc/v2 v2.16.0/go.mod h1:HlzFg8k1SK+bJN3Zsy2z5g6OZS1D4DYiDUgJtF0gJnE= -cloud.google.com/go/eventarc v1.18.0 h1:8WWG1/ogInYur1NQjML6EMHQ0ZBzAdMDGlUVpLD56cI= -cloud.google.com/go/eventarc v1.18.0/go.mod h1:/6SDoqh5+9QNUqCX4/oQcJVK16fG/snHBSXu7lrJtO8= -cloud.google.com/go/filestore v1.10.3 h1:3KZifUVTqGhNNv6MLeONYth1HjlVM4vDhaH+xrdPljU= -cloud.google.com/go/filestore v1.10.3/go.mod h1:94ZGyLTx9j+aWKozPQ6Wbq1DuImie/L/HIdGMshtwac= -cloud.google.com/go/functions v1.19.7 h1:7LcOD18euIVGRUPaeCmgO6vfWSLNIsi6STWRQcdANG8= -cloud.google.com/go/functions v1.19.7/go.mod h1:xbcKfS7GoIcaXr2FSwmtn9NXal1JR4TV6iYZlgXffwA= -cloud.google.com/go/iam v1.6.0 h1:JiSIcEi38dWBKhB3BtfKCW+dMvCZJEhBA2BsaGJgoxs= -cloud.google.com/go/iam v1.6.0/go.mod h1:ZS6zEy7QHmcNO18mjO2viYv/n+wOUkhJqGNkPPGueGU= -cloud.google.com/go/kms v1.26.0 h1:cK9mN2cf+9V63D3H1f6koxTatWy39aTI/hCjz1I+adU= -cloud.google.com/go/kms v1.26.0/go.mod h1:pHKOdFJm63hxBsiPkYtowZPltu9dW0MWvBa6IA4HM58= -cloud.google.com/go/logging v1.13.2 h1:qqlHCBvieJT9Cdq4QqYx1KPadCQ2noD4FK02eNqHAjA= -cloud.google.com/go/logging v1.13.2/go.mod h1:zaybliM3yun1J8mU2dVQ1/qDzjbOqEijZCn6hSBtKak= +cloud.google.com/go/dataplex v1.30.0 h1:VeGEANl3ywJ7txZ79BN1BlRluzfxxyv/CbOsh2u4cVQ= +cloud.google.com/go/dataplex v1.30.0/go.mod h1:GgV6b+1viq2nMtr+AUzKNUbaR+tKGxdhVaMN8TPPu0w= +cloud.google.com/go/dataproc/v2 v2.17.0 h1:jqH6LpQaMytLb7xW6zu6GoL9v/lYWcRXqXqndgT9mXQ= +cloud.google.com/go/dataproc/v2 v2.17.0/go.mod h1:lUY58QBxs6IIScAo9ZZKSOxx3imkHBxz6dow9f4fSRM= +cloud.google.com/go/eventarc v1.19.0 h1:K5LrgI6FR1lM+YRpa3s2LSG9DxHGmShdP+ULOm90LnQ= +cloud.google.com/go/eventarc v1.19.0/go.mod h1:xO4c0cMGNC47wZPJVe7gpABvOLJF+Wt6ze2n6KM70ak= +cloud.google.com/go/filestore v1.11.0 h1:gRdVpDwzWo98WMQLLEaMBx8b6S79DH3ixD5Lz0JsNCA= +cloud.google.com/go/filestore v1.11.0/go.mod h1:SOM0F8N/VIQaI8/KLbc1evVqX8S9KQmGdLYuj9HaQns= +cloud.google.com/go/functions v1.20.0 h1:32Njh/dOxmhPclof7thv9UadyHtP+koeF8GzfrizRsY= +cloud.google.com/go/functions v1.20.0/go.mod h1:TW6jT0+yQsnI9ICkhJfz7HDFzdwtTaTOlKi1c9wQTqA= +cloud.google.com/go/iam v1.7.0 h1:JD3zh0C6LHl16aCn5Akff0+GELdp1+4hmh6ndoFLl8U= +cloud.google.com/go/iam v1.7.0/go.mod h1:tetWZW1PD/m6vcuY2Zj/aU0eCHNPuxedbnbRTyKXvdY= +cloud.google.com/go/kms v1.27.0 h1:iYYgoD0HJIqz35A+He1G0dS5qTQzQsDXFsyXwzkUCXM= +cloud.google.com/go/kms v1.27.0/go.mod h1:KPxrdf61iYEOZ86uPwR86muBpSik2y4Ion6e83fVl1Q= +cloud.google.com/go/logging v1.14.0 h1:xpPpY8cVT6n9DgIRgrWyE+YEsGlO/994pWnbc7o5Eh4= +cloud.google.com/go/logging v1.14.0/go.mod h1:jmI+Try/fZeOTOAer3wVYOuPf9WX9PyzhlSDoBAi4HM= cloud.google.com/go/longrunning v0.8.0 h1:LiKK77J3bx5gDLi4SMViHixjD2ohlkwBi+mKA7EhfW8= cloud.google.com/go/longrunning v0.8.0/go.mod h1:UmErU2Onzi+fKDg2gR7dusz11Pe26aknR4kHmJJqIfk= -cloud.google.com/go/monitoring v1.24.3 h1:dde+gMNc0UhPZD1Azu6at2e79bfdztVDS5lvhOdsgaE= -cloud.google.com/go/monitoring v1.24.3/go.mod h1:nYP6W0tm3N9H/bOw8am7t62YTzZY+zUeQ+Bi6+2eonI= -cloud.google.com/go/networksecurity v0.11.0 h1:+ahtCqEqwHw3a3UIeG21vT817xt9kkDDAO6k9+LCc18= -cloud.google.com/go/networksecurity v0.11.0/go.mod h1:JLgDsg4tOyJ3eMO8lypjqMftbfd60SJ+P7T+DUmWBsM= -cloud.google.com/go/orgpolicy v1.15.1 h1:0hq12wxNwcfUMojr5j3EjWECSInIuyYDhkAWXTomRhc= -cloud.google.com/go/orgpolicy v1.15.1/go.mod h1:bpvi9YIyU7wCW9WiXL/ZKT7pd2Ovegyr2xENIeRX5q0= -cloud.google.com/go/redis v1.18.3 h1:6LI8zSt+vmE3WQ7hE5GsJ13CbJBLV1qUw6B7CY31Wcw= -cloud.google.com/go/redis v1.18.3/go.mod h1:x8HtXZbvMBDNT6hMHaQ022Pos5d7SP7YsUH8fCJ2Wm4= -cloud.google.com/go/resourcemanager v1.10.7 h1:oPZKIdjyVTuag+D4HF7HO0mnSqcqgjcuA18xblwA0V0= -cloud.google.com/go/resourcemanager v1.10.7/go.mod h1:rScGkr6j2eFwxAjctvOP/8sqnEpDbQ9r5CKwKfomqjs= -cloud.google.com/go/run v1.16.0 h1:dPkx5oS81AC/ly4TSpRr3AYcMushvFrl8lR7jnQjzdk= -cloud.google.com/go/run v1.16.0/go.mod h1:ydUU2MjfZ64kWfzy8+GKVqXmCxMS+Ik61VQx8/FwUyY= -cloud.google.com/go/secretmanager v1.16.0 h1:19QT7ZsLJ8FSP1k+4esQvuCD7npMJml6hYzilxVyT+k= -cloud.google.com/go/secretmanager v1.16.0/go.mod h1://C/e4I8D26SDTz1f3TQcddhcmiC3rMEl0S1Cakvs3Q= -cloud.google.com/go/securitycentermanagement v1.1.6 h1:XFqjKq4ZpKTj8xCXWs/mTmh/UMWDiV25iCOUd9xaGWI= -cloud.google.com/go/securitycentermanagement v1.1.6/go.mod h1:nt5Z6rU4s2/j8R/EQxG5K7OfVAfAfwo89j0Nx2Srzaw= +cloud.google.com/go/monitoring v1.25.0 h1:HnsTIOxTN6BCSkt1P/Im23r1m7MHTTpmSYCzPkW7NK4= +cloud.google.com/go/monitoring v1.25.0/go.mod h1:wlj6rX+JGyusw/8+2duW4cJ6kmDHGmde3zMTJuG3Jpc= +cloud.google.com/go/networksecurity v0.12.0 h1:jClw7eryNWLBDlkZpo32iLt1a9XqTUsmOOu0ZxoQpkw= +cloud.google.com/go/networksecurity v0.12.0/go.mod h1:/xkjQdYYGHZGlpE+kF43RMJFx2Ak5FcUXEYXkBzuWHE= +cloud.google.com/go/orgpolicy v1.16.0 h1:MdF2ebs6H4CoxgrtPQ7XDvE1ACIeSQDY2a1/QwVgqvA= +cloud.google.com/go/orgpolicy v1.16.0/go.mod h1:l88jhhDM/KYN0FJF87m5Gub1qcOFkN94efA2uhy+U0s= +cloud.google.com/go/redis v1.19.0 h1:2ZoK/yOVAHFuRduQQR931BccLw1jd5CvjdDAI+06QTc= +cloud.google.com/go/redis v1.19.0/go.mod h1:vdL4FIUmY5DFGORAb/iBrpJWoWxygEpT9Ls8wK+8bsk= +cloud.google.com/go/resourcemanager v1.11.0 h1:9SJ/sfpUnxchKOthw4p7sXycLPjkWh2PaHHOekoMvs8= +cloud.google.com/go/resourcemanager v1.11.0/go.mod h1:jBdKCDtskipgmN1BC9wJnBzlGG7HpVIQ60cT9LEMgeQ= +cloud.google.com/go/run v1.17.0 h1:Ffdtcg+qjKmZFDuRwGOkgmd1Bkt2AHiMx4wSIxZYrtM= +cloud.google.com/go/run v1.17.0/go.mod h1:24NIu/ueXZFL3OyKJfkhp1j3VJDAEukuclpBSiL9pa0= +cloud.google.com/go/secretmanager v1.17.0 h1:rji2m9dikfOxUvYxgJ5XpSvDtwqjouqKFAPp4Hgfyto= +cloud.google.com/go/secretmanager v1.17.0/go.mod h1:ojzpR7KA2il9qcmBYaysgHsclj8nMcCL/Hc+WYxUsGA= +cloud.google.com/go/securitycentermanagement v1.2.0 h1:PEPs1pnL78b9E2Ijm9CyF9siiF8p4t9xYUSnipD4yBc= +cloud.google.com/go/securitycentermanagement v1.2.0/go.mod h1:BfW7vy7ZU3gQe+JRohYoPNCxfvdrXWXqfy3+LScSGbI= cloud.google.com/go/spanner v1.89.0 h1:r3h5Z5RA8JRPf3HCvA6ujNhREIMhPY+MrDL9mkY8jS0= cloud.google.com/go/spanner v1.89.0/go.mod h1:okNuxnp1wdPaVoM5M28Al2irKZLkHhZ2Z+DW6/ZJWGw= cloud.google.com/go/storage v1.61.3 h1:VS//ZfBuPGDvakfD9xyPW1RGF1Vy3BWUoVZXgW1KMOg= cloud.google.com/go/storage v1.61.3/go.mod h1:JtqK8BBB7TWv0HVGHubtUdzYYrakOQIsMLffZ2Z/HWk= -cloud.google.com/go/storagetransfer v1.13.1 h1:Sjukr1LtUt7vLTHNvGc2gaAqlXNFeDFRIRmWGrFaJlY= -cloud.google.com/go/storagetransfer v1.13.1/go.mod h1:S858w5l383ffkdqAqrAA+BC7KlhCqeNieK3sFf5Bj4Y= +cloud.google.com/go/storagetransfer v1.14.0 h1:03hl9UR63xarS3+oq+uXRdCzHzxYttsv8df5tTAXbwk= +cloud.google.com/go/storagetransfer v1.14.0/go.mod h1:AXV3Zq+8CDDYa5VlFB4xCnve/LnuHRnCKh2ZP8bVFm4= cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U= cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s= connectrpc.com/connect v1.18.1 h1:PAg7CjSAGvscaf6YZKUefjoih5Z/qYkyaTrBW8xvYPw= @@ -104,6 +104,8 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns v1.2.0 h1:lpOxw github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns v1.2.0/go.mod h1:fSvRkb8d26z9dbL40Uf/OO6Vo9iExtZK3D0ulRV+8M0= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/elasticsan/armelasticsan v1.2.0 h1:8xYBtaMs3Msy1bFYTVrVFBh05JUGNMMP/v3z3x5hoIw= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/elasticsan/armelasticsan v1.2.0/go.mod h1:bXxc3uCnIUCh68pl4njcH45qUgRuR0kZfR6v06k18/A= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0 h1:PTFGRSlMKCQelWwxUyYVEUqseBJVemLyqWJjvMyt0do= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v2 v2.0.0/go.mod h1:LRr2FzBTQlONPPa5HREE5+RjSCTXl7BwOvYOaWTqCaI= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.1 h1:1kpY4qe+BGAH2ykv4baVSqyx+AY5VjXeJ15SldlU6hs= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/internal/v3 v3.1.1/go.mod h1:nT6cWpWdUt+g81yuKmjeYPUtI73Ak3yQIT4PVVsCEEQ= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault/v2 v2.0.2 h1:O2iuZYGa1nIMDk2uAFR0F7hDALVXMvz8Zwarz6itQ3E= @@ -114,6 +116,8 @@ github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.3.0 h1:L7G3d github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi v1.3.0/go.mod h1:Ms6gYEy0+A2knfKrwdatsggTXYA2+ICKug8w7STorFw= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9 v9.0.0 h1:CbHDMVJhcJSmXenq+UDWyIjumzVkZIb5pVUGzsCok5M= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9 v9.0.0/go.mod h1:raqbEXrok4aycS74XoU6p9Hne1dliAFpHLizlp+qJoM= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights v1.2.0 h1:4FlNvfcPu7tTvOgOzXxIbZLvwvmZq1OdhQUdIa9g2N4= +github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights v1.2.0/go.mod h1:A4nzEXwVd5pAyneR6KOvUAo72svUc5rmCzRHhAbP6lA= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5 v5.0.0 h1:S7K+MLPEYe+g9AX9dLKldBpYV03bPl7zeDaWhiNDqqs= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5 v5.0.0/go.mod h1:EHRrmrnS2Q8fB3+DE30TTk04JLqjui5ZJEF7eMVQ2/M= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns v1.3.0 h1:yzrctSl9GMIQ5lHu7jc8olOsGjWDCsBpJhWqfGa/YIM= @@ -136,8 +140,8 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgv github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/BurntSushi/toml v1.4.0 h1:kuoIxZQy2WRRk1pttg9asf+WVv6tWQuBNVmK8+nqPr0= github.com/BurntSushi/toml v1.4.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= -github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0 h1:sBEjpZlNHzK1voKq9695PJSX2o5NEXl7/OL3coiIY0c= -github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 h1:DHa2U07rk8syqvCge0QIGMCE1WxGj9njT44GH7zNJLQ= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 h1:UnDZ/zFfG1JhH/DqxIZYU/1CUAlTUScoXD/LcM2Ykk8= github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0/go.mod h1:IA1C1U7jO/ENqm/vhi7V9YYpBsp+IMyqNrEN94N7tVc= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0 h1:7t/qx5Ost0s0wbA/VDrByOooURhp+ikYwv20i9Y07TQ= @@ -189,10 +193,10 @@ github.com/aws/aws-sdk-go-v2 v1.41.5 h1:dj5kopbwUsVUVFgO4Fi5BIT3t4WyqIDjGKCangnV github.com/aws/aws-sdk-go-v2 v1.41.5/go.mod h1:mwsPRE8ceUUpiTgF7QmQIJ7lgsKUPQOUl3o72QBrE1o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8 h1:eBMB84YGghSocM7PsjmmPffTa+1FBUeNvGvFou6V/4o= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.8/go.mod h1:lyw7GFp3qENLh7kwzf7iMzAxDn+NzjXEAGjKS2UOKqI= -github.com/aws/aws-sdk-go-v2/config v1.32.13 h1:5KgbxMaS2coSWRrx9TX/QtWbqzgQkOdEa3sZPhBhCSg= -github.com/aws/aws-sdk-go-v2/config v1.32.13/go.mod h1:8zz7wedqtCbw5e9Mi2doEwDyEgHcEE9YOJp6a8jdSMY= -github.com/aws/aws-sdk-go-v2/credentials v1.19.13 h1:mA59E3fokBvyEGHKFdnpNNrvaR351cqiHgRg+JzOSRI= -github.com/aws/aws-sdk-go-v2/credentials v1.19.13/go.mod h1:yoTXOQKea18nrM69wGF9jBdG4WocSZA1h38A+t/MAsk= +github.com/aws/aws-sdk-go-v2/config v1.32.14 h1:opVIRo/ZbbI8OIqSOKmpFaY7IwfFUOCCXBsUpJOwDdI= +github.com/aws/aws-sdk-go-v2/config v1.32.14/go.mod h1:U4/V0uKxh0Tl5sxmCBZ3AecYny4UNlVmObYjKuuaiOo= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14 h1:n+UcGWAIZHkXzYt87uMFBv/l8THYELoX6gVcUvgl6fI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.14/go.mod h1:cJKuyWB59Mqi0jM3nFYQRmnHVQIcgoxjEMAbLkpr62w= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21 h1:NUS3K4BTDArQqNu2ih7yeDLaS3bmHD0YndtA6UP884g= github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.21/go.mod h1:YWNWJQNjKigKY1RHVJCuupeWDrrHjRqHm0N9rdrWzYI= github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.21 h1:Rgg6wvjjtX8bNHcvi9OnXWwcE0a2vGpbwmtICOsvcf4= @@ -205,20 +209,20 @@ github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22 h1:rWyie/PxDRIdhNf4DzRk0lvjVOq github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.22/go.mod h1:zd/JsJ4P7oGfUhXn1VyLqaRZwPmZwg44Jf2dS84Dm3Y= github.com/aws/aws-sdk-go-v2/service/apigateway v1.39.1 h1:r3dXvi6tMfv4D48pyantOgDL48ifV6Ibj1eU1ca0C3k= github.com/aws/aws-sdk-go-v2/service/apigateway v1.39.1/go.mod h1:nhYOLBwQu7P3ckR+L4gZkY0DT0nAhrQuZkI51jR1vTE= -github.com/aws/aws-sdk-go-v2/service/autoscaling v1.64.4 h1:9ytLDWrppFYTtWVVx80nefvaf/v02yG5pT+8HGk0vv8= -github.com/aws/aws-sdk-go-v2/service/autoscaling v1.64.4/go.mod h1:Lg8BJb1TOzVTJ6RFfkJ9zyI/XFcjcfZem+Iu4PeQxPE= -github.com/aws/aws-sdk-go-v2/service/cloudfront v1.60.4 h1:IvmTOyh1CZB0Gq6fUqVwmGqy8L9GApUr+cAK/Wq4oPs= -github.com/aws/aws-sdk-go-v2/service/cloudfront v1.60.4/go.mod h1:4/Vk7LHrr16Zkvy71Th2BJPNmCMPJFP91TaGcEqywWs= -github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.55.3 h1:mymqCkKEbqQIFkhh2xPAJ8jS0rmZqegQOF7bw48b0iw= -github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.55.3/go.mod h1:+bNfizG/fpRGctZuVeH8uWht/0BLD9wUyXOKM4VaCVA= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.65.0 h1:zl3sbszfDnh6nCRJVZ6bLzBD91zeExyTAEewPRyn89k= +github.com/aws/aws-sdk-go-v2/service/autoscaling v1.65.0/go.mod h1:Lg8BJb1TOzVTJ6RFfkJ9zyI/XFcjcfZem+Iu4PeQxPE= +github.com/aws/aws-sdk-go-v2/service/cloudfront v1.61.0 h1:Yx6+Np7TIPx2/j15dWnuGadv+w11ysw5KHgKpaiZsYM= +github.com/aws/aws-sdk-go-v2/service/cloudfront v1.61.0/go.mod h1:4/Vk7LHrr16Zkvy71Th2BJPNmCMPJFP91TaGcEqywWs= +github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0 h1:ud2A364lLBkhGAC7oYw/1xg9BF4acwJC+qdLykxy83o= +github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.56.0/go.mod h1:+bNfizG/fpRGctZuVeH8uWht/0BLD9wUyXOKM4VaCVA= github.com/aws/aws-sdk-go-v2/service/directconnect v1.38.15 h1:UtMubbp/0sQ+mM8fLpsarNlAvzYOYP7BTAMaGPfaV0I= github.com/aws/aws-sdk-go-v2/service/directconnect v1.38.15/go.mod h1:NSqhUsoeEhxJxyhtfG65YZs8WJ208MDslRU+lWfTSJc= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1 h1:Vk+a1j2pXZHkkYqHmEdpwe8eX6NDtFSBGfzuauMEWYQ= github.com/aws/aws-sdk-go-v2/service/dynamodb v1.57.1/go.mod h1:wHrWCwhXZrl2PuCP5t36UTacy9fCHDJ+vw1r3qxTL5M= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.1 h1:AsKDVqIbQox9NykcAm14xUiuzAKbarnC5+PZkrB2010= -github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.1/go.mod h1:R+2BNtUfTfhPY0RH18oL02q116bakeBWjanrbnVBqkM= -github.com/aws/aws-sdk-go-v2/service/ecs v1.74.1 h1:O0hhTSsxp24mIjaNUaZ0zST98SZojDluj/Zh4RkYss4= -github.com/aws/aws-sdk-go-v2/service/ecs v1.74.1/go.mod h1:QkWmubOYmjj3cHn7A4CoUU7BKJhVeo39Gp6NH7IyhZw= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2 h1:Ytu50ChAxCiDsOlBcBq8jbczXy6+QLb07T65DBJASRs= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.296.2/go.mod h1:R+2BNtUfTfhPY0RH18oL02q116bakeBWjanrbnVBqkM= +github.com/aws/aws-sdk-go-v2/service/ecs v1.76.0 h1:a5G/TgJNrpuCjZBTf8/PTN0C2B0do/ylaYVynxPSbUQ= +github.com/aws/aws-sdk-go-v2/service/ecs v1.76.0/go.mod h1:QkWmubOYmjj3cHn7A4CoUU7BKJhVeo39Gp6NH7IyhZw= github.com/aws/aws-sdk-go-v2/service/efs v1.41.14 h1:Ql2FayQK0PspATQ7DETibPMutuLn16xecUqRkT09kyM= github.com/aws/aws-sdk-go-v2/service/efs v1.41.14/go.mod h1:4qKY0MLGqCjoOY3Wvb/J/soeJN5Tlc6uo85UuoKXqlI= github.com/aws/aws-sdk-go-v2/service/eks v1.81.2 h1:6c/Jkyx1gYLiZGl6VPjApViaoPiYo7TDWXCMk/ZBq6c= @@ -251,8 +255,8 @@ github.com/aws/aws-sdk-go-v2/service/rds v1.117.1 h1:LwcVYTKHBsQPhD0evNWtHIH8+xQ github.com/aws/aws-sdk-go-v2/service/rds v1.117.1/go.mod h1:EbQarE9odk5+EEhP2Yr6NjDEhms3PU3k9/qZ2GRpOuc= github.com/aws/aws-sdk-go-v2/service/route53 v1.62.5 h1:Z+/OLsb85Kpq7TVLCspskqePaf68Tdv6GfmJP4kH6i0= github.com/aws/aws-sdk-go-v2/service/route53 v1.62.5/go.mod h1:TmxGowuBYwjmHFOsEDxaZdsQE62JJzOmtiWafTi/czg= -github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3 h1:HwxWTbTrIHm5qY+CAEur0s/figc3qwvLWsNkF4RPToo= -github.com/aws/aws-sdk-go-v2/service/s3 v1.97.3/go.mod h1:uoA43SdFwacedBfSgfFSjjCvYe8aYBS7EnU5GZ/YKMM= +github.com/aws/aws-sdk-go-v2/service/s3 v1.98.0 h1:foqo/ocQ7WqKwy3FojGtZQJo0FR4vto9qnz9VaumbCo= +github.com/aws/aws-sdk-go-v2/service/s3 v1.98.0/go.mod h1:uoA43SdFwacedBfSgfFSjjCvYe8aYBS7EnU5GZ/YKMM= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9 h1:QKZH0S178gCmFEgst8hN0mCX1KxLgHBKKY/CLqwP8lg= github.com/aws/aws-sdk-go-v2/service/signin v1.0.9/go.mod h1:7yuQJoT+OoH8aqIxw9vwF+8KpvLZ8AWmvmUWHsGQZvI= github.com/aws/aws-sdk-go-v2/service/sns v1.39.15 h1:rOWMUrXJPcTXnk75ja6Bxv1P+j83dPhIWjfJ2cujj34= @@ -261,14 +265,14 @@ github.com/aws/aws-sdk-go-v2/service/sqs v1.42.25 h1:8Bv3TQ1Cob6HLlpUbAnWxeHhAkY github.com/aws/aws-sdk-go-v2/service/sqs v1.42.25/go.mod h1:eDstEbM0OEnBUnNQxIA7j74Jy61cCU1S4EMlCtdMwzs= github.com/aws/aws-sdk-go-v2/service/ssm v1.68.4 h1:5Wg8AAAnIWM2LE/0KFGqllZff96bm4dBs+uerYFfReE= github.com/aws/aws-sdk-go-v2/service/ssm v1.68.4/go.mod h1:nph0ypDLWm9D9iA9zOX39W/N+A4GqwzlxA13jzXVD4k= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.14 h1:GcLE9ba5ehAQma6wlopUesYg/hbcOhFNWTjELkiWkh4= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.14/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18 h1:mP49nTpfKtpXLt5SLn8Uv8z6W+03jYVoOSAl/c02nog= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.18/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15 h1:lFd1+ZSEYJZYvv9d6kXzhkZu07si3f+GQ1AaYwa2LUM= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.15/go.mod h1:WSvS1NLr7JaPunCXqpJnWk1Bjo7IxzZXrZi1QQCkuqM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19 h1:dzztQ1YmfPrxdrOiuZRMF6fuOwWlWpD2StNLTceKpys= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.19/go.mod h1:YO8TrYtFdl5w/4vmjL8zaBSsiNp3w0L1FfKVKenZT7w= github.com/aws/aws-sdk-go-v2/service/sts v1.41.10 h1:p8ogvvLugcR/zLBXTXrTkj0RYBUdErbMnAFFp12Lm/U= github.com/aws/aws-sdk-go-v2/service/sts v1.41.10/go.mod h1:60dv0eZJfeVXfbT1tFJinbHrDfSJ2GZl4Q//OSSNAVw= -github.com/aws/smithy-go v1.24.2 h1:FzA3bu/nt/vDvmnkg+R8Xl46gmzEDam6mZ1hzmwXFng= -github.com/aws/smithy-go v1.24.2/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= +github.com/aws/smithy-go v1.24.3 h1:XgOAaUgx+HhVBoP4v8n6HCQoTRDhoMghKqw4LNHsDNg= +github.com/aws/smithy-go v1.24.3/go.mod h1:YE2RhdIuDbA5E5bTdciG9KrW3+TiEONeUWCqxX9i1Fc= github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k= github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8= github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o= @@ -364,8 +368,8 @@ github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UN github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= github.com/go-git/go-git/v5 v5.16.5 h1:mdkuqblwr57kVfXri5TTH+nMFLNUxIj9Z7F5ykFbw5s= github.com/go-git/go-git/v5 v5.16.5/go.mod h1:QOMLpNf1qxuSY4StA/ArOdfFR2TrKEjJiye2kel2m+M= -github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= -github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-jose/go-jose/v4 v4.1.4 h1:moDMcTHmvE6Groj34emNPLs/qtYXRVcd6S7NHbHz3kA= +github.com/go-jose/go-jose/v4 v4.1.4/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= @@ -421,8 +425,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= -github.com/googleapis/gax-go/v2 v2.20.0 h1:NIKVuLhDlIV74muWlsMM4CcQZqN6JJ20Qcxd9YMuYcs= -github.com/googleapis/gax-go/v2 v2.20.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4= +github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI= +github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4= github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQHCoQ= github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo= github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= @@ -494,8 +498,8 @@ github.com/jarcoal/httpmock v1.3.0 h1:2RJ8GP0IIaWwcC9Fp2BmVi8Kog3v2Hn7VXM3fTd+nu github.com/jarcoal/httpmock v1.3.0/go.mod h1:3yb8rc4BI7TCBhFY8ng0gjuLKJNquuDNiPaZjnENuYg= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= -github.com/jedib0t/go-pretty/v6 v6.7.8 h1:BVYrDy5DPBA3Qn9ICT+PokP9cvCv1KaHv2i+Hc8sr5o= -github.com/jedib0t/go-pretty/v6 v6.7.8/go.mod h1:YwC5CE4fJ1HFUDeivSV1r//AmANFHyqczZk+U6BDALU= +github.com/jedib0t/go-pretty/v6 v6.7.9 h1:frarzQWmkZd97syT81+TH8INKPpzoxQnk+Mk5EIHSrM= +github.com/jedib0t/go-pretty/v6 v6.7.9/go.mod h1:YwC5CE4fJ1HFUDeivSV1r//AmANFHyqczZk+U6BDALU= github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94= github.com/jhump/protoreflect v1.17.0/go.mod h1:h9+vUUL38jiBzck8ck+6G/aeMX8Z4QUY/NiJPwPNi+8= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= @@ -748,30 +752,30 @@ go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.6 go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.63.0/go.mod h1:fvPi2qXDqFs8M4B4fmJhE92TyQs9Ydjlg3RvfUp+NbQ= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg= -go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= -go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0 h1:THuZiwpQZuHPul65w4WcwEnkX2QIuMT+UFoOrygtoJw= -go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.42.0/go.mod h1:J2pvYM5NGHofZ2/Ru6zw/TNWnEQp5crgyDeSrYpXkAw= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0 h1:uLXP+3mghfMf7XmV4PkGfFhFKuNWoCvvx5wP/wOXo0o= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.42.0/go.mod h1:v0Tj04armyT59mnURNUJf7RCKcKzq+lgJs6QSjHjaTc= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0 h1:88Y4s2C8oTui1LGM6bTWkw0ICGcOLCAI5l6zsD1j20k= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.43.0/go.mod h1:Vl1/iaggsuRlrHf/hfPJPvVag77kKyvrLeD10kpMl+A= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0 h1:3iZJKlCZufyRzPzlQhUIWVmfltrXuGyfjREgGP3UUjc= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.43.0/go.mod h1:/G+nUPfhq2e+qiXMGxMwumDrP5jtzU+mWN7/sjT2rak= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0 h1:ZrPRak/kS4xI3AVXy8F7pipuDXmDsrO8Lg+yQjBLjw0= go.opentelemetry.io/otel/exporters/stdout/stdoutmetric v1.40.0/go.mod h1:3y6kQCWztq6hyW8Z9YxQDDm0Je9AJoFar2G0yDcmhRk= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0 h1:s/1iRkCKDfhlh1JF26knRneorus8aOwVIDhvYx9WoDw= -go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.42.0/go.mod h1:UI3wi0FXg1Pofb8ZBiBLhtMzgoTm1TYkMvn71fAqDzs= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0 h1:mS47AX77OtFfKG4vtp+84kuGSFZHTyxtXIN269vChY0= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.43.0/go.mod h1:PJnsC41lAGncJlPUniSwM81gc80GkgWJWr3cu2nKEtU= go.opentelemetry.io/otel/log v0.11.0 h1:c24Hrlk5WJ8JWcwbQxdBqxZdOK7PcP/LFtOtwpDTe3Y= go.opentelemetry.io/otel/log v0.11.0/go.mod h1:U/sxQ83FPmT29trrifhQg+Zj2lo1/IPN1PF6RTFqdwc= -go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= -go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= go.opentelemetry.io/otel/schema v0.0.12 h1:X8NKrwH07Oe9SJruY/D1XmwHrb6D2+qrLs2POlZX7F4= go.opentelemetry.io/otel/schema v0.0.12/go.mod h1:+w+Q7DdGfykSNi+UU9GAQz5/rtYND6FkBJUWUXzZb0M= -go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= -go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= -go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= -go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= -go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= -go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= -go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= -go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= +go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= @@ -858,19 +862,19 @@ golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhS golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= -google.golang.org/api v0.273.0 h1:r/Bcv36Xa/te1ugaN1kdJ5LoA5Wj/cL+a4gj6FiPBjQ= -google.golang.org/api v0.273.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew= +google.golang.org/api v0.274.0 h1:aYhycS5QQCwxHLwfEHRRLf9yNsfvp1JadKKWBE54RFA= +google.golang.org/api v0.274.0/go.mod h1:JbAt7mF+XVmWu6xNP8/+CTiGH30ofmCmk9nM8d8fHew= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0= google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I= -google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI= -google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7 h1:ndE4FoJqsIceKP2oYSnUZqhTdYufCYYkqwtFzfrhI7w= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= -google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= -google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= diff --git a/go/audit/main.go b/go/audit/main.go new file mode 100644 index 00000000..fba76862 --- /dev/null +++ b/go/audit/main.go @@ -0,0 +1,132 @@ +package audit + +import ( + "bufio" + "context" + "errors" + "net" + "net/http" + + log "github.com/sirupsen/logrus" +) + +type contextKey struct{} + +// AuditData holds identity fields populated by auth middleware for +// post-request audit logging. The audit middleware places a mutable +// *AuditData in the request context before calling inner handlers; +// auth fills it after token validation so the log emitted after +// the response contains the correct identity. +type AuditData struct { + Subject string + AccountName string + Scopes string +} + +// AuditDataFromContext returns the AuditData pointer placed in context +// by the audit middleware. Returns nil when called outside the chain. +func AuditDataFromContext(ctx context.Context) *AuditData { + ad, _ := ctx.Value(contextKey{}).(*AuditData) + return ad +} + +// Option configures the audit middleware. +type Option func(*auditConfig) + +type auditConfig struct { + excludePaths map[string]bool +} + +// WithExcludePaths skips audit logging for the given exact request +// paths (e.g. "/healthz"). +func WithExcludePaths(paths ...string) Option { + return func(c *auditConfig) { + for _, p := range paths { + c.excludePaths[p] = true + } + } +} + +// statusRecorder wraps http.ResponseWriter to capture the status code. +type statusRecorder struct { + http.ResponseWriter + status int + wroteHeader bool +} + +func (sr *statusRecorder) WriteHeader(code int) { + if !sr.wroteHeader { + sr.status = code + sr.wroteHeader = true + } + sr.ResponseWriter.WriteHeader(code) +} + +func (sr *statusRecorder) Write(b []byte) (int, error) { + if !sr.wroteHeader { + sr.WriteHeader(http.StatusOK) + } + return sr.ResponseWriter.Write(b) +} + +// Unwrap returns the underlying ResponseWriter, preserving optional +// interfaces (Flusher, Hijacker, etc.) for http.ResponseController. +func (sr *statusRecorder) Unwrap() http.ResponseWriter { + return sr.ResponseWriter +} + +// Hijack implements http.Hijacker by delegating to the underlying +// ResponseWriter. This is required for WebSocket upgrade handshakes +// which do direct type assertions on the writer. +func (sr *statusRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if h, ok := sr.ResponseWriter.(http.Hijacker); ok { + return h.Hijack() + } + return nil, nil, errors.New("underlying ResponseWriter does not support hijacking") +} + +// Flush implements http.Flusher by delegating to the underlying +// ResponseWriter. This is needed for streaming responses (SSE, etc.). +func (sr *statusRecorder) Flush() { + if f, ok := sr.ResponseWriter.(http.Flusher); ok { + f.Flush() + } +} + +// NewAuditMiddleware returns middleware that emits a structured audit +// log entry after each request completes. Identity fields (sub, account, +// scopes) are populated by auth middleware via [AuditDataFromContext]. +// +// The middleware must wrap the handler chain from outside otelhttp so +// that audit logs are not exported to the tracing backend. +func NewAuditMiddleware(logger *log.Logger, opts ...Option) func(next http.Handler) http.Handler { + cfg := &auditConfig{excludePaths: make(map[string]bool)} + for _, o := range opts { + o(cfg) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if cfg.excludePaths[r.URL.Path] { + next.ServeHTTP(w, r) + return + } + + ad := &AuditData{} + ctx := context.WithValue(r.Context(), contextKey{}, ad) + + rec := &statusRecorder{ResponseWriter: w, status: http.StatusOK} + next.ServeHTTP(rec, r.WithContext(ctx)) + + logger.WithContext(ctx). + WithField("method", r.Method). + WithField("url", r.URL.String()). + WithField("status", rec.status). + WithField("sub", ad.Subject). + WithField("account", ad.AccountName). + WithField("ovm.audit", true). + WithField("scopes", ad.Scopes). + Info("audit") + }) + } +} diff --git a/go/audit/main_test.go b/go/audit/main_test.go new file mode 100644 index 00000000..b83bf9a2 --- /dev/null +++ b/go/audit/main_test.go @@ -0,0 +1,264 @@ +package audit + +import ( + "bufio" + "bytes" + "encoding/json" + "net" + "net/http" + "net/http/httptest" + "testing" + + log "github.com/sirupsen/logrus" +) + +func TestAuditMiddleware_AuthenticatedRequest(t *testing.T) { + var buf bytes.Buffer + testLogger := log.New() + testLogger.SetOutput(&buf) + testLogger.SetFormatter(&log.JSONFormatter{}) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if ad := AuditDataFromContext(r.Context()); ad != nil { + ad.Subject = "auth0|user123" + ad.AccountName = "acme-corp" + ad.Scopes = "read:items write:items" + } + w.WriteHeader(http.StatusOK) + }) + + mw := NewAuditMiddleware(testLogger) + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/api/items", nil) + mw(inner).ServeHTTP(rec, req) + + var entry map[string]any + if err := json.Unmarshal(buf.Bytes(), &entry); err != nil { + t.Fatalf("failed to unmarshal log entry: %v", err) + } + if entry["method"] != "GET" { + t.Errorf("expected method GET, got %q", entry["method"]) + } + if entry["url"] != "/api/items" { + t.Errorf("expected url /api/items, got %q", entry["url"]) + } + if entry["sub"] != "auth0|user123" { + t.Errorf("expected sub auth0|user123, got %q", entry["sub"]) + } + if entry["account"] != "acme-corp" { + t.Errorf("expected account acme-corp, got %q", entry["account"]) + } + if entry["scopes"] != "read:items write:items" { + t.Errorf("expected scopes 'read:items write:items', got %q", entry["scopes"]) + } + if entry["ovm.audit"] != true { + t.Errorf("expected ovm.audit true, got %v", entry["ovm.audit"]) + } + if entry["status"] != float64(http.StatusOK) { + t.Errorf("expected status 200, got %v", entry["status"]) + } +} + +func TestAuditMiddleware_UnauthenticatedRequest(t *testing.T) { + var buf bytes.Buffer + testLogger := log.New() + testLogger.SetOutput(&buf) + testLogger.SetFormatter(&log.JSONFormatter{}) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }) + + mw := NewAuditMiddleware(testLogger) + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/api/secret", nil) + mw(inner).ServeHTTP(rec, req) + + var entry map[string]any + if err := json.Unmarshal(buf.Bytes(), &entry); err != nil { + t.Fatalf("failed to unmarshal log entry: %v", err) + } + if entry["sub"] != "" { + t.Errorf("expected empty sub for unauthenticated request, got %q", entry["sub"]) + } + if entry["account"] != "" { + t.Errorf("expected empty account for unauthenticated request, got %q", entry["account"]) + } + if entry["status"] != float64(http.StatusUnauthorized) { + t.Errorf("expected status 401, got %v", entry["status"]) + } +} + +func TestAuditMiddleware_ExcludedPath(t *testing.T) { + var buf bytes.Buffer + testLogger := log.New() + testLogger.SetOutput(&buf) + testLogger.SetFormatter(&log.JSONFormatter{}) + + called := false + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + }) + + mw := NewAuditMiddleware(testLogger, WithExcludePaths("/healthz")) + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/healthz", nil) + mw(inner).ServeHTTP(rec, req) + + if !called { + t.Error("inner handler was not called for excluded path") + } + if buf.Len() > 0 { + t.Errorf("expected no audit log for excluded path, got: %s", buf.String()) + } +} + +func TestAuditMiddleware_NonExcludedPathStillLogged(t *testing.T) { + var buf bytes.Buffer + testLogger := log.New() + testLogger.SetOutput(&buf) + testLogger.SetFormatter(&log.JSONFormatter{}) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + mw := NewAuditMiddleware(testLogger, WithExcludePaths("/healthz")) + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodPost, "/api/changes", nil) + mw(inner).ServeHTTP(rec, req) + + if buf.Len() == 0 { + t.Error("expected audit log for non-excluded path") + } +} + +func TestAuditMiddleware_CapturesStatusCode(t *testing.T) { + var buf bytes.Buffer + testLogger := log.New() + testLogger.SetOutput(&buf) + testLogger.SetFormatter(&log.JSONFormatter{}) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + }) + + mw := NewAuditMiddleware(testLogger) + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodDelete, "/api/admin/user", nil) + mw(inner).ServeHTTP(rec, req) + + var entry map[string]any + if err := json.Unmarshal(buf.Bytes(), &entry); err != nil { + t.Fatalf("failed to unmarshal log entry: %v", err) + } + if entry["status"] != float64(http.StatusForbidden) { + t.Errorf("expected status 403, got %v", entry["status"]) + } + if entry["method"] != "DELETE" { + t.Errorf("expected method DELETE, got %q", entry["method"]) + } +} + +func TestAuditMiddleware_DefaultStatusIs200(t *testing.T) { + var buf bytes.Buffer + testLogger := log.New() + testLogger.SetOutput(&buf) + testLogger.SetFormatter(&log.JSONFormatter{}) + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write([]byte("ok")) + }) + + mw := NewAuditMiddleware(testLogger) + rec := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/api/items", nil) + mw(inner).ServeHTTP(rec, req) + + var entry map[string]any + if err := json.Unmarshal(buf.Bytes(), &entry); err != nil { + t.Fatalf("failed to unmarshal log entry: %v", err) + } + if entry["status"] != float64(http.StatusOK) { + t.Errorf("expected status 200 when handler writes body without explicit WriteHeader, got %v", entry["status"]) + } +} + +func TestAuditDataFromContext_NilOutsideMiddleware(t *testing.T) { + if ad := AuditDataFromContext(t.Context()); ad != nil { + t.Error("expected nil AuditData outside audit middleware chain") + } +} + +func TestStatusRecorder_Hijack(t *testing.T) { + hijacked := false + mock := &mockHijackWriter{ + ResponseWriter: httptest.NewRecorder(), + hijackFunc: func() (net.Conn, *bufio.ReadWriter, error) { + hijacked = true + return nil, nil, nil + }, + } + + var w http.ResponseWriter = &statusRecorder{ResponseWriter: mock, status: http.StatusOK} + + h, ok := w.(http.Hijacker) + if !ok { + t.Fatal("statusRecorder should implement http.Hijacker") + } + + _, _, err := h.Hijack() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !hijacked { + t.Error("expected Hijack to be delegated to underlying writer") + } +} + +func TestStatusRecorder_HijackNotSupported(t *testing.T) { + var w http.ResponseWriter = &statusRecorder{ResponseWriter: httptest.NewRecorder(), status: http.StatusOK} + + _, _, err := w.(http.Hijacker).Hijack() + if err == nil { + t.Error("expected error when underlying writer doesn't support Hijack") + } +} + +func TestStatusRecorder_Flush(t *testing.T) { + flushed := false + mock := &mockFlushWriter{ + ResponseWriter: httptest.NewRecorder(), + flushFunc: func() { flushed = true }, + } + + var w http.ResponseWriter = &statusRecorder{ResponseWriter: mock, status: http.StatusOK} + + f, ok := w.(http.Flusher) + if !ok { + t.Fatal("statusRecorder should implement http.Flusher") + } + + f.Flush() + if !flushed { + t.Error("expected Flush to be delegated to underlying writer") + } +} + +type mockHijackWriter struct { + http.ResponseWriter + hijackFunc func() (net.Conn, *bufio.ReadWriter, error) +} + +func (m *mockHijackWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return m.hijackFunc() +} + +type mockFlushWriter struct { + http.ResponseWriter + flushFunc func() +} + +func (m *mockFlushWriter) Flush() { + m.flushFunc() +} diff --git a/go/auth/auth.go b/go/auth/auth.go index 9b7e2ce2..b8732409 100644 --- a/go/auth/auth.go +++ b/go/auth/auth.go @@ -130,6 +130,10 @@ type Auth0Config struct { ClientID string ClientSecret string Audience string + // ManagementAudience is the Auth0 tenant hostname for the Management API. + // Token endpoint: https://{ManagementAudience}/oauth/token + // API audience: https://{ManagementAudience}/api/v2/ + ManagementAudience string } // ImpersonationHTTPClient creates an HTTP client that can impersonate the specified account. diff --git a/go/auth/middleware.go b/go/auth/middleware.go index f4509360..d7e8a319 100644 --- a/go/auth/middleware.go +++ b/go/auth/middleware.go @@ -15,6 +15,7 @@ import ( "github.com/auth0/go-jwt-middleware/v3/jwks" "github.com/auth0/go-jwt-middleware/v3/validator" "github.com/getsentry/sentry-go" + "github.com/overmindtech/cli/go/audit" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" @@ -192,6 +193,18 @@ func NewAuthMiddleware(config MiddlewareConfig, next http.Handler) http.Handler ctx = OverrideAuth(r.Context(), options...) } + if ad := audit.AuditDataFromContext(ctx); ad != nil { + if sub, ok := ctx.Value(CurrentSubjectContextKey{}).(string); ok { + ad.Subject = sub + } + if account, ok := ctx.Value(AccountNameContextKey{}).(string); ok { + ad.AccountName = account + } + if claims, ok := ctx.Value(CustomClaimsContextKey{}).(*CustomClaims); ok { + ad.Scopes = claims.Scope + } + } + r = r.Clone(ctx) next.ServeHTTP(w, r) @@ -218,6 +231,14 @@ func WithAccount(account string) OverrideAuthOptionFunc { }) } +// Sets the subject (typically the Auth0 user_id from the token's sub claim) +// in the context. +func WithSubject(subject string) OverrideAuthOptionFunc { + return func(ctx context.Context) context.Context { + return context.WithValue(ctx, CurrentSubjectContextKey{}, subject) + } +} + // Sets the auth info in the context directly from the validated claims produced // by the `github.com/auth0/go-jwt-middleware/v3/validator` package. This is // essentially what the middleware already does when receiving a request, and @@ -296,7 +317,7 @@ func ensureValidTokenHandler(config MiddlewareConfig, next http.Handler) http.Ha return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { span := trace.SpanFromContext(r.Context()) span.SetAttributes(attribute.Bool("ovm.auth.bypass", true)) - ctx := OverrideAuth(r.Context(), WithBypassScopeCheck()) + ctx := OverrideAuth(r.Context(), WithBypassScopeCheck(), WithSubject("auth-bypass")) next.ServeHTTP(w, r.Clone(ctx)) }) } @@ -487,7 +508,7 @@ func ensureValidTokenHandler(config MiddlewareConfig, next http.Handler) http.Ha span.SetAttributes(attribute.Bool("ovm.auth.bypass", shouldBypass)) if shouldBypass { - ctx = OverrideAuth(ctx, WithBypassScopeCheck()) + ctx = OverrideAuth(ctx, WithBypassScopeCheck(), WithSubject("auth-bypass")) r = r.Clone(ctx) diff --git a/go/auth/middleware_test.go b/go/auth/middleware_test.go index 5b00928c..3ab6e19e 100644 --- a/go/auth/middleware_test.go +++ b/go/auth/middleware_test.go @@ -6,6 +6,7 @@ import ( "crypto/rsa" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "regexp" @@ -15,6 +16,7 @@ import ( "github.com/auth0/go-jwt-middleware/v3/validator" "github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4/jwt" + "github.com/overmindtech/cli/go/audit" log "github.com/sirupsen/logrus" ) @@ -448,6 +450,124 @@ func TestNewAuthMiddleware(t *testing.T) { } } +// TestBypassAuthInjectsSubject verifies the BypassAuth code path (local/dev +// environments only — never runs in production where real JWTs provide the +// subject). It ensures a synthetic "auth-bypass" subject is injected into +// CurrentSubjectContextKey so handlers like Area51 job scheduling and feature +// flags work without a JWT. +func TestBypassAuthInjectsSubject(t *testing.T) { + t.Parallel() + + bypassConfig := MiddlewareConfig{ + BypassAuth: true, + } + + var capturedSubject string + handler := NewAuthMiddleware(bypassConfig, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if subj, ok := r.Context().Value(CurrentSubjectContextKey{}).(string); ok { + capturedSubject = subj + } + w.WriteHeader(http.StatusOK) + })) + + t.Run("injects default subject", func(t *testing.T) { + capturedSubject = "" + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200, got %d", rr.Code) + } + if capturedSubject != "auth-bypass" { + t.Errorf("expected subject %q, got %q", "auth-bypass", capturedSubject) + } + }) + + t.Run("scope check is bypassed", func(t *testing.T) { + rr := httptest.NewRecorder() + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) + if err != nil { + t.Fatal(err) + } + + scopeHandler := NewAuthMiddleware(bypassConfig, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !HasAllScopes(r.Context(), "any:scope") { + w.WriteHeader(http.StatusForbidden) + return + } + w.WriteHeader(http.StatusOK) + })) + scopeHandler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Errorf("expected 200 (scope check bypassed), got %d", rr.Code) + } + }) +} + +func TestWithSubject(t *testing.T) { + t.Parallel() + + t.Run("sets subject in context", func(t *testing.T) { + ctx := OverrideAuth(context.Background(), WithSubject("auth0|user-123")) + + subject, ok := ctx.Value(CurrentSubjectContextKey{}).(string) + if !ok { + t.Fatal("expected CurrentSubjectContextKey to be set") + } + if subject != "auth0|user-123" { + t.Errorf("expected subject %q, got %q", "auth0|user-123", subject) + } + }) + + t.Run("last WithSubject wins", func(t *testing.T) { + ctx := OverrideAuth(context.Background(), + WithSubject("first"), + WithSubject("second"), + ) + + subject, ok := ctx.Value(CurrentSubjectContextKey{}).(string) + if !ok { + t.Fatal("expected CurrentSubjectContextKey to be set") + } + if subject != "second" { + t.Errorf("expected subject %q, got %q", "second", subject) + } + }) + + t.Run("composes with other options", func(t *testing.T) { + ctx := OverrideAuth(context.Background(), + WithScope("api:read"), + WithAccount("test-account"), + WithSubject("auth0|user-456"), + ) + + subject, ok := ctx.Value(CurrentSubjectContextKey{}).(string) + if !ok { + t.Fatal("expected CurrentSubjectContextKey to be set") + } + if subject != "auth0|user-456" { + t.Errorf("expected subject %q, got %q", "auth0|user-456", subject) + } + + accountName, err := ExtractAccount(ctx) + if err != nil { + t.Fatal(err) + } + if accountName != "test-account" { + t.Errorf("expected account %q, got %q", "test-account", accountName) + } + + if !HasAllScopes(ctx, "api:read") { + t.Error("expected api:read scope to be present") + } + }) +} + func TestOverrideAuth(t *testing.T) { tests := []struct { Name string @@ -553,7 +673,6 @@ func BenchmarkAuthMiddleware(b *testing.B) { // Create a request to pass to our handler. We don't have any query parameters for now, so we'll // pass 'nil' as the third parameter. req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/", nil) - if err != nil { b.Fatal(err) } @@ -595,7 +714,6 @@ func NewTestJWTServer() (*TestJWTServer, error) { Key: jwk, } signer, err := jose.NewSigner(signingKey, &jose.SignerOptions{}) - if err != nil { return nil, err } @@ -877,3 +995,141 @@ func TestConnectErrorHandling(t *testing.T) { }) } } + +func TestAuthMiddleware_PopulatesAuditData(t *testing.T) { + server, err := NewTestJWTServer() + if err != nil { + t.Fatal(err) + } + + jwksURL := server.Start(t.Context()) + + discardLogger := log.New() + discardLogger.SetOutput(io.Discard) + + t.Run("populates audit data from JWT", func(t *testing.T) { + var capturedAD *audit.AuditData + + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAD = audit.AuditDataFromContext(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + handler := audit.NewAuditMiddleware(discardLogger)( + NewAuthMiddleware(MiddlewareConfig{ + IssuerURL: jwksURL, + Auth0Audience: "https://api.overmind.tech", + }, inner), + ) + + token, err := server.GenerateJWT(&TestTokenOptions{ + Audience: []string{"https://api.overmind.tech"}, + Expiry: time.Now().Add(time.Hour), + CustomClaims: CustomClaims{ + AccountName: "acme-corp", + Scope: "read:items write:items", + }, + }) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/api/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + if capturedAD == nil { + t.Fatal("expected audit data to be present in context") + } + if capturedAD.Subject != "test" { + t.Errorf("expected subject 'test', got %q", capturedAD.Subject) + } + if capturedAD.AccountName != "acme-corp" { + t.Errorf("expected account 'acme-corp', got %q", capturedAD.AccountName) + } + if capturedAD.Scopes != "read:items write:items" { + t.Errorf("expected scopes 'read:items write:items', got %q", capturedAD.Scopes) + } + }) + + t.Run("populates audit data with account override", func(t *testing.T) { + var capturedAD *audit.AuditData + + override := "override-acme" + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAD = audit.AuditDataFromContext(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + handler := audit.NewAuditMiddleware(discardLogger)( + NewAuthMiddleware(MiddlewareConfig{ + IssuerURL: jwksURL, + Auth0Audience: "https://api.overmind.tech", + AccountOverride: &override, + }, inner), + ) + + token, err := server.GenerateJWT(&TestTokenOptions{ + Audience: []string{"https://api.overmind.tech"}, + Expiry: time.Now().Add(time.Hour), + CustomClaims: CustomClaims{ + AccountName: "original-acme", + Scope: "read:items", + }, + }) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/api/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200, got %d", rr.Code) + } + if capturedAD == nil { + t.Fatal("expected audit data to be present in context") + } + if capturedAD.AccountName != "override-acme" { + t.Errorf("expected overridden account 'override-acme', got %q", capturedAD.AccountName) + } + }) + + t.Run("works without audit context", func(t *testing.T) { + inner := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + }) + + handler := NewAuthMiddleware(MiddlewareConfig{ + IssuerURL: jwksURL, + Auth0Audience: "https://api.overmind.tech", + }, inner) + + token, err := server.GenerateJWT(&TestTokenOptions{ + Audience: []string{"https://api.overmind.tech"}, + Expiry: time.Now().Add(time.Hour), + CustomClaims: CustomClaims{ + AccountName: "acme-corp", + Scope: "read:items", + }, + }) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/api/test", nil) + req.Header.Set("Authorization", "Bearer "+token) + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("expected 200 (no panic without audit context), got %d", rr.Code) + } + }) +} diff --git a/go/discovery/cmd.go b/go/discovery/cmd.go index 7ad794a7..af848416 100644 --- a/go/discovery/cmd.go +++ b/go/discovery/cmd.go @@ -234,7 +234,7 @@ func (ec *EngineConfig) CreateClients() error { if ec.Unauthenticated { log.Warn("Using unauthenticated NATS as ALLOW_UNAUTHENTICATED is set") if ec.NATSOptions != nil { - log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") + log.WithField("config", fmt.Sprintf("%v", MapFromEngineConfig(ec))).Info("Engine config") } return nil } @@ -277,7 +277,7 @@ func (ec *EngineConfig) CreateClients() error { } if ec.NATSOptions != nil { - log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") + log.WithField("config", fmt.Sprintf("%v", MapFromEngineConfig(ec))).Info("Engine config") } return nil case sdp.SourceManaged_MANAGED: @@ -316,7 +316,7 @@ func (ec *EngineConfig) CreateClients() error { } if ec.NATSOptions != nil { - log.WithFields(MapFromEngineConfig(ec)).Info("Engine config") + log.WithField("config", fmt.Sprintf("%v", MapFromEngineConfig(ec))).Info("Engine config") } return nil } diff --git a/go/discovery/engine.go b/go/discovery/engine.go index eeef3dbe..05d79414 100644 --- a/go/discovery/engine.go +++ b/go/discovery/engine.go @@ -854,7 +854,7 @@ func (e *Engine) InitialiseAdapters(ctx context.Context, initFn func(ctx context // This checks only engine initialization (NATS connection, heartbeats) and does NOT check adapter-specific health. func (e *Engine) LivenessProbeHandlerFunc() func(http.ResponseWriter, *http.Request) { return func(rw http.ResponseWriter, r *http.Request) { - ctx, span := tracing.HealthCheckTracer().Start(r.Context(), "healthcheck.liveness") + ctx, span := tracing.Tracer().Start(r.Context(), "healthcheck.liveness") defer span.End() err := e.LivenessHealthCheck(ctx) @@ -880,7 +880,7 @@ func (e *Engine) SetReadinessCheck(check func(context.Context) error) { // This checks adapter-specific health only (not engine/liveness). func (e *Engine) ReadinessProbeHandlerFunc() func(http.ResponseWriter, *http.Request) { return func(rw http.ResponseWriter, r *http.Request) { - ctx, span := tracing.HealthCheckTracer().Start(r.Context(), "healthcheck.readiness") + ctx, span := tracing.Tracer().Start(r.Context(), "healthcheck.readiness") defer span.End() err := e.ReadinessHealthCheck(ctx) diff --git a/go/discovery/heartbeat.go b/go/discovery/heartbeat.go index 0a9ab3dd..2addd4d1 100644 --- a/go/discovery/heartbeat.go +++ b/go/discovery/heartbeat.go @@ -11,7 +11,6 @@ import ( "github.com/overmindtech/cli/go/tracing" log "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" "google.golang.org/protobuf/types/known/durationpb" ) @@ -26,8 +25,8 @@ var ErrNoHealthcheckDefined = errors.New("no healthcheck defined") // to indicate that the engine is in an error state, this will be sent to the // management API and will be displayed in the UI. func (e *Engine) SendHeartbeat(ctx context.Context, customErr error) error { - // Get span from context - span := trace.SpanFromContext(ctx) + ctx, span := tracer.Start(ctx, "SendHeartbeat") + defer span.End() // Read memory stats and add them to the span memStats := tracing.ReadMemoryStats() diff --git a/go/sdp-go/changetimeline.go b/go/sdp-go/changetimeline.go index d76c5b4f..281dd7ff 100644 --- a/go/sdp-go/changetimeline.go +++ b/go/sdp-go/changetimeline.go @@ -67,11 +67,6 @@ var ( Label: "calculated_labels", Name: "Apply auto labels", } - // Tracks the application of auto tags for a change - ChangeTimelineEntryV2IDAutoTagging = ChangeTimelineEntryV2ID{ - Label: "auto_tagging", - Name: "Auto Tagging", - } // Tracks the validation of a change. This happens after the change is // complete and at time of writing is not generally available ChangeTimelineEntryV2IDChangeValidation = ChangeTimelineEntryV2ID{ @@ -124,7 +119,6 @@ var allChangeTimelineEntryV2IDs = []ChangeTimelineEntryV2ID{ ChangeTimelineEntryV2IDAnalyzedSignals, ChangeTimelineEntryV2IDCalculatedRisks, ChangeTimelineEntryV2IDCalculatedLabels, - ChangeTimelineEntryV2IDAutoTagging, ChangeTimelineEntryV2IDChangeValidation, ChangeTimelineEntryV2IDRecordObservations, ChangeTimelineEntryV2IDFormHypotheses, diff --git a/go/sdp-go/changetimeline_test.go b/go/sdp-go/changetimeline_test.go index 391d4cb4..b06a7460 100644 --- a/go/sdp-go/changetimeline_test.go +++ b/go/sdp-go/changetimeline_test.go @@ -52,11 +52,6 @@ func TestChangeTimelineEntryNameConversion(t *testing.T) { entryID: ChangeTimelineEntryV2IDCalculatedRisks, hasInProgressVariant: false, }, - { - name: "Auto Tagging (no in-progress variant)", - entryID: ChangeTimelineEntryV2IDAutoTagging, - hasInProgressVariant: false, - }, { name: "Change Validation (no in-progress variant)", entryID: ChangeTimelineEntryV2IDChangeValidation, diff --git a/go/sdp-go/config.pb.go b/go/sdp-go/config.pb.go index 79669c5a..373fe118 100644 --- a/go/sdp-go/config.pb.go +++ b/go/sdp-go/config.pb.go @@ -1154,9 +1154,12 @@ type SignalConfig struct { // Config for Github app profile, such as primary branch name GithubOrganisationProfile *GithubOrganisationProfile `protobuf:"bytes,3,opt,name=githubOrganisationProfile,proto3,oneof" json:"githubOrganisationProfile,omitempty"` // Controls the GitHub Check Run pass/fail conclusion criteria - CheckRunMode CheckRunMode `protobuf:"varint,4,opt,name=check_run_mode,json=checkRunMode,proto3,enum=config.CheckRunMode" json:"check_run_mode,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + CheckRunMode CheckRunMode `protobuf:"varint,4,opt,name=check_run_mode,json=checkRunMode,proto3,enum=config.CheckRunMode" json:"check_run_mode,omitempty"` + // Whether GitHub Check Runs are enabled for this account. + // Defaults to false (disabled). Customer must explicitly enable via settings. + CheckRunsEnabled bool `protobuf:"varint,5,opt,name=check_runs_enabled,json=checkRunsEnabled,proto3" json:"check_runs_enabled,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *SignalConfig) Reset() { @@ -1217,6 +1220,13 @@ func (x *SignalConfig) GetCheckRunMode() CheckRunMode { return CheckRunMode_CHECK_RUN_MODE_REPORT_ONLY } +func (x *SignalConfig) GetCheckRunsEnabled() bool { + if x != nil { + return x.CheckRunsEnabled + } + return false +} + type AggregationConfig struct { state protoimpl.MessageState `protogen:"open.v1"` // Alpha parameter for aggregation: controls the weighting of recent data versus older data @@ -1402,9 +1412,13 @@ type GithubAppInformation struct { RequestedAt *timestamppb.Timestamp `protobuf:"bytes,11,opt,name=requestedAt,proto3,oneof" json:"requestedAt,omitempty"` RequestedBy *string `protobuf:"bytes,12,opt,name=requestedBy,proto3,oneof" json:"requestedBy,omitempty"` // Suspended status (true when GitHub org admin has suspended the installation) - Suspended *bool `protobuf:"varint,13,opt,name=suspended,proto3,oneof" json:"suspended,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Suspended *bool `protobuf:"varint,13,opt,name=suspended,proto3,oneof" json:"suspended,omitempty"` + // Whether the installation has checks:write permission. + // Set by GetGithubAppInformation; used by the frontend to show + // the check runs section vs a permission prompt. + CanCreateChecks *bool `protobuf:"varint,14,opt,name=can_create_checks,json=canCreateChecks,proto3,oneof" json:"can_create_checks,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GithubAppInformation) Reset() { @@ -1528,6 +1542,13 @@ func (x *GithubAppInformation) GetSuspended() bool { return false } +func (x *GithubAppInformation) GetCanCreateChecks() bool { + if x != nil && x.CanCreateChecks != nil { + return *x.CanCreateChecks + } + return false +} + // this is all the information required to display the github app information type GetGithubAppInformationResponse struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1708,80 +1729,6 @@ func (x *RegenerateGithubAppProfileResponse) GetGithubOrganisationProfile() *Git return nil } -// no parameters required, we will look up the installation ID from the account ID -type DeleteGithubAppProfileAndGithubInstallationIDRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *DeleteGithubAppProfileAndGithubInstallationIDRequest) Reset() { - *x = DeleteGithubAppProfileAndGithubInstallationIDRequest{} - mi := &file_config_proto_msgTypes[28] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *DeleteGithubAppProfileAndGithubInstallationIDRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*DeleteGithubAppProfileAndGithubInstallationIDRequest) ProtoMessage() {} - -func (x *DeleteGithubAppProfileAndGithubInstallationIDRequest) ProtoReflect() protoreflect.Message { - mi := &file_config_proto_msgTypes[28] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use DeleteGithubAppProfileAndGithubInstallationIDRequest.ProtoReflect.Descriptor instead. -func (*DeleteGithubAppProfileAndGithubInstallationIDRequest) Descriptor() ([]byte, []int) { - return file_config_proto_rawDescGZIP(), []int{28} -} - -// status codes to indicate if the deletion was successful -type DeleteGithubAppProfileAndGithubInstallationIDResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache -} - -func (x *DeleteGithubAppProfileAndGithubInstallationIDResponse) Reset() { - *x = DeleteGithubAppProfileAndGithubInstallationIDResponse{} - mi := &file_config_proto_msgTypes[29] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) -} - -func (x *DeleteGithubAppProfileAndGithubInstallationIDResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*DeleteGithubAppProfileAndGithubInstallationIDResponse) ProtoMessage() {} - -func (x *DeleteGithubAppProfileAndGithubInstallationIDResponse) ProtoReflect() protoreflect.Message { - mi := &file_config_proto_msgTypes[29] - if x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use DeleteGithubAppProfileAndGithubInstallationIDResponse.ProtoReflect.Descriptor instead. -func (*DeleteGithubAppProfileAndGithubInstallationIDResponse) Descriptor() ([]byte, []int) { - return file_config_proto_rawDescGZIP(), []int{29} -} - // No parameters required — the account is determined from the caller's auth context. type CreateGithubInstallURLRequest struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1791,7 +1738,7 @@ type CreateGithubInstallURLRequest struct { func (x *CreateGithubInstallURLRequest) Reset() { *x = CreateGithubInstallURLRequest{} - mi := &file_config_proto_msgTypes[30] + mi := &file_config_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1803,7 +1750,7 @@ func (x *CreateGithubInstallURLRequest) String() string { func (*CreateGithubInstallURLRequest) ProtoMessage() {} func (x *CreateGithubInstallURLRequest) ProtoReflect() protoreflect.Message { - mi := &file_config_proto_msgTypes[30] + mi := &file_config_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1816,7 +1763,7 @@ func (x *CreateGithubInstallURLRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateGithubInstallURLRequest.ProtoReflect.Descriptor instead. func (*CreateGithubInstallURLRequest) Descriptor() ([]byte, []int) { - return file_config_proto_rawDescGZIP(), []int{30} + return file_config_proto_rawDescGZIP(), []int{28} } type CreateGithubInstallURLResponse struct { @@ -1833,7 +1780,7 @@ type CreateGithubInstallURLResponse struct { func (x *CreateGithubInstallURLResponse) Reset() { *x = CreateGithubInstallURLResponse{} - mi := &file_config_proto_msgTypes[31] + mi := &file_config_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1845,7 +1792,7 @@ func (x *CreateGithubInstallURLResponse) String() string { func (*CreateGithubInstallURLResponse) ProtoMessage() {} func (x *CreateGithubInstallURLResponse) ProtoReflect() protoreflect.Message { - mi := &file_config_proto_msgTypes[31] + mi := &file_config_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1858,7 +1805,7 @@ func (x *CreateGithubInstallURLResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateGithubInstallURLResponse.ProtoReflect.Descriptor instead. func (*CreateGithubInstallURLResponse) Descriptor() ([]byte, []int) { - return file_config_proto_rawDescGZIP(), []int{31} + return file_config_proto_rawDescGZIP(), []int{29} } func (x *CreateGithubInstallURLResponse) GetInstallUrl() string { @@ -1924,12 +1871,13 @@ const file_config_proto_rawDesc = "" + "\x19UpdateSignalConfigRequest\x12,\n" + "\x06config\x18\x01 \x01(\v2\x14.config.SignalConfigR\x06config\"J\n" + "\x1aUpdateSignalConfigResponse\x12,\n" + - "\x06config\x18\x01 \x01(\v2\x14.config.SignalConfigR\x06config\"\xe9\x02\n" + + "\x06config\x18\x01 \x01(\v2\x14.config.SignalConfigR\x06config\"\x97\x03\n" + "\fSignalConfig\x12G\n" + "\x11aggregationConfig\x18\x01 \x01(\v2\x19.config.AggregationConfigR\x11aggregationConfig\x12P\n" + "\x14routineChangesConfig\x18\x02 \x01(\v2\x1c.config.RoutineChangesConfigR\x14routineChangesConfig\x12d\n" + "\x19githubOrganisationProfile\x18\x03 \x01(\v2!.config.GithubOrganisationProfileH\x00R\x19githubOrganisationProfile\x88\x01\x01\x12:\n" + - "\x0echeck_run_mode\x18\x04 \x01(\x0e2\x14.config.CheckRunModeR\fcheckRunModeB\x1c\n" + + "\x0echeck_run_mode\x18\x04 \x01(\x0e2\x14.config.CheckRunModeR\fcheckRunMode\x12,\n" + + "\x12check_runs_enabled\x18\x05 \x01(\bR\x10checkRunsEnabledB\x1c\n" + "\x1a_githubOrganisationProfile\"5\n" + "\x11AggregationConfig\x12 \n" + "\x05alpha\x18\x01 \x01(\x02B\n" + @@ -1946,7 +1894,7 @@ const file_config_proto_rawDesc = "" + "\x05WEEKS\x10\x01\x12\n" + "\n" + "\x06MONTHS\x10\x02\" \n" + - "\x1eGetGithubAppInformationRequest\"\xcb\x05\n" + + "\x1eGetGithubAppInformationRequest\"\x92\x06\n" + "\x14GithubAppInformation\x12&\n" + "\x0einstallationID\x18\x01 \x01(\x03R\x0einstallationID\x12 \n" + "\vinstalledBy\x18\x02 \x01(\tR\vinstalledBy\x12<\n" + @@ -1961,12 +1909,14 @@ const file_config_proto_rawDesc = "" + " \x01(\tH\x00R\x10requestedOrgName\x88\x01\x01\x12A\n" + "\vrequestedAt\x18\v \x01(\v2\x1a.google.protobuf.TimestampH\x01R\vrequestedAt\x88\x01\x01\x12%\n" + "\vrequestedBy\x18\f \x01(\tH\x02R\vrequestedBy\x88\x01\x01\x12!\n" + - "\tsuspended\x18\r \x01(\bH\x03R\tsuspended\x88\x01\x01B\x13\n" + + "\tsuspended\x18\r \x01(\bH\x03R\tsuspended\x88\x01\x01\x12/\n" + + "\x11can_create_checks\x18\x0e \x01(\bH\x04R\x0fcanCreateChecks\x88\x01\x01B\x13\n" + "\x11_requestedOrgNameB\x0e\n" + "\f_requestedAtB\x0e\n" + "\f_requestedByB\f\n" + "\n" + - "_suspended\"s\n" + + "_suspendedB\x14\n" + + "\x12_can_create_checks\"s\n" + "\x1fGetGithubAppInformationResponse\x12P\n" + "\x14githubAppInformation\x18\x01 \x01(\v2\x1c.config.GithubAppInformationR\x14githubAppInformation\"#\n" + "!RegenerateGithubAppProfileRequest\"\x8f\x01\n" + @@ -1974,9 +1924,7 @@ const file_config_proto_rawDesc = "" + "\x11primaryBranchName\x18\x01 \x01(\tR\x11primaryBranchName\x12D\n" + "\fhourlyScores\x18\x02 \x03(\x01B \xbaH\x1d\x92\x01\x1a\b\x18\x10\x18\"\x14\x12\x12\x19\x00\x00\x00\x00\x00\x00\x14@)\x00\x00\x00\x00\x00\x00\x14\xc0R\fhourlyScores\"\x85\x01\n" + "\"RegenerateGithubAppProfileResponse\x12_\n" + - "\x19githubOrganisationProfile\x18\x01 \x01(\v2!.config.GithubOrganisationProfileR\x19githubOrganisationProfile\"6\n" + - "4DeleteGithubAppProfileAndGithubInstallationIDRequest\"7\n" + - "5DeleteGithubAppProfileAndGithubInstallationIDResponse\"\x1f\n" + + "\x19githubOrganisationProfile\x18\x01 \x01(\v2!.config.GithubOrganisationProfileR\x19githubOrganisationProfile\"\x1f\n" + "\x1dCreateGithubInstallURLRequest\"A\n" + "\x1eCreateGithubInstallURLResponse\x12\x1f\n" + "\vinstall_url\x18\x01 \x01(\tR\n" + @@ -1984,7 +1932,7 @@ const file_config_proto_rawDesc = "" + "\fCheckRunMode\x12\x1e\n" + "\x1aCHECK_RUN_MODE_REPORT_ONLY\x10\x00\x12%\n" + "!CHECK_RUN_MODE_FAIL_HIGH_SEVERITY\x10\x01\x12 \n" + - "\x1cCHECK_RUN_MODE_FAIL_ANY_RISK\x10\x022\xc1\t\n" + + "\x1cCHECK_RUN_MODE_FAIL_ANY_RISK\x10\x022\x92\b\n" + "\x14ConfigurationService\x12U\n" + "\x10GetAccountConfig\x12\x1f.config.GetAccountConfigRequest\x1a .config.GetAccountConfigResponse\x12^\n" + "\x13UpdateAccountConfig\x12\".config.UpdateAccountConfigRequest\x1a#.config.UpdateAccountConfigResponse\x12R\n" + @@ -1995,8 +1943,7 @@ const file_config_proto_rawDesc = "" + "\x0fGetSignalConfig\x12\x1e.config.GetSignalConfigRequest\x1a\x1f.config.GetSignalConfigResponse\x12[\n" + "\x12UpdateSignalConfig\x12!.config.UpdateSignalConfigRequest\x1a\".config.UpdateSignalConfigResponse\x12j\n" + "\x17GetGithubAppInformation\x12&.config.GetGithubAppInformationRequest\x1a'.config.GetGithubAppInformationResponse\x12s\n" + - "\x1aRegenerateGithubAppProfile\x12).config.RegenerateGithubAppProfileRequest\x1a*.config.RegenerateGithubAppProfileResponse\x12\xac\x01\n" + - "-DeleteGithubAppProfileAndGithubInstallationID\x12<.config.DeleteGithubAppProfileAndGithubInstallationIDRequest\x1a=.config.DeleteGithubAppProfileAndGithubInstallationIDResponse\x12g\n" + + "\x1aRegenerateGithubAppProfile\x12).config.RegenerateGithubAppProfileRequest\x1a*.config.RegenerateGithubAppProfileResponse\x12g\n" + "\x16CreateGithubInstallURL\x12%.config.CreateGithubInstallURLRequest\x1a&.config.CreateGithubInstallURLResponseB1Z/github.com/overmindtech/workspace/go/sdp-go;sdpb\x06proto3" var ( @@ -2012,61 +1959,59 @@ func file_config_proto_rawDescGZIP() []byte { } var file_config_proto_enumTypes = make([]protoimpl.EnumInfo, 4) -var file_config_proto_msgTypes = make([]protoimpl.MessageInfo, 32) +var file_config_proto_msgTypes = make([]protoimpl.MessageInfo, 30) var file_config_proto_goTypes = []any{ - (CheckRunMode)(0), // 0: config.CheckRunMode - (AccountConfig_BlastRadiusPreset)(0), // 1: config.AccountConfig.BlastRadiusPreset - (GetHcpConfigResponse_Status)(0), // 2: config.GetHcpConfigResponse.Status - (RoutineChangesConfig_DurationUnit)(0), // 3: config.RoutineChangesConfig.DurationUnit - (*BlastRadiusConfig)(nil), // 4: config.BlastRadiusConfig - (*AccountConfig)(nil), // 5: config.AccountConfig - (*GetAccountConfigRequest)(nil), // 6: config.GetAccountConfigRequest - (*GetAccountConfigResponse)(nil), // 7: config.GetAccountConfigResponse - (*UpdateAccountConfigRequest)(nil), // 8: config.UpdateAccountConfigRequest - (*UpdateAccountConfigResponse)(nil), // 9: config.UpdateAccountConfigResponse - (*CreateHcpConfigRequest)(nil), // 10: config.CreateHcpConfigRequest - (*CreateHcpConfigResponse)(nil), // 11: config.CreateHcpConfigResponse - (*HcpConfig)(nil), // 12: config.HcpConfig - (*GetHcpConfigRequest)(nil), // 13: config.GetHcpConfigRequest - (*GetHcpConfigResponse)(nil), // 14: config.GetHcpConfigResponse - (*DeleteHcpConfigRequest)(nil), // 15: config.DeleteHcpConfigRequest - (*DeleteHcpConfigResponse)(nil), // 16: config.DeleteHcpConfigResponse - (*ReplaceHcpApiKeyRequest)(nil), // 17: config.ReplaceHcpApiKeyRequest - (*ReplaceHcpApiKeyResponse)(nil), // 18: config.ReplaceHcpApiKeyResponse - (*GetSignalConfigRequest)(nil), // 19: config.GetSignalConfigRequest - (*GetSignalConfigResponse)(nil), // 20: config.GetSignalConfigResponse - (*UpdateSignalConfigRequest)(nil), // 21: config.UpdateSignalConfigRequest - (*UpdateSignalConfigResponse)(nil), // 22: config.UpdateSignalConfigResponse - (*SignalConfig)(nil), // 23: config.SignalConfig - (*AggregationConfig)(nil), // 24: config.AggregationConfig - (*RoutineChangesConfig)(nil), // 25: config.RoutineChangesConfig - (*GetGithubAppInformationRequest)(nil), // 26: config.GetGithubAppInformationRequest - (*GithubAppInformation)(nil), // 27: config.GithubAppInformation - (*GetGithubAppInformationResponse)(nil), // 28: config.GetGithubAppInformationResponse - (*RegenerateGithubAppProfileRequest)(nil), // 29: config.RegenerateGithubAppProfileRequest - (*GithubOrganisationProfile)(nil), // 30: config.GithubOrganisationProfile - (*RegenerateGithubAppProfileResponse)(nil), // 31: config.RegenerateGithubAppProfileResponse - (*DeleteGithubAppProfileAndGithubInstallationIDRequest)(nil), // 32: config.DeleteGithubAppProfileAndGithubInstallationIDRequest - (*DeleteGithubAppProfileAndGithubInstallationIDResponse)(nil), // 33: config.DeleteGithubAppProfileAndGithubInstallationIDResponse - (*CreateGithubInstallURLRequest)(nil), // 34: config.CreateGithubInstallURLRequest - (*CreateGithubInstallURLResponse)(nil), // 35: config.CreateGithubInstallURLResponse - (*durationpb.Duration)(nil), // 36: google.protobuf.Duration - (*CreateAPIKeyResponse)(nil), // 37: apikeys.CreateAPIKeyResponse - (*timestamppb.Timestamp)(nil), // 38: google.protobuf.Timestamp + (CheckRunMode)(0), // 0: config.CheckRunMode + (AccountConfig_BlastRadiusPreset)(0), // 1: config.AccountConfig.BlastRadiusPreset + (GetHcpConfigResponse_Status)(0), // 2: config.GetHcpConfigResponse.Status + (RoutineChangesConfig_DurationUnit)(0), // 3: config.RoutineChangesConfig.DurationUnit + (*BlastRadiusConfig)(nil), // 4: config.BlastRadiusConfig + (*AccountConfig)(nil), // 5: config.AccountConfig + (*GetAccountConfigRequest)(nil), // 6: config.GetAccountConfigRequest + (*GetAccountConfigResponse)(nil), // 7: config.GetAccountConfigResponse + (*UpdateAccountConfigRequest)(nil), // 8: config.UpdateAccountConfigRequest + (*UpdateAccountConfigResponse)(nil), // 9: config.UpdateAccountConfigResponse + (*CreateHcpConfigRequest)(nil), // 10: config.CreateHcpConfigRequest + (*CreateHcpConfigResponse)(nil), // 11: config.CreateHcpConfigResponse + (*HcpConfig)(nil), // 12: config.HcpConfig + (*GetHcpConfigRequest)(nil), // 13: config.GetHcpConfigRequest + (*GetHcpConfigResponse)(nil), // 14: config.GetHcpConfigResponse + (*DeleteHcpConfigRequest)(nil), // 15: config.DeleteHcpConfigRequest + (*DeleteHcpConfigResponse)(nil), // 16: config.DeleteHcpConfigResponse + (*ReplaceHcpApiKeyRequest)(nil), // 17: config.ReplaceHcpApiKeyRequest + (*ReplaceHcpApiKeyResponse)(nil), // 18: config.ReplaceHcpApiKeyResponse + (*GetSignalConfigRequest)(nil), // 19: config.GetSignalConfigRequest + (*GetSignalConfigResponse)(nil), // 20: config.GetSignalConfigResponse + (*UpdateSignalConfigRequest)(nil), // 21: config.UpdateSignalConfigRequest + (*UpdateSignalConfigResponse)(nil), // 22: config.UpdateSignalConfigResponse + (*SignalConfig)(nil), // 23: config.SignalConfig + (*AggregationConfig)(nil), // 24: config.AggregationConfig + (*RoutineChangesConfig)(nil), // 25: config.RoutineChangesConfig + (*GetGithubAppInformationRequest)(nil), // 26: config.GetGithubAppInformationRequest + (*GithubAppInformation)(nil), // 27: config.GithubAppInformation + (*GetGithubAppInformationResponse)(nil), // 28: config.GetGithubAppInformationResponse + (*RegenerateGithubAppProfileRequest)(nil), // 29: config.RegenerateGithubAppProfileRequest + (*GithubOrganisationProfile)(nil), // 30: config.GithubOrganisationProfile + (*RegenerateGithubAppProfileResponse)(nil), // 31: config.RegenerateGithubAppProfileResponse + (*CreateGithubInstallURLRequest)(nil), // 32: config.CreateGithubInstallURLRequest + (*CreateGithubInstallURLResponse)(nil), // 33: config.CreateGithubInstallURLResponse + (*durationpb.Duration)(nil), // 34: google.protobuf.Duration + (*CreateAPIKeyResponse)(nil), // 35: apikeys.CreateAPIKeyResponse + (*timestamppb.Timestamp)(nil), // 36: google.protobuf.Timestamp } var file_config_proto_depIdxs = []int32{ - 36, // 0: config.BlastRadiusConfig.changeAnalysisTargetDuration:type_name -> google.protobuf.Duration + 34, // 0: config.BlastRadiusConfig.changeAnalysisTargetDuration:type_name -> google.protobuf.Duration 1, // 1: config.AccountConfig.blastRadiusPreset:type_name -> config.AccountConfig.BlastRadiusPreset 4, // 2: config.AccountConfig.blastRadius:type_name -> config.BlastRadiusConfig 5, // 3: config.GetAccountConfigResponse.config:type_name -> config.AccountConfig 5, // 4: config.UpdateAccountConfigRequest.config:type_name -> config.AccountConfig 5, // 5: config.UpdateAccountConfigResponse.config:type_name -> config.AccountConfig 12, // 6: config.CreateHcpConfigResponse.config:type_name -> config.HcpConfig - 37, // 7: config.CreateHcpConfigResponse.apiKey:type_name -> apikeys.CreateAPIKeyResponse + 35, // 7: config.CreateHcpConfigResponse.apiKey:type_name -> apikeys.CreateAPIKeyResponse 12, // 8: config.GetHcpConfigResponse.config:type_name -> config.HcpConfig 2, // 9: config.GetHcpConfigResponse.status:type_name -> config.GetHcpConfigResponse.Status 12, // 10: config.ReplaceHcpApiKeyResponse.config:type_name -> config.HcpConfig - 37, // 11: config.ReplaceHcpApiKeyResponse.apiKey:type_name -> apikeys.CreateAPIKeyResponse + 35, // 11: config.ReplaceHcpApiKeyResponse.apiKey:type_name -> apikeys.CreateAPIKeyResponse 23, // 12: config.GetSignalConfigResponse.config:type_name -> config.SignalConfig 23, // 13: config.UpdateSignalConfigRequest.config:type_name -> config.SignalConfig 23, // 14: config.UpdateSignalConfigResponse.config:type_name -> config.SignalConfig @@ -2076,8 +2021,8 @@ var file_config_proto_depIdxs = []int32{ 0, // 18: config.SignalConfig.check_run_mode:type_name -> config.CheckRunMode 3, // 19: config.RoutineChangesConfig.eventsPerUnit:type_name -> config.RoutineChangesConfig.DurationUnit 3, // 20: config.RoutineChangesConfig.durationUnit:type_name -> config.RoutineChangesConfig.DurationUnit - 38, // 21: config.GithubAppInformation.installedAt:type_name -> google.protobuf.Timestamp - 38, // 22: config.GithubAppInformation.requestedAt:type_name -> google.protobuf.Timestamp + 36, // 21: config.GithubAppInformation.installedAt:type_name -> google.protobuf.Timestamp + 36, // 22: config.GithubAppInformation.requestedAt:type_name -> google.protobuf.Timestamp 27, // 23: config.GetGithubAppInformationResponse.githubAppInformation:type_name -> config.GithubAppInformation 30, // 24: config.RegenerateGithubAppProfileResponse.githubOrganisationProfile:type_name -> config.GithubOrganisationProfile 6, // 25: config.ConfigurationService.GetAccountConfig:input_type -> config.GetAccountConfigRequest @@ -2090,22 +2035,20 @@ var file_config_proto_depIdxs = []int32{ 21, // 32: config.ConfigurationService.UpdateSignalConfig:input_type -> config.UpdateSignalConfigRequest 26, // 33: config.ConfigurationService.GetGithubAppInformation:input_type -> config.GetGithubAppInformationRequest 29, // 34: config.ConfigurationService.RegenerateGithubAppProfile:input_type -> config.RegenerateGithubAppProfileRequest - 32, // 35: config.ConfigurationService.DeleteGithubAppProfileAndGithubInstallationID:input_type -> config.DeleteGithubAppProfileAndGithubInstallationIDRequest - 34, // 36: config.ConfigurationService.CreateGithubInstallURL:input_type -> config.CreateGithubInstallURLRequest - 7, // 37: config.ConfigurationService.GetAccountConfig:output_type -> config.GetAccountConfigResponse - 9, // 38: config.ConfigurationService.UpdateAccountConfig:output_type -> config.UpdateAccountConfigResponse - 11, // 39: config.ConfigurationService.CreateHcpConfig:output_type -> config.CreateHcpConfigResponse - 14, // 40: config.ConfigurationService.GetHcpConfig:output_type -> config.GetHcpConfigResponse - 16, // 41: config.ConfigurationService.DeleteHcpConfig:output_type -> config.DeleteHcpConfigResponse - 18, // 42: config.ConfigurationService.ReplaceHcpApiKey:output_type -> config.ReplaceHcpApiKeyResponse - 20, // 43: config.ConfigurationService.GetSignalConfig:output_type -> config.GetSignalConfigResponse - 22, // 44: config.ConfigurationService.UpdateSignalConfig:output_type -> config.UpdateSignalConfigResponse - 28, // 45: config.ConfigurationService.GetGithubAppInformation:output_type -> config.GetGithubAppInformationResponse - 31, // 46: config.ConfigurationService.RegenerateGithubAppProfile:output_type -> config.RegenerateGithubAppProfileResponse - 33, // 47: config.ConfigurationService.DeleteGithubAppProfileAndGithubInstallationID:output_type -> config.DeleteGithubAppProfileAndGithubInstallationIDResponse - 35, // 48: config.ConfigurationService.CreateGithubInstallURL:output_type -> config.CreateGithubInstallURLResponse - 37, // [37:49] is the sub-list for method output_type - 25, // [25:37] is the sub-list for method input_type + 32, // 35: config.ConfigurationService.CreateGithubInstallURL:input_type -> config.CreateGithubInstallURLRequest + 7, // 36: config.ConfigurationService.GetAccountConfig:output_type -> config.GetAccountConfigResponse + 9, // 37: config.ConfigurationService.UpdateAccountConfig:output_type -> config.UpdateAccountConfigResponse + 11, // 38: config.ConfigurationService.CreateHcpConfig:output_type -> config.CreateHcpConfigResponse + 14, // 39: config.ConfigurationService.GetHcpConfig:output_type -> config.GetHcpConfigResponse + 16, // 40: config.ConfigurationService.DeleteHcpConfig:output_type -> config.DeleteHcpConfigResponse + 18, // 41: config.ConfigurationService.ReplaceHcpApiKey:output_type -> config.ReplaceHcpApiKeyResponse + 20, // 42: config.ConfigurationService.GetSignalConfig:output_type -> config.GetSignalConfigResponse + 22, // 43: config.ConfigurationService.UpdateSignalConfig:output_type -> config.UpdateSignalConfigResponse + 28, // 44: config.ConfigurationService.GetGithubAppInformation:output_type -> config.GetGithubAppInformationResponse + 31, // 45: config.ConfigurationService.RegenerateGithubAppProfile:output_type -> config.RegenerateGithubAppProfileResponse + 33, // 46: config.ConfigurationService.CreateGithubInstallURL:output_type -> config.CreateGithubInstallURLResponse + 36, // [36:47] is the sub-list for method output_type + 25, // [25:36] is the sub-list for method input_type 25, // [25:25] is the sub-list for extension type_name 25, // [25:25] is the sub-list for extension extendee 0, // [0:25] is the sub-list for field type_name @@ -2127,7 +2070,7 @@ func file_config_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_config_proto_rawDesc), len(file_config_proto_rawDesc)), NumEnums: 4, - NumMessages: 32, + NumMessages: 30, NumExtensions: 0, NumServices: 1, }, diff --git a/go/sdp-go/sdpconnect/config.connect.go b/go/sdp-go/sdpconnect/config.connect.go index bb758bc5..6993594b 100644 --- a/go/sdp-go/sdpconnect/config.connect.go +++ b/go/sdp-go/sdpconnect/config.connect.go @@ -63,9 +63,6 @@ const ( // ConfigurationServiceRegenerateGithubAppProfileProcedure is the fully-qualified name of the // ConfigurationService's RegenerateGithubAppProfile RPC. ConfigurationServiceRegenerateGithubAppProfileProcedure = "/config.ConfigurationService/RegenerateGithubAppProfile" - // ConfigurationServiceDeleteGithubAppProfileAndGithubInstallationIDProcedure is the fully-qualified - // name of the ConfigurationService's DeleteGithubAppProfileAndGithubInstallationID RPC. - ConfigurationServiceDeleteGithubAppProfileAndGithubInstallationIDProcedure = "/config.ConfigurationService/DeleteGithubAppProfileAndGithubInstallationID" // ConfigurationServiceCreateGithubInstallURLProcedure is the fully-qualified name of the // ConfigurationService's CreateGithubInstallURL RPC. ConfigurationServiceCreateGithubInstallURLProcedure = "/config.ConfigurationService/CreateGithubInstallURL" @@ -98,9 +95,6 @@ type ConfigurationServiceClient interface { GetGithubAppInformation(context.Context, *connect.Request[sdp_go.GetGithubAppInformationRequest]) (*connect.Response[sdp_go.GetGithubAppInformationResponse], error) // regenerate the github app profile, this information is used for signal processing RegenerateGithubAppProfile(context.Context, *connect.Request[sdp_go.RegenerateGithubAppProfileRequest]) (*connect.Response[sdp_go.RegenerateGithubAppProfileResponse], error) - // remove the github app installation id and github organisation profile from the signal config - // this will not uninstall the app from github, that must be done manually by the user - DeleteGithubAppProfileAndGithubInstallationID(context.Context, *connect.Request[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDRequest]) (*connect.Response[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDResponse], error) // Create a GitHub App install URL with a DB-backed state parameter for CSRF // protection. The frontend calls this RPC, then redirects the user to the // returned URL. GitHub will redirect back with the state UUID, which the @@ -179,12 +173,6 @@ func NewConfigurationServiceClient(httpClient connect.HTTPClient, baseURL string connect.WithSchema(configurationServiceMethods.ByName("RegenerateGithubAppProfile")), connect.WithClientOptions(opts...), ), - deleteGithubAppProfileAndGithubInstallationID: connect.NewClient[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDRequest, sdp_go.DeleteGithubAppProfileAndGithubInstallationIDResponse]( - httpClient, - baseURL+ConfigurationServiceDeleteGithubAppProfileAndGithubInstallationIDProcedure, - connect.WithSchema(configurationServiceMethods.ByName("DeleteGithubAppProfileAndGithubInstallationID")), - connect.WithClientOptions(opts...), - ), createGithubInstallURL: connect.NewClient[sdp_go.CreateGithubInstallURLRequest, sdp_go.CreateGithubInstallURLResponse]( httpClient, baseURL+ConfigurationServiceCreateGithubInstallURLProcedure, @@ -196,18 +184,17 @@ func NewConfigurationServiceClient(httpClient connect.HTTPClient, baseURL string // configurationServiceClient implements ConfigurationServiceClient. type configurationServiceClient struct { - getAccountConfig *connect.Client[sdp_go.GetAccountConfigRequest, sdp_go.GetAccountConfigResponse] - updateAccountConfig *connect.Client[sdp_go.UpdateAccountConfigRequest, sdp_go.UpdateAccountConfigResponse] - createHcpConfig *connect.Client[sdp_go.CreateHcpConfigRequest, sdp_go.CreateHcpConfigResponse] - getHcpConfig *connect.Client[sdp_go.GetHcpConfigRequest, sdp_go.GetHcpConfigResponse] - deleteHcpConfig *connect.Client[sdp_go.DeleteHcpConfigRequest, sdp_go.DeleteHcpConfigResponse] - replaceHcpApiKey *connect.Client[sdp_go.ReplaceHcpApiKeyRequest, sdp_go.ReplaceHcpApiKeyResponse] - getSignalConfig *connect.Client[sdp_go.GetSignalConfigRequest, sdp_go.GetSignalConfigResponse] - updateSignalConfig *connect.Client[sdp_go.UpdateSignalConfigRequest, sdp_go.UpdateSignalConfigResponse] - getGithubAppInformation *connect.Client[sdp_go.GetGithubAppInformationRequest, sdp_go.GetGithubAppInformationResponse] - regenerateGithubAppProfile *connect.Client[sdp_go.RegenerateGithubAppProfileRequest, sdp_go.RegenerateGithubAppProfileResponse] - deleteGithubAppProfileAndGithubInstallationID *connect.Client[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDRequest, sdp_go.DeleteGithubAppProfileAndGithubInstallationIDResponse] - createGithubInstallURL *connect.Client[sdp_go.CreateGithubInstallURLRequest, sdp_go.CreateGithubInstallURLResponse] + getAccountConfig *connect.Client[sdp_go.GetAccountConfigRequest, sdp_go.GetAccountConfigResponse] + updateAccountConfig *connect.Client[sdp_go.UpdateAccountConfigRequest, sdp_go.UpdateAccountConfigResponse] + createHcpConfig *connect.Client[sdp_go.CreateHcpConfigRequest, sdp_go.CreateHcpConfigResponse] + getHcpConfig *connect.Client[sdp_go.GetHcpConfigRequest, sdp_go.GetHcpConfigResponse] + deleteHcpConfig *connect.Client[sdp_go.DeleteHcpConfigRequest, sdp_go.DeleteHcpConfigResponse] + replaceHcpApiKey *connect.Client[sdp_go.ReplaceHcpApiKeyRequest, sdp_go.ReplaceHcpApiKeyResponse] + getSignalConfig *connect.Client[sdp_go.GetSignalConfigRequest, sdp_go.GetSignalConfigResponse] + updateSignalConfig *connect.Client[sdp_go.UpdateSignalConfigRequest, sdp_go.UpdateSignalConfigResponse] + getGithubAppInformation *connect.Client[sdp_go.GetGithubAppInformationRequest, sdp_go.GetGithubAppInformationResponse] + regenerateGithubAppProfile *connect.Client[sdp_go.RegenerateGithubAppProfileRequest, sdp_go.RegenerateGithubAppProfileResponse] + createGithubInstallURL *connect.Client[sdp_go.CreateGithubInstallURLRequest, sdp_go.CreateGithubInstallURLResponse] } // GetAccountConfig calls config.ConfigurationService.GetAccountConfig. @@ -260,12 +247,6 @@ func (c *configurationServiceClient) RegenerateGithubAppProfile(ctx context.Cont return c.regenerateGithubAppProfile.CallUnary(ctx, req) } -// DeleteGithubAppProfileAndGithubInstallationID calls -// config.ConfigurationService.DeleteGithubAppProfileAndGithubInstallationID. -func (c *configurationServiceClient) DeleteGithubAppProfileAndGithubInstallationID(ctx context.Context, req *connect.Request[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDRequest]) (*connect.Response[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDResponse], error) { - return c.deleteGithubAppProfileAndGithubInstallationID.CallUnary(ctx, req) -} - // CreateGithubInstallURL calls config.ConfigurationService.CreateGithubInstallURL. func (c *configurationServiceClient) CreateGithubInstallURL(ctx context.Context, req *connect.Request[sdp_go.CreateGithubInstallURLRequest]) (*connect.Response[sdp_go.CreateGithubInstallURLResponse], error) { return c.createGithubInstallURL.CallUnary(ctx, req) @@ -298,9 +279,6 @@ type ConfigurationServiceHandler interface { GetGithubAppInformation(context.Context, *connect.Request[sdp_go.GetGithubAppInformationRequest]) (*connect.Response[sdp_go.GetGithubAppInformationResponse], error) // regenerate the github app profile, this information is used for signal processing RegenerateGithubAppProfile(context.Context, *connect.Request[sdp_go.RegenerateGithubAppProfileRequest]) (*connect.Response[sdp_go.RegenerateGithubAppProfileResponse], error) - // remove the github app installation id and github organisation profile from the signal config - // this will not uninstall the app from github, that must be done manually by the user - DeleteGithubAppProfileAndGithubInstallationID(context.Context, *connect.Request[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDRequest]) (*connect.Response[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDResponse], error) // Create a GitHub App install URL with a DB-backed state parameter for CSRF // protection. The frontend calls this RPC, then redirects the user to the // returned URL. GitHub will redirect back with the state UUID, which the @@ -375,12 +353,6 @@ func NewConfigurationServiceHandler(svc ConfigurationServiceHandler, opts ...con connect.WithSchema(configurationServiceMethods.ByName("RegenerateGithubAppProfile")), connect.WithHandlerOptions(opts...), ) - configurationServiceDeleteGithubAppProfileAndGithubInstallationIDHandler := connect.NewUnaryHandler( - ConfigurationServiceDeleteGithubAppProfileAndGithubInstallationIDProcedure, - svc.DeleteGithubAppProfileAndGithubInstallationID, - connect.WithSchema(configurationServiceMethods.ByName("DeleteGithubAppProfileAndGithubInstallationID")), - connect.WithHandlerOptions(opts...), - ) configurationServiceCreateGithubInstallURLHandler := connect.NewUnaryHandler( ConfigurationServiceCreateGithubInstallURLProcedure, svc.CreateGithubInstallURL, @@ -409,8 +381,6 @@ func NewConfigurationServiceHandler(svc ConfigurationServiceHandler, opts ...con configurationServiceGetGithubAppInformationHandler.ServeHTTP(w, r) case ConfigurationServiceRegenerateGithubAppProfileProcedure: configurationServiceRegenerateGithubAppProfileHandler.ServeHTTP(w, r) - case ConfigurationServiceDeleteGithubAppProfileAndGithubInstallationIDProcedure: - configurationServiceDeleteGithubAppProfileAndGithubInstallationIDHandler.ServeHTTP(w, r) case ConfigurationServiceCreateGithubInstallURLProcedure: configurationServiceCreateGithubInstallURLHandler.ServeHTTP(w, r) default: @@ -462,10 +432,6 @@ func (UnimplementedConfigurationServiceHandler) RegenerateGithubAppProfile(conte return nil, connect.NewError(connect.CodeUnimplemented, errors.New("config.ConfigurationService.RegenerateGithubAppProfile is not implemented")) } -func (UnimplementedConfigurationServiceHandler) DeleteGithubAppProfileAndGithubInstallationID(context.Context, *connect.Request[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDRequest]) (*connect.Response[sdp_go.DeleteGithubAppProfileAndGithubInstallationIDResponse], error) { - return nil, connect.NewError(connect.CodeUnimplemented, errors.New("config.ConfigurationService.DeleteGithubAppProfileAndGithubInstallationID is not implemented")) -} - func (UnimplementedConfigurationServiceHandler) CreateGithubInstallURL(context.Context, *connect.Request[sdp_go.CreateGithubInstallURLRequest]) (*connect.Response[sdp_go.CreateGithubInstallURLResponse], error) { return nil, connect.NewError(connect.CodeUnimplemented, errors.New("config.ConfigurationService.CreateGithubInstallURL is not implemented")) } diff --git a/go/sdp-go/sdpws/messagehandler.go b/go/sdp-go/sdpws/messagehandler.go index e25fffdc..3e9d348a 100644 --- a/go/sdp-go/sdpws/messagehandler.go +++ b/go/sdp-go/sdpws/messagehandler.go @@ -40,7 +40,14 @@ type LoggingGatewayMessageHandler struct { var _ GatewayMessageHandler = (*LoggingGatewayMessageHandler)(nil) func (l *LoggingGatewayMessageHandler) NewItem(ctx context.Context, item *sdp.Item) { - log.WithContext(ctx).WithField("item", item).Log(l.Level, "received new item") + entry := log.WithContext(ctx) + if item != nil { + entry = entry.WithField("item.globallyUniqueName", item.GloballyUniqueName()) + } + if l.Level >= log.DebugLevel { + entry = entry.WithField("item", item) + } + entry.Log(l.Level, "received new item") } func (l *LoggingGatewayMessageHandler) NewEdge(ctx context.Context, edge *sdp.Edge) { @@ -68,7 +75,14 @@ func (l *LoggingGatewayMessageHandler) DeleteEdge(ctx context.Context, edge *sdp } func (l *LoggingGatewayMessageHandler) UpdateItem(ctx context.Context, item *sdp.Item) { - log.WithContext(ctx).WithField("item", item).Log(l.Level, "received updated item") + entry := log.WithContext(ctx) + if item != nil { + entry = entry.WithField("item.globallyUniqueName", item.GloballyUniqueName()) + } + if l.Level >= log.DebugLevel { + entry = entry.WithField("item", item) + } + entry.Log(l.Level, "received updated item") } func (l *LoggingGatewayMessageHandler) SnapshotStoreResult(ctx context.Context, result *sdp.SnapshotStoreResult) { diff --git a/go/sdpcache/boltstore.go b/go/sdpcache/boltstore.go index 96219ff4..5d92e6c3 100644 --- a/go/sdpcache/boltstore.go +++ b/go/sdpcache/boltstore.go @@ -522,6 +522,8 @@ func (c *boltStore) Search(ctx context.Context, ck CacheKey) ([]*sdp.Item, error } if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "cache search failed") return nil, fmt.Errorf("search failed: %w", err) } @@ -642,11 +644,15 @@ func (c *boltStore) storeResult(ctx context.Context, res CachedResult) { entry, err := fromCachedResult(&res) if err != nil { - return // Silently fail on serialization errors + span.RecordError(err) + span.SetStatus(codes.Error, "failed to serialize cache result") + return } entryBytes, err := encodeCachedEntry(entry) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "failed to encode cache entry") return } @@ -922,6 +928,7 @@ func (c *boltStore) Purge(ctx context.Context, before time.Time) PurgeStats { span.SetAttributes(attribute.Bool("ovm.boltdb.compactionSuccess", true)) } else { span.RecordError(err) + span.SetStatus(codes.Error, "compaction failed") span.SetAttributes(attribute.Bool("ovm.boltdb.compactionSuccess", false)) } } else { @@ -952,7 +959,7 @@ func (c *boltStore) purgeLocked(ctx context.Context, before time.Time) PurgeStat } expired := make([]expiredEntry, 0) - _ = c.db.View(func(tx *bbolt.Tx) error { + if err := c.db.View(func(tx *bbolt.Tx) error { expiry := tx.Bucket(expiryBucketName) if expiry == nil { return nil @@ -990,11 +997,14 @@ func (c *boltStore) purgeLocked(ctx context.Context, before time.Time) PurgeStat } return nil - }) + }); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "purge scan failed") + } // Delete expired entries if len(expired) > 0 { - _ = c.db.Update(func(tx *bbolt.Tx) error { + if err := c.db.Update(func(tx *bbolt.Tx) error { items := tx.Bucket(itemsBucketName) expiry := tx.Bucket(expiryBucketName) @@ -1029,7 +1039,10 @@ func (c *boltStore) purgeLocked(ctx context.Context, before time.Time) PurgeStat // Save deleted bytes c.addDeletedBytes(totalDeleted) return c.saveDeletedBytes(tx) - }) + }); err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, "purge delete failed") + } } // Update final disk usage metrics diff --git a/go/tracing/main.go b/go/tracing/main.go index 41996aee..2f35980f 100644 --- a/go/tracing/main.go +++ b/go/tracing/main.go @@ -6,7 +6,6 @@ import ( "net/http" "os" "path/filepath" - "slices" "time" _ "embed" @@ -27,7 +26,6 @@ import ( sdktrace "go.opentelemetry.io/otel/sdk/trace" semconv "go.opentelemetry.io/otel/semconv/v1.26.0" "go.opentelemetry.io/otel/trace" - "golang.org/x/sync/errgroup" ) // logrusOtelErrorHandler routes OpenTelemetry SDK errors through logrus so they @@ -162,27 +160,6 @@ func tracingResource(component string) *resource.Resource { } var tp *sdktrace.TracerProvider -var healthTp *sdktrace.TracerProvider - -// HealthCheckTracerProvider returns the tracer provider used for health checks. This has a built-in 1:100 sampler for health checks that are not captured by the default UserAgentSampler for ELB and kube-probe requests. -func HealthCheckTracerProvider() *sdktrace.TracerProvider { - if healthTp == nil { - panic("tracer providers not initialised") - } - return healthTp -} - -// healthCheckTracer is the tracer used for health checks. This is heavily sampled to avoid getting spammed by k8s or ELBs -func HealthCheckTracer() trace.Tracer { - return HealthCheckTracerProvider().Tracer( - instrumentationName, - trace.WithInstrumentationVersion(version), - trace.WithSchemaURL(semconv.SchemaURL), - trace.WithInstrumentationAttributes( - attribute.Bool("ovm.healthCheck", true), - ), - ) -} // InitTracerWithUpstreams initialises the tracer with uploading directly to Honeycomb and sentry if `honeycombApiKey` and `sentryDSN` is set respectively. `component` is used as the service name. func InitTracerWithUpstreams(component, honeycombApiKey, sentryDSN string, opts ...otlptracehttp.Option) error { @@ -255,18 +232,11 @@ func InitTracer(component string, opts ...otlptracehttp.Option) error { return fmt.Errorf("creating OTLP trace exporter: %w", err) } - healthExp, err := otlptrace.New(context.Background(), otlptracehttp.NewClient(opts...)) - if err != nil { - return fmt.Errorf("creating health OTLP trace exporter: %w", err) - } - - overmindSampler := NewOvermindSampler() res := tracingResource(component) tracerOpts := []sdktrace.TracerProviderOption{ sdktrace.WithBatcher(otlpExp, batcherOpts...), sdktrace.WithResource(res), - sdktrace.WithSampler(sdktrace.ParentBased(overmindSampler)), } if viper.GetBool("stdout-trace-dump") { stdoutExp, err := stdouttrace.New(stdouttrace.WithPrettyPrint()) @@ -279,12 +249,6 @@ func InitTracer(component string, opts ...otlptracehttp.Option) error { otel.SetTracerProvider(tp) - healthTp = sdktrace.NewTracerProvider( - sdktrace.WithBatcher(healthExp, batcherOpts...), - sdktrace.WithResource(res), - sdktrace.WithSampler(sdktrace.ParentBased(sdktrace.TraceIDRatioBased(0.1))), - ) - otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) return nil } @@ -295,121 +259,16 @@ func ShutdownTracer(ctx context.Context) { ctx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 10*time.Second) defer cancel() - var g errgroup.Group - - // Do not nil healthTp or tp here: concurrent callers (e.g. health check - // probes via HealthCheckTracerProvider) would panic on the nil guard. - // Calling Shutdown on an already-shutdown provider is a safe no-op. - if healthTp != nil { - localTp := healthTp - g.Go(func() error { - if err := localTp.ForceFlush(ctx); err != nil { - log.WithContext(ctx).WithError(err).Error("Error flushing health tracer provider") - } - if err := localTp.Shutdown(ctx); err != nil { - log.WithContext(ctx).WithError(err).Error("Error shutting down health tracer provider") - return err - } - return nil - }) - } - if tp != nil { - localTp := tp - g.Go(func() error { - if err := localTp.ForceFlush(ctx); err != nil { - log.WithContext(ctx).WithError(err).Error("Error flushing tracer provider") - } - if err := localTp.Shutdown(ctx); err != nil { - log.WithContext(ctx).WithError(err).Error("Error shutting down tracer provider") - return err - } - return nil - }) - } - - if err := g.Wait(); err != nil { - log.WithContext(ctx).WithError(err).Error("Error during tracer shutdown") - } - log.WithContext(ctx).Trace("tracing has shut down") -} - -// SamplingRule defines a single sampling rule with a rate and matching function -type SamplingRule struct { - SampleRate int - ShouldSample func(sdktrace.SamplingParameters) bool -} - -// OvermindSampler is a unified sampler that evaluates multiple sampling rules in order -type OvermindSampler struct { - rules []SamplingRule - ruleSamplers []sdktrace.Sampler -} - -// NewOvermindSampler creates a new unified sampler with the default rules -func NewOvermindSampler() *OvermindSampler { - rules := []SamplingRule{ - { - SampleRate: 200, - ShouldSample: UserAgentMatcher("ELB-HealthChecker/2.0", "kube-probe/1.27+"), - }, - } - - // Pre-allocate samplers for each rule - ruleSamplers := make([]sdktrace.Sampler, 0, len(rules)) - for _, rule := range rules { - var sampler sdktrace.Sampler - switch { - case rule.SampleRate <= 0: - sampler = sdktrace.NeverSample() - case rule.SampleRate == 1: - sampler = sdktrace.AlwaysSample() - default: - sampler = sdktrace.TraceIDRatioBased(1.0 / float64(rule.SampleRate)) - } - ruleSamplers = append(ruleSamplers, sampler) - } - - return &OvermindSampler{ - rules: rules, - ruleSamplers: ruleSamplers, - } -} - -// UserAgentMatcher returns a function that matches specific user agents -func UserAgentMatcher(userAgents ...string) func(sdktrace.SamplingParameters) bool { - return func(parameters sdktrace.SamplingParameters) bool { - for _, attr := range parameters.Attributes { - if (attr.Key == "http.user_agent" || attr.Key == "user_agent.original") && - slices.Contains(userAgents, attr.Value.AsString()) { - return true - } + if err := tp.ForceFlush(ctx); err != nil { + log.WithContext(ctx).WithError(err).Error("Error flushing tracer provider") } - return false - } -} - -// ShouldSample evaluates rules in order and returns the first matching decision -func (o *OvermindSampler) ShouldSample(parameters sdktrace.SamplingParameters) sdktrace.SamplingResult { - for i, rule := range o.rules { - if rule.ShouldSample(parameters) { - // Use the pre-allocated sampler for this rule - result := o.ruleSamplers[i].ShouldSample(parameters) - if result.Decision == sdktrace.RecordAndSample { - result.Attributes = append(result.Attributes, - attribute.Int("SampleRate", rule.SampleRate)) - } - return result + if err := tp.Shutdown(ctx); err != nil { + log.WithContext(ctx).WithError(err).Error("Error shutting down tracer provider") } } - // Default to AlwaysSample if no rules match - return sdktrace.AlwaysSample().ShouldSample(parameters) -} - -// Description returns information describing the Sampler -func (o *OvermindSampler) Description() string { - return "Unified Overmind sampler combining multiple sampling strategies" + log.WithContext(ctx).Trace("tracing has shut down") } // Version returns the version baked into the binary at build time. diff --git a/go/tracing/main_test.go b/go/tracing/main_test.go index 3d277079..3907d95a 100644 --- a/go/tracing/main_test.go +++ b/go/tracing/main_test.go @@ -21,34 +21,28 @@ func TestTracingResource(t *testing.T) { } } -func TestShutdownBothProviders(t *testing.T) { +func TestShutdownProvider(t *testing.T) { exp := tracetest.NewInMemoryExporter() tp = sdktrace.NewTracerProvider(sdktrace.WithBatcher(exp)) - healthTp = sdktrace.NewTracerProvider(sdktrace.WithBatcher(exp)) - if tp == nil || healthTp == nil { - t.Fatal("expected both tp and healthTp to be non-nil after init") + if tp == nil { + t.Fatal("expected tp to be non-nil after init") } ShutdownTracer(context.Background()) - // After shutdown, calling Shutdown again on the providers should be a - // safe no-op (the SDK guards with stopOnce). We do NOT nil the package - // vars because concurrent callers (e.g. health probes) would panic. + // After shutdown, calling Shutdown again should be a safe no-op + // (the SDK guards with stopOnce). if err := tp.Shutdown(context.Background()); err != nil { t.Errorf("second tp.Shutdown should be a no-op, got: %v", err) } - if err := healthTp.Shutdown(context.Background()); err != nil { - t.Errorf("second healthTp.Shutdown should be a no-op, got: %v", err) - } } func TestShutdownIdempotent(t *testing.T) { exp := tracetest.NewInMemoryExporter() tp = sdktrace.NewTracerProvider(sdktrace.WithBatcher(exp)) - healthTp = sdktrace.NewTracerProvider(sdktrace.WithBatcher(exp)) ShutdownTracer(context.Background()) ShutdownTracer(context.Background()) // must not panic diff --git a/k8s-source/build/package/Dockerfile b/k8s-source/build/package/Dockerfile index 566ca10e..bd730490 100644 --- a/k8s-source/build/package/Dockerfile +++ b/k8s-source/build/package/Dockerfile @@ -6,12 +6,16 @@ ARG BUILD_VERSION ARG BUILD_COMMIT # required for accessing the private dependencies and generating version descriptor -RUN apk upgrade --no-cache && apk add --no-cache git curl +RUN apk upgrade --no-cache && apk add --no-cache git WORKDIR /workspace -# Copy the go source -COPY . . +COPY go.mod go.sum ./ +RUN --mount=type=cache,target=/go/pkg \ + go mod download + +COPY go/ go/ +COPY k8s-source/ k8s-source/ # Build RUN --mount=type=cache,target=/go/pkg \ diff --git a/sources/azure/build/package/Dockerfile b/sources/azure/build/package/Dockerfile index a31baa68..e7d504fc 100644 --- a/sources/azure/build/package/Dockerfile +++ b/sources/azure/build/package/Dockerfile @@ -10,8 +10,12 @@ RUN apk upgrade --no-cache && apk add --no-cache git WORKDIR /workspace -# Copy the go source -COPY . . +COPY go.mod go.sum ./ +RUN --mount=type=cache,target=/go/pkg \ + go mod download + +COPY go/ go/ +COPY sources/ sources/ # Build RUN --mount=type=cache,target=/go/pkg \ diff --git a/sources/azure/clients/batch-private-endpoint-connection-client.go b/sources/azure/clients/batch-private-endpoint-connection-client.go new file mode 100644 index 00000000..eb77e77d --- /dev/null +++ b/sources/azure/clients/batch-private-endpoint-connection-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/batch/armbatch/v4" +) + +//go:generate mockgen -destination=../shared/mocks/mock_batch_private_endpoint_connection_client.go -package=mocks -source=batch-private-endpoint-connection-client.go + +// BatchPrivateEndpointConnectionPager is a type alias for the generic Pager interface with Batch private endpoint connection list response type. +type BatchPrivateEndpointConnectionPager = Pager[armbatch.PrivateEndpointConnectionClientListByBatchAccountResponse] + +// BatchPrivateEndpointConnectionClient is an interface for interacting with Azure Batch private endpoint connections. +type BatchPrivateEndpointConnectionClient interface { + Get(ctx context.Context, resourceGroupName string, accountName string, privateEndpointConnectionName string) (armbatch.PrivateEndpointConnectionClientGetResponse, error) + ListByBatchAccount(ctx context.Context, resourceGroupName string, accountName string) BatchPrivateEndpointConnectionPager +} + +type batchPrivateEndpointConnectionClient struct { + client *armbatch.PrivateEndpointConnectionClient +} + +func (c *batchPrivateEndpointConnectionClient) Get(ctx context.Context, resourceGroupName string, accountName string, privateEndpointConnectionName string) (armbatch.PrivateEndpointConnectionClientGetResponse, error) { + return c.client.Get(ctx, resourceGroupName, accountName, privateEndpointConnectionName, nil) +} + +func (c *batchPrivateEndpointConnectionClient) ListByBatchAccount(ctx context.Context, resourceGroupName string, accountName string) BatchPrivateEndpointConnectionPager { + return c.client.NewListByBatchAccountPager(resourceGroupName, accountName, nil) +} + +// NewBatchPrivateEndpointConnectionClient creates a new BatchPrivateEndpointConnectionClient from the Azure SDK client. +func NewBatchPrivateEndpointConnectionClient(client *armbatch.PrivateEndpointConnectionClient) BatchPrivateEndpointConnectionClient { + return &batchPrivateEndpointConnectionClient{client: client} +} diff --git a/sources/azure/clients/dbforpostgresql-flexible-server-administrator-client.go b/sources/azure/clients/dbforpostgresql-flexible-server-administrator-client.go new file mode 100644 index 00000000..05be6049 --- /dev/null +++ b/sources/azure/clients/dbforpostgresql-flexible-server-administrator-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" +) + +//go:generate mockgen -destination=../shared/mocks/mock_dbforpostgresql_flexible_server_administrator_client.go -package=mocks -source=dbforpostgresql-flexible-server-administrator-client.go + +// DBforPostgreSQLFlexibleServerAdministratorPager is a type alias for the generic Pager interface with administrator response type. +type DBforPostgreSQLFlexibleServerAdministratorPager = Pager[armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse] + +// DBforPostgreSQLFlexibleServerAdministratorClient is an interface for interacting with Azure PostgreSQL Flexible Server Administrators +type DBforPostgreSQLFlexibleServerAdministratorClient interface { + ListByServer(ctx context.Context, resourceGroupName string, serverName string) DBforPostgreSQLFlexibleServerAdministratorPager + Get(ctx context.Context, resourceGroupName string, serverName string, objectID string) (armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientGetResponse, error) +} + +type dbforPostgreSQLFlexibleServerAdministratorClient struct { + client *armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClient +} + +func (a *dbforPostgreSQLFlexibleServerAdministratorClient) ListByServer(ctx context.Context, resourceGroupName string, serverName string) DBforPostgreSQLFlexibleServerAdministratorPager { + return a.client.NewListByServerPager(resourceGroupName, serverName, nil) +} + +func (a *dbforPostgreSQLFlexibleServerAdministratorClient) Get(ctx context.Context, resourceGroupName string, serverName string, objectID string) (armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientGetResponse, error) { + return a.client.Get(ctx, resourceGroupName, serverName, objectID, nil) +} + +// NewDBforPostgreSQLFlexibleServerAdministratorClient creates a new DBforPostgreSQLFlexibleServerAdministratorClient from the Azure SDK client +func NewDBforPostgreSQLFlexibleServerAdministratorClient(client *armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClient) DBforPostgreSQLFlexibleServerAdministratorClient { + return &dbforPostgreSQLFlexibleServerAdministratorClient{client: client} +} diff --git a/sources/azure/clients/dbforpostgresql-flexible-server-virtual-endpoint-client.go b/sources/azure/clients/dbforpostgresql-flexible-server-virtual-endpoint-client.go new file mode 100644 index 00000000..b4b44bb8 --- /dev/null +++ b/sources/azure/clients/dbforpostgresql-flexible-server-virtual-endpoint-client.go @@ -0,0 +1,32 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" +) + +//go:generate mockgen -destination=../shared/mocks/mock_dbforpostgresql_flexible_server_virtual_endpoint_client.go -package=mocks -source=dbforpostgresql-flexible-server-virtual-endpoint-client.go + +type DBforPostgreSQLFlexibleServerVirtualEndpointPager = Pager[armpostgresqlflexibleservers.VirtualEndpointsClientListByServerResponse] + +type DBforPostgreSQLFlexibleServerVirtualEndpointClient interface { + ListByServer(ctx context.Context, resourceGroupName string, serverName string) DBforPostgreSQLFlexibleServerVirtualEndpointPager + Get(ctx context.Context, resourceGroupName string, serverName string, virtualEndpointName string) (armpostgresqlflexibleservers.VirtualEndpointsClientGetResponse, error) +} + +type dbforPostgreSQLFlexibleServerVirtualEndpointClient struct { + client *armpostgresqlflexibleservers.VirtualEndpointsClient +} + +func (a *dbforPostgreSQLFlexibleServerVirtualEndpointClient) ListByServer(ctx context.Context, resourceGroupName string, serverName string) DBforPostgreSQLFlexibleServerVirtualEndpointPager { + return a.client.NewListByServerPager(resourceGroupName, serverName, nil) +} + +func (a *dbforPostgreSQLFlexibleServerVirtualEndpointClient) Get(ctx context.Context, resourceGroupName string, serverName string, virtualEndpointName string) (armpostgresqlflexibleservers.VirtualEndpointsClientGetResponse, error) { + return a.client.Get(ctx, resourceGroupName, serverName, virtualEndpointName, nil) +} + +func NewDBforPostgreSQLFlexibleServerVirtualEndpointClient(client *armpostgresqlflexibleservers.VirtualEndpointsClient) DBforPostgreSQLFlexibleServerVirtualEndpointClient { + return &dbforPostgreSQLFlexibleServerVirtualEndpointClient{client: client} +} diff --git a/sources/azure/clients/elastic-san-volume-client.go b/sources/azure/clients/elastic-san-volume-client.go new file mode 100644 index 00000000..0547d90a --- /dev/null +++ b/sources/azure/clients/elastic-san-volume-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/elasticsan/armelasticsan" +) + +//go:generate mockgen -destination=../shared/mocks/mock_elastic_san_volume_client.go -package=mocks -source=elastic-san-volume-client.go + +// ElasticSanVolumePager is a type alias for the generic Pager interface with volume list response type. +type ElasticSanVolumePager = Pager[armelasticsan.VolumesClientListByVolumeGroupResponse] + +// ElasticSanVolumeClient is an interface for interacting with Azure Elastic SAN volumes. +type ElasticSanVolumeClient interface { + Get(ctx context.Context, resourceGroupName string, elasticSanName string, volumeGroupName string, volumeName string, options *armelasticsan.VolumesClientGetOptions) (armelasticsan.VolumesClientGetResponse, error) + NewListByVolumeGroupPager(resourceGroupName string, elasticSanName string, volumeGroupName string, options *armelasticsan.VolumesClientListByVolumeGroupOptions) ElasticSanVolumePager +} + +type elasticSanVolumeClient struct { + client *armelasticsan.VolumesClient +} + +func (c *elasticSanVolumeClient) Get(ctx context.Context, resourceGroupName string, elasticSanName string, volumeGroupName string, volumeName string, options *armelasticsan.VolumesClientGetOptions) (armelasticsan.VolumesClientGetResponse, error) { + return c.client.Get(ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, options) +} + +func (c *elasticSanVolumeClient) NewListByVolumeGroupPager(resourceGroupName string, elasticSanName string, volumeGroupName string, options *armelasticsan.VolumesClientListByVolumeGroupOptions) ElasticSanVolumePager { + return c.client.NewListByVolumeGroupPager(resourceGroupName, elasticSanName, volumeGroupName, options) +} + +// NewElasticSanVolumeClient creates a new ElasticSanVolumeClient from the Azure SDK client. +func NewElasticSanVolumeClient(client *armelasticsan.VolumesClient) ElasticSanVolumeClient { + return &elasticSanVolumeClient{client: client} +} diff --git a/sources/azure/clients/federated-identity-credentials-client.go b/sources/azure/clients/federated-identity-credentials-client.go new file mode 100644 index 00000000..cd0858ec --- /dev/null +++ b/sources/azure/clients/federated-identity-credentials-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" +) + +//go:generate mockgen -destination=../shared/mocks/mock_federated_identity_credentials_client.go -package=mocks -source=federated-identity-credentials-client.go + +// FederatedIdentityCredentialsPager is a pager for listing federated identity credentials +type FederatedIdentityCredentialsPager = Pager[armmsi.FederatedIdentityCredentialsClientListResponse] + +// FederatedIdentityCredentialsClient is the client interface for interacting with federated identity credentials +type FederatedIdentityCredentialsClient interface { + NewListPager(resourceGroupName string, resourceName string, options *armmsi.FederatedIdentityCredentialsClientListOptions) FederatedIdentityCredentialsPager + Get(ctx context.Context, resourceGroupName string, resourceName string, federatedIdentityCredentialResourceName string, options *armmsi.FederatedIdentityCredentialsClientGetOptions) (armmsi.FederatedIdentityCredentialsClientGetResponse, error) +} + +type federatedIdentityCredentialsClient struct { + client *armmsi.FederatedIdentityCredentialsClient +} + +func (f *federatedIdentityCredentialsClient) NewListPager(resourceGroupName string, resourceName string, options *armmsi.FederatedIdentityCredentialsClientListOptions) FederatedIdentityCredentialsPager { + return f.client.NewListPager(resourceGroupName, resourceName, options) +} + +func (f *federatedIdentityCredentialsClient) Get(ctx context.Context, resourceGroupName string, resourceName string, federatedIdentityCredentialResourceName string, options *armmsi.FederatedIdentityCredentialsClientGetOptions) (armmsi.FederatedIdentityCredentialsClientGetResponse, error) { + return f.client.Get(ctx, resourceGroupName, resourceName, federatedIdentityCredentialResourceName, options) +} + +// NewFederatedIdentityCredentialsClient creates a new FederatedIdentityCredentialsClient +func NewFederatedIdentityCredentialsClient(client *armmsi.FederatedIdentityCredentialsClient) FederatedIdentityCredentialsClient { + return &federatedIdentityCredentialsClient{client: client} +} diff --git a/sources/azure/clients/interface-ip-configurations-client.go b/sources/azure/clients/interface-ip-configurations-client.go new file mode 100644 index 00000000..cd1101e8 --- /dev/null +++ b/sources/azure/clients/interface-ip-configurations-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" +) + +//go:generate mockgen -destination=../shared/mocks/mock_interface_ip_configurations_client.go -package=mocks -source=interface-ip-configurations-client.go + +// InterfaceIPConfigurationsPager is a type alias for the generic Pager interface with InterfaceIPConfiguration response type. +type InterfaceIPConfigurationsPager = Pager[armnetwork.InterfaceIPConfigurationsClientListResponse] + +// InterfaceIPConfigurationsClient is an interface for interacting with Azure network interface IP configurations +type InterfaceIPConfigurationsClient interface { + Get(ctx context.Context, resourceGroupName string, networkInterfaceName string, ipConfigurationName string) (armnetwork.InterfaceIPConfigurationsClientGetResponse, error) + List(ctx context.Context, resourceGroupName string, networkInterfaceName string) InterfaceIPConfigurationsPager +} + +type interfaceIPConfigurationsClient struct { + client *armnetwork.InterfaceIPConfigurationsClient +} + +func (a *interfaceIPConfigurationsClient) Get(ctx context.Context, resourceGroupName string, networkInterfaceName string, ipConfigurationName string) (armnetwork.InterfaceIPConfigurationsClientGetResponse, error) { + return a.client.Get(ctx, resourceGroupName, networkInterfaceName, ipConfigurationName, nil) +} + +func (a *interfaceIPConfigurationsClient) List(ctx context.Context, resourceGroupName string, networkInterfaceName string) InterfaceIPConfigurationsPager { + return a.client.NewListPager(resourceGroupName, networkInterfaceName, nil) +} + +// NewInterfaceIPConfigurationsClient creates a new InterfaceIPConfigurationsClient from the Azure SDK client +func NewInterfaceIPConfigurationsClient(client *armnetwork.InterfaceIPConfigurationsClient) InterfaceIPConfigurationsClient { + return &interfaceIPConfigurationsClient{client: client} +} diff --git a/sources/azure/clients/ip-groups-client.go b/sources/azure/clients/ip-groups-client.go new file mode 100644 index 00000000..78bd0a8c --- /dev/null +++ b/sources/azure/clients/ip-groups-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" +) + +//go:generate mockgen -destination=../shared/mocks/mock_ip_groups_client.go -package=mocks -source=ip-groups-client.go + +// IPGroupsPager is a type alias for the generic Pager interface with IP groups response type. +type IPGroupsPager = Pager[armnetwork.IPGroupsClientListByResourceGroupResponse] + +// IPGroupsClient is an interface for interacting with Azure IP Groups. +type IPGroupsClient interface { + Get(ctx context.Context, resourceGroupName string, ipGroupsName string, options *armnetwork.IPGroupsClientGetOptions) (armnetwork.IPGroupsClientGetResponse, error) + NewListByResourceGroupPager(resourceGroupName string, options *armnetwork.IPGroupsClientListByResourceGroupOptions) IPGroupsPager +} + +type ipGroupsClient struct { + client *armnetwork.IPGroupsClient +} + +func (c *ipGroupsClient) Get(ctx context.Context, resourceGroupName string, ipGroupsName string, options *armnetwork.IPGroupsClientGetOptions) (armnetwork.IPGroupsClientGetResponse, error) { + return c.client.Get(ctx, resourceGroupName, ipGroupsName, options) +} + +func (c *ipGroupsClient) NewListByResourceGroupPager(resourceGroupName string, options *armnetwork.IPGroupsClientListByResourceGroupOptions) IPGroupsPager { + return c.client.NewListByResourceGroupPager(resourceGroupName, options) +} + +// NewIPGroupsClient creates a new IPGroupsClient from the Azure SDK client. +func NewIPGroupsClient(client *armnetwork.IPGroupsClient) IPGroupsClient { + return &ipGroupsClient{client: client} +} diff --git a/sources/azure/clients/local-network-gateways-client.go b/sources/azure/clients/local-network-gateways-client.go new file mode 100644 index 00000000..03f4bd33 --- /dev/null +++ b/sources/azure/clients/local-network-gateways-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" +) + +//go:generate mockgen -destination=../shared/mocks/mock_local_network_gateways_client.go -package=mocks -source=local-network-gateways-client.go + +// LocalNetworkGatewaysPager is a type alias for the generic Pager interface with local network gateway list response type. +type LocalNetworkGatewaysPager = Pager[armnetwork.LocalNetworkGatewaysClientListResponse] + +// LocalNetworkGatewaysClient is an interface for interacting with Azure local network gateways. +type LocalNetworkGatewaysClient interface { + Get(ctx context.Context, resourceGroupName string, localNetworkGatewayName string, options *armnetwork.LocalNetworkGatewaysClientGetOptions) (armnetwork.LocalNetworkGatewaysClientGetResponse, error) + NewListPager(resourceGroupName string, options *armnetwork.LocalNetworkGatewaysClientListOptions) LocalNetworkGatewaysPager +} + +type localNetworkGatewaysClient struct { + client *armnetwork.LocalNetworkGatewaysClient +} + +func (c *localNetworkGatewaysClient) Get(ctx context.Context, resourceGroupName string, localNetworkGatewayName string, options *armnetwork.LocalNetworkGatewaysClientGetOptions) (armnetwork.LocalNetworkGatewaysClientGetResponse, error) { + return c.client.Get(ctx, resourceGroupName, localNetworkGatewayName, options) +} + +func (c *localNetworkGatewaysClient) NewListPager(resourceGroupName string, options *armnetwork.LocalNetworkGatewaysClientListOptions) LocalNetworkGatewaysPager { + return c.client.NewListPager(resourceGroupName, options) +} + +// NewLocalNetworkGatewaysClient creates a new LocalNetworkGatewaysClient from the Azure SDK client. +func NewLocalNetworkGatewaysClient(client *armnetwork.LocalNetworkGatewaysClient) LocalNetworkGatewaysClient { + return &localNetworkGatewaysClient{client: client} +} diff --git a/sources/azure/clients/network-watchers-client.go b/sources/azure/clients/network-watchers-client.go new file mode 100644 index 00000000..582ea58b --- /dev/null +++ b/sources/azure/clients/network-watchers-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" +) + +//go:generate mockgen -destination=../shared/mocks/mock_network_watchers_client.go -package=mocks -source=network-watchers-client.go + +// NetworkWatchersPager is a type alias for the generic Pager interface with network watchers response type. +type NetworkWatchersPager = Pager[armnetwork.WatchersClientListResponse] + +// NetworkWatchersClient is an interface for interacting with Azure Network Watchers +type NetworkWatchersClient interface { + NewListPager(resourceGroupName string, options *armnetwork.WatchersClientListOptions) NetworkWatchersPager + Get(ctx context.Context, resourceGroupName string, networkWatcherName string, options *armnetwork.WatchersClientGetOptions) (armnetwork.WatchersClientGetResponse, error) +} + +type networkWatchersClient struct { + client *armnetwork.WatchersClient +} + +func (c *networkWatchersClient) NewListPager(resourceGroupName string, options *armnetwork.WatchersClientListOptions) NetworkWatchersPager { + return c.client.NewListPager(resourceGroupName, options) +} + +func (c *networkWatchersClient) Get(ctx context.Context, resourceGroupName string, networkWatcherName string, options *armnetwork.WatchersClientGetOptions) (armnetwork.WatchersClientGetResponse, error) { + return c.client.Get(ctx, resourceGroupName, networkWatcherName, options) +} + +// NewNetworkWatchersClient creates a new NetworkWatchersClient from the Azure SDK client +func NewNetworkWatchersClient(client *armnetwork.WatchersClient) NetworkWatchersClient { + return &networkWatchersClient{client: client} +} diff --git a/sources/azure/clients/operational-insights-workspace-client.go b/sources/azure/clients/operational-insights-workspace-client.go new file mode 100644 index 00000000..04b56120 --- /dev/null +++ b/sources/azure/clients/operational-insights-workspace-client.go @@ -0,0 +1,36 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights" +) + +//go:generate mockgen -destination=../shared/mocks/mock_operational_insights_workspace_client.go -package=mocks -source=operational-insights-workspace-client.go + +// OperationalInsightsWorkspacePager is a type alias for the generic Pager interface with workspace response type. +// This uses the generic Pager[T] interface to avoid code duplication. +type OperationalInsightsWorkspacePager = Pager[armoperationalinsights.WorkspacesClientListByResourceGroupResponse] + +// OperationalInsightsWorkspaceClient is an interface for interacting with Azure Log Analytics Workspaces +type OperationalInsightsWorkspaceClient interface { + NewListByResourceGroupPager(resourceGroupName string, options *armoperationalinsights.WorkspacesClientListByResourceGroupOptions) OperationalInsightsWorkspacePager + Get(ctx context.Context, resourceGroupName string, workspaceName string, options *armoperationalinsights.WorkspacesClientGetOptions) (armoperationalinsights.WorkspacesClientGetResponse, error) +} + +type operationalInsightsWorkspaceClient struct { + client *armoperationalinsights.WorkspacesClient +} + +func (o *operationalInsightsWorkspaceClient) NewListByResourceGroupPager(resourceGroupName string, options *armoperationalinsights.WorkspacesClientListByResourceGroupOptions) OperationalInsightsWorkspacePager { + return o.client.NewListByResourceGroupPager(resourceGroupName, options) +} + +func (o *operationalInsightsWorkspaceClient) Get(ctx context.Context, resourceGroupName string, workspaceName string, options *armoperationalinsights.WorkspacesClientGetOptions) (armoperationalinsights.WorkspacesClientGetResponse, error) { + return o.client.Get(ctx, resourceGroupName, workspaceName, options) +} + +// NewOperationalInsightsWorkspaceClient creates a new OperationalInsightsWorkspaceClient from the Azure SDK client +func NewOperationalInsightsWorkspaceClient(client *armoperationalinsights.WorkspacesClient) OperationalInsightsWorkspaceClient { + return &operationalInsightsWorkspaceClient{client: client} +} diff --git a/sources/azure/clients/private-link-services-client.go b/sources/azure/clients/private-link-services-client.go new file mode 100644 index 00000000..1065cf4c --- /dev/null +++ b/sources/azure/clients/private-link-services-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" +) + +//go:generate mockgen -destination=../shared/mocks/mock_private_link_services_client.go -package=mocks -source=private-link-services-client.go + +// PrivateLinkServicesPager is a type alias for the generic Pager interface with private link service response type. +type PrivateLinkServicesPager = Pager[armnetwork.PrivateLinkServicesClientListResponse] + +// PrivateLinkServicesClient is an interface for interacting with Azure private link services. +type PrivateLinkServicesClient interface { + Get(ctx context.Context, resourceGroupName string, serviceName string) (armnetwork.PrivateLinkServicesClientGetResponse, error) + List(resourceGroupName string) PrivateLinkServicesPager +} + +type privateLinkServicesClient struct { + client *armnetwork.PrivateLinkServicesClient +} + +func (c *privateLinkServicesClient) Get(ctx context.Context, resourceGroupName string, serviceName string) (armnetwork.PrivateLinkServicesClientGetResponse, error) { + return c.client.Get(ctx, resourceGroupName, serviceName, nil) +} + +func (c *privateLinkServicesClient) List(resourceGroupName string) PrivateLinkServicesPager { + return c.client.NewListPager(resourceGroupName, nil) +} + +// NewPrivateLinkServicesClient creates a new PrivateLinkServicesClient from the Azure SDK client. +func NewPrivateLinkServicesClient(client *armnetwork.PrivateLinkServicesClient) PrivateLinkServicesClient { + return &privateLinkServicesClient{client: client} +} diff --git a/sources/azure/clients/role-definitions-client.go b/sources/azure/clients/role-definitions-client.go new file mode 100644 index 00000000..8fc2a9a6 --- /dev/null +++ b/sources/azure/clients/role-definitions-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" +) + +//go:generate mockgen -destination=../shared/mocks/mock_role_definitions_client.go -package=mocks -source=role-definitions-client.go + +// RoleDefinitionsPager is a type alias for the generic Pager interface with role definition response type. +type RoleDefinitionsPager = Pager[armauthorization.RoleDefinitionsClientListResponse] + +// RoleDefinitionsClient is an interface for interacting with Azure role definitions +type RoleDefinitionsClient interface { + NewListPager(scope string, options *armauthorization.RoleDefinitionsClientListOptions) RoleDefinitionsPager + Get(ctx context.Context, scope string, roleDefinitionID string, options *armauthorization.RoleDefinitionsClientGetOptions) (armauthorization.RoleDefinitionsClientGetResponse, error) +} + +type roleDefinitionsClient struct { + client *armauthorization.RoleDefinitionsClient +} + +func (c *roleDefinitionsClient) NewListPager(scope string, options *armauthorization.RoleDefinitionsClientListOptions) RoleDefinitionsPager { + return c.client.NewListPager(scope, options) +} + +func (c *roleDefinitionsClient) Get(ctx context.Context, scope string, roleDefinitionID string, options *armauthorization.RoleDefinitionsClientGetOptions) (armauthorization.RoleDefinitionsClientGetResponse, error) { + return c.client.Get(ctx, scope, roleDefinitionID, options) +} + +// NewRoleDefinitionsClient creates a new RoleDefinitionsClient from the Azure SDK client +func NewRoleDefinitionsClient(client *armauthorization.RoleDefinitionsClient) RoleDefinitionsClient { + return &roleDefinitionsClient{client: client} +} diff --git a/sources/azure/clients/sql-failover-groups-client.go b/sources/azure/clients/sql-failover-groups-client.go new file mode 100644 index 00000000..73a41b86 --- /dev/null +++ b/sources/azure/clients/sql-failover-groups-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" +) + +//go:generate mockgen -destination=../shared/mocks/mock_sql_failover_groups_client.go -package=mocks -source=sql-failover-groups-client.go + +// SqlFailoverGroupsPager is a type alias for the generic Pager interface with failover groups response type. +type SqlFailoverGroupsPager = Pager[armsql.FailoverGroupsClientListByServerResponse] + +// SqlFailoverGroupsClient is an interface for interacting with Azure SQL Server Failover Groups +type SqlFailoverGroupsClient interface { + ListByServer(ctx context.Context, resourceGroupName string, serverName string) SqlFailoverGroupsPager + Get(ctx context.Context, resourceGroupName string, serverName string, failoverGroupName string) (armsql.FailoverGroupsClientGetResponse, error) +} + +type sqlFailoverGroupsClient struct { + client *armsql.FailoverGroupsClient +} + +func (a *sqlFailoverGroupsClient) ListByServer(ctx context.Context, resourceGroupName string, serverName string) SqlFailoverGroupsPager { + return a.client.NewListByServerPager(resourceGroupName, serverName, nil) +} + +func (a *sqlFailoverGroupsClient) Get(ctx context.Context, resourceGroupName string, serverName string, failoverGroupName string) (armsql.FailoverGroupsClientGetResponse, error) { + return a.client.Get(ctx, resourceGroupName, serverName, failoverGroupName, nil) +} + +// NewSqlFailoverGroupsClient creates a new SqlFailoverGroupsClient from the Azure SDK client +func NewSqlFailoverGroupsClient(client *armsql.FailoverGroupsClient) SqlFailoverGroupsClient { + return &sqlFailoverGroupsClient{client: client} +} diff --git a/sources/azure/clients/sql-server-keys-client.go b/sources/azure/clients/sql-server-keys-client.go new file mode 100644 index 00000000..9481175d --- /dev/null +++ b/sources/azure/clients/sql-server-keys-client.go @@ -0,0 +1,35 @@ +package clients + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" +) + +//go:generate mockgen -destination=../shared/mocks/mock_sql_server_keys_client.go -package=mocks -source=sql-server-keys-client.go + +// SqlServerKeysPager is a type alias for the generic Pager interface with sql server keys response type. +type SqlServerKeysPager = Pager[armsql.ServerKeysClientListByServerResponse] + +// SqlServerKeysClient is an interface for interacting with Azure SQL server keys +type SqlServerKeysClient interface { + NewListByServerPager(resourceGroupName string, serverName string) SqlServerKeysPager + Get(ctx context.Context, resourceGroupName string, serverName string, keyName string) (armsql.ServerKeysClientGetResponse, error) +} + +type sqlServerKeysClient struct { + client *armsql.ServerKeysClient +} + +func (a *sqlServerKeysClient) NewListByServerPager(resourceGroupName string, serverName string) SqlServerKeysPager { + return a.client.NewListByServerPager(resourceGroupName, serverName, nil) +} + +func (a *sqlServerKeysClient) Get(ctx context.Context, resourceGroupName string, serverName string, keyName string) (armsql.ServerKeysClientGetResponse, error) { + return a.client.Get(ctx, resourceGroupName, serverName, keyName, nil) +} + +// NewSqlServerKeysClient creates a new SqlServerKeysClient from the Azure SDK client +func NewSqlServerKeysClient(client *armsql.ServerKeysClient) SqlServerKeysClient { + return &sqlServerKeysClient{client: client} +} diff --git a/sources/azure/cmd/root.go b/sources/azure/cmd/root.go index ae641c37..bf21f327 100644 --- a/sources/azure/cmd/root.go +++ b/sources/azure/cmd/root.go @@ -16,8 +16,8 @@ import ( "github.com/spf13/viper" "github.com/overmindtech/cli/go/discovery" - "github.com/overmindtech/cli/sources/azure/proc" "github.com/overmindtech/cli/go/tracing" + "github.com/overmindtech/cli/sources/azure/proc" ) var cfgFile string diff --git a/sources/azure/integration-tests/authorization-role-definition_test.go b/sources/azure/integration-tests/authorization-role-definition_test.go new file mode 100644 index 00000000..c179e16c --- /dev/null +++ b/sources/azure/integration-tests/authorization-role-definition_test.go @@ -0,0 +1,286 @@ +package integrationtests + +import ( + "fmt" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" +) + +func TestAuthorizationRoleDefinitionIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + roleDefinitionsClient, err := armauthorization.NewRoleDefinitionsClient(cred, nil) + if err != nil { + t.Fatalf("Failed to create Role Definitions client: %v", err) + } + + // Use a built-in role definition ID that always exists: "Reader" + // The Reader role ID is the same across all Azure subscriptions + readerRoleDefinitionID := "acdd72a7-3385-48ef-bd42-f606fba81ae7" + + t.Run("Setup", func(t *testing.T) { + // No setup required for role definitions - they are built-in Azure resources + log.Printf("Using built-in Reader role definition ID: %s", readerRoleDefinitionID) + }) + + t.Run("Run", func(t *testing.T) { + t.Run("GetRoleDefinition", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving role definition %s", readerRoleDefinitionID) + + wrapper := manual.NewAuthorizationRoleDefinition( + clients.NewRoleDefinitionsClient(roleDefinitionsClient), + subscriptionID, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := adapter.Get(ctx, scope, readerRoleDefinitionID, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + if sdpItem.GetType() != azureshared.AuthorizationRoleDefinition.String() { + t.Errorf("Expected type %s, got %s", azureshared.AuthorizationRoleDefinition.String(), sdpItem.GetType()) + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + if uniqueAttrKey != "name" { + t.Errorf("Expected unique attribute 'name', got %s", uniqueAttrKey) + } + + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + if uniqueAttrValue != readerRoleDefinitionID { + t.Errorf("Expected unique attribute value %s, got %s", readerRoleDefinitionID, uniqueAttrValue) + } + + if sdpItem.GetScope() != subscriptionID { + t.Errorf("Expected scope %s, got %s", subscriptionID, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Successfully retrieved role definition %s", readerRoleDefinitionID) + }) + + t.Run("ListRoleDefinitions", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Listing role definitions in subscription %s", subscriptionID) + + wrapper := manual.NewAuthorizationRoleDefinition( + clients.NewRoleDefinitionsClient(roleDefinitionsClient), + subscriptionID, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Failed to list role definitions: %v", err) + } + + // Azure has many built-in role definitions, expect at least a few + if len(sdpItems) < 5 { + t.Fatalf("Expected at least 5 role definitions, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil { + if v == readerRoleDefinitionID { + found = true + break + } + } + } + + if !found { + t.Fatalf("Expected to find Reader role definition %s in the list results", readerRoleDefinitionID) + } + + log.Printf("Found %d role definitions in list results", len(sdpItems)) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying item attributes for role definition %s", readerRoleDefinitionID) + + wrapper := manual.NewAuthorizationRoleDefinition( + clients.NewRoleDefinitionsClient(roleDefinitionsClient), + subscriptionID, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := adapter.Get(ctx, scope, readerRoleDefinitionID, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify item type + if sdpItem.GetType() != azureshared.AuthorizationRoleDefinition.String() { + t.Errorf("Expected item type %s, got %s", azureshared.AuthorizationRoleDefinition.String(), sdpItem.GetType()) + } + + // Verify scope + if sdpItem.GetScope() != subscriptionID { + t.Errorf("Expected scope %s, got %s", subscriptionID, sdpItem.GetScope()) + } + + // Verify unique attribute + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + // Verify item validation + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + // Verify role name is Reader + roleName, err := sdpItem.GetAttributes().Get("properties.roleName") + if err != nil { + t.Logf("Warning: Could not get roleName attribute: %v", err) + } else if roleName != "Reader" { + t.Errorf("Expected role name 'Reader', got %s", roleName) + } + + log.Printf("Verified item attributes for role definition %s", readerRoleDefinitionID) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for role definition %s", readerRoleDefinitionID) + + wrapper := manual.NewAuthorizationRoleDefinition( + clients.NewRoleDefinitionsClient(roleDefinitionsClient), + subscriptionID, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := adapter.Get(ctx, scope, readerRoleDefinitionID, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Role definitions link to AssignableScopes (subscriptions and resource groups) + // The built-in Reader role has "/" as its assignable scope, which may not produce links + // Custom roles would have specific subscription/resource group scopes + linkedQueries := sdpItem.GetLinkedItemQueries() + + // Verify each linked query has proper attributes + for _, linkedQuery := range linkedQueries { + query := linkedQuery.GetQuery() + if query.GetType() == "" { + t.Error("Linked item query has empty Type") + } + if query.GetMethod() != sdp.QueryMethod_GET && query.GetMethod() != sdp.QueryMethod_SEARCH { + t.Errorf("Linked item query has invalid Method: %v", query.GetMethod()) + } + if query.GetQuery() == "" { + t.Error("Linked item query has empty Query") + } + if query.GetScope() == "" { + t.Error("Linked item query has empty Scope") + } + + // Verify linked types are either subscription or resource group + validTypes := map[string]bool{ + azureshared.ResourcesSubscription.String(): true, + azureshared.ResourcesResourceGroup.String(): true, + } + if !validTypes[query.GetType()] { + t.Errorf("Unexpected linked item type: %s", query.GetType()) + } + } + + log.Printf("Verified linked items for role definition %s (found %d linked queries)", readerRoleDefinitionID, len(linkedQueries)) + }) + + t.Run("VerifyBuiltInRoles", func(t *testing.T) { + ctx := t.Context() + + // Verify some well-known built-in role definitions exist + builtInRoles := map[string]string{ + "acdd72a7-3385-48ef-bd42-f606fba81ae7": "Reader", + "b24988ac-6180-42a0-ab88-20f7382dd24c": "Contributor", + "8e3af657-a8ff-443c-a75c-2fe8c4bcb635": "Owner", + } + + wrapper := manual.NewAuthorizationRoleDefinition( + clients.NewRoleDefinitionsClient(roleDefinitionsClient), + subscriptionID, + ) + scope := wrapper.Scopes()[0] + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + for roleID, roleName := range builtInRoles { + t.Run(fmt.Sprintf("Get%sRole", roleName), func(t *testing.T) { + sdpItem, qErr := adapter.Get(ctx, scope, roleID, true) + if qErr != nil { + t.Fatalf("Failed to get %s role definition: %v", roleName, qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected %s role definition to be non-nil", roleName) + } + + actualRoleName, err := sdpItem.GetAttributes().Get("properties.roleName") + if err != nil { + t.Logf("Warning: Could not get roleName attribute for %s: %v", roleName, err) + } else if actualRoleName != roleName { + t.Errorf("Expected role name '%s', got '%s'", roleName, actualRoleName) + } + + log.Printf("Successfully verified built-in role: %s (ID: %s)", roleName, roleID) + }) + } + }) + }) + + t.Run("Teardown", func(t *testing.T) { + // No teardown required - role definitions are built-in Azure resources + log.Printf("No teardown required for role definitions (built-in Azure resources)") + }) +} diff --git a/sources/azure/integration-tests/batch-private-endpoint-connection_test.go b/sources/azure/integration-tests/batch-private-endpoint-connection_test.go new file mode 100644 index 00000000..307daaca --- /dev/null +++ b/sources/azure/integration-tests/batch-private-endpoint-connection_test.go @@ -0,0 +1,627 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/batch/armbatch/v4" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage/v3" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +const ( + integrationTestBatchPECAccountName = "ovm-integ-test-bpec" + integrationTestBatchPECSAName = "ovm-integ-test-sa-bpec" + integrationTestBatchPECVNetName = "ovm-integ-test-vnet-bpec" + integrationTestBatchPECSubnetName = "ovm-integ-test-subnet-bpec" + integrationTestBatchPECPEName = "ovm-integ-test-pe-bpec" +) + +func TestBatchPrivateEndpointConnectionIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + batchClient, err := armbatch.NewAccountClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Batch Account client: %v", err) + } + + pecClient, err := armbatch.NewPrivateEndpointConnectionClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Batch Private Endpoint Connection client: %v", err) + } + + saClient, err := armstorage.NewAccountsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Storage Accounts client: %v", err) + } + + vnetClient, err := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Virtual Networks client: %v", err) + } + + subnetClient, err := armnetwork.NewSubnetsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Subnets client: %v", err) + } + + peClient, err := armnetwork.NewPrivateEndpointsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Private Endpoints client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + batchAccountName := generateBatchAccountName(integrationTestBatchPECAccountName) + storageAccountName := generateStorageAccountName(integrationTestBatchPECSAName) + vnetName := integrationTestBatchPECVNetName + subnetName := integrationTestBatchPECSubnetName + peName := integrationTestBatchPECPEName + + setupCompleted := false + var privateEndpointConnectionName string + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + err = createStorageAccount(ctx, saClient, integrationTestResourceGroup, storageAccountName, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create storage account: %v", err) + } + + err = waitForStorageAccountAvailable(ctx, saClient, integrationTestResourceGroup, storageAccountName) + if err != nil { + t.Fatalf("Failed waiting for storage account to be available: %v", err) + } + + saResp, err := saClient.GetProperties(ctx, integrationTestResourceGroup, storageAccountName, nil) + if err != nil { + t.Fatalf("Failed to get storage account properties: %v", err) + } + storageAccountID := *saResp.ID + + err = createBatchAccountWithPrivateEndpointPolicy(ctx, batchClient, integrationTestResourceGroup, batchAccountName, integrationTestLocation, storageAccountID) + if err != nil { + if errors.Is(err, errBatchQuotaExceeded) { + t.Skipf("Skipping Batch private endpoint connection integration test due to Azure subscription quota: %v", err) + } + t.Fatalf("Failed to create batch account: %v", err) + } + + err = waitForBatchAccountAvailable(ctx, batchClient, integrationTestResourceGroup, batchAccountName) + if err != nil { + t.Fatalf("Failed waiting for batch account to be available: %v", err) + } + + err = createVNetForBatchPEC(ctx, vnetClient, integrationTestResourceGroup, vnetName, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create VNet: %v", err) + } + + err = createSubnetForBatchPEC(ctx, subnetClient, integrationTestResourceGroup, vnetName, subnetName) + if err != nil { + t.Fatalf("Failed to create subnet: %v", err) + } + + batchResp, err := batchClient.Get(ctx, integrationTestResourceGroup, batchAccountName, nil) + if err != nil { + t.Fatalf("Failed to get batch account: %v", err) + } + batchAccountID := *batchResp.ID + + err = createPrivateEndpointForBatch(ctx, peClient, integrationTestResourceGroup, peName, integrationTestLocation, batchAccountID, vnetName, subnetName) + if err != nil { + t.Fatalf("Failed to create private endpoint: %v", err) + } + + privateEndpointConnectionName, err = waitForBatchPrivateEndpointConnection(ctx, pecClient, integrationTestResourceGroup, batchAccountName) + if err != nil { + t.Fatalf("Failed waiting for private endpoint connection: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetPrivateEndpointConnection", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving batch private endpoint connection %s in account %s", privateEndpointConnectionName, batchAccountName) + + pecWrapper := manual.NewBatchPrivateEndpointConnection( + clients.NewBatchPrivateEndpointConnectionClient(pecClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := pecWrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(pecWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(batchAccountName, privateEndpointConnectionName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + if sdpItem.GetType() != azureshared.BatchBatchPrivateEndpointConnection.String() { + t.Errorf("Expected type %s, got %s", azureshared.BatchBatchPrivateEndpointConnection, sdpItem.GetType()) + } + + expectedUniqueAttr := shared.CompositeLookupKey(batchAccountName, privateEndpointConnectionName) + if sdpItem.UniqueAttributeValue() != expectedUniqueAttr { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueAttr, sdpItem.UniqueAttributeValue()) + } + + log.Printf("Successfully retrieved private endpoint connection %s", privateEndpointConnectionName) + }) + + t.Run("SearchPrivateEndpointConnections", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Searching private endpoint connections in batch account %s", batchAccountName) + + pecWrapper := manual.NewBatchPrivateEndpointConnection( + clients.NewBatchPrivateEndpointConnectionClient(pecClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := pecWrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(pecWrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, scope, batchAccountName, true) + if err != nil { + t.Fatalf("Failed to search private endpoint connections: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one private endpoint connection, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == shared.CompositeLookupKey(batchAccountName, privateEndpointConnectionName) { + found = true + break + } + } + + if !found { + t.Fatalf("Expected to find private endpoint connection %s in the search results", privateEndpointConnectionName) + } + + log.Printf("Found %d private endpoint connections in search results", len(sdpItems)) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for private endpoint connection %s", privateEndpointConnectionName) + + pecWrapper := manual.NewBatchPrivateEndpointConnection( + clients.NewBatchPrivateEndpointConnectionClient(pecClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := pecWrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(pecWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(batchAccountName, privateEndpointConnectionName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + for _, liq := range linkedQueries { + query := liq.GetQuery() + if query.GetType() == "" { + t.Error("LinkedItemQuery has empty Type") + } + if query.GetMethod() != sdp.QueryMethod_GET && query.GetMethod() != sdp.QueryMethod_SEARCH { + t.Errorf("LinkedItemQuery has invalid Method: %v", query.GetMethod()) + } + if query.GetQuery() == "" { + t.Error("LinkedItemQuery has empty Query") + } + if query.GetScope() == "" { + t.Error("LinkedItemQuery has empty Scope") + } + } + + var hasBatchAccountLink bool + for _, liq := range linkedQueries { + if liq.GetQuery().GetType() == azureshared.BatchBatchAccount.String() { + hasBatchAccountLink = true + if liq.GetQuery().GetQuery() != batchAccountName { + t.Errorf("Expected linked query to batch account %s, got %s", batchAccountName, liq.GetQuery().GetQuery()) + } + break + } + } + + if !hasBatchAccountLink { + t.Error("Expected linked query to batch account, but didn't find one") + } + + log.Printf("Verified %d linked item queries for private endpoint connection %s", len(linkedQueries), privateEndpointConnectionName) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + pecWrapper := manual.NewBatchPrivateEndpointConnection( + clients.NewBatchPrivateEndpointConnectionClient(pecClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := pecWrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(pecWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(batchAccountName, privateEndpointConnectionName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.BatchBatchPrivateEndpointConnection.String() { + t.Errorf("Expected type %s, got %s", azureshared.BatchBatchPrivateEndpointConnection, sdpItem.GetType()) + } + + expectedScope := subscriptionID + "." + integrationTestResourceGroup + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Errorf("Item validation failed: %v", err) + } + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + err := deletePrivateEndpointForBatch(ctx, peClient, integrationTestResourceGroup, peName) + if err != nil { + t.Errorf("Failed to delete private endpoint: %v", err) + } + + err = deleteBatchAccount(ctx, batchClient, integrationTestResourceGroup, batchAccountName) + if err != nil { + t.Errorf("Failed to delete batch account: %v", err) + } + + err = deleteSubnetForBatchPEC(ctx, subnetClient, integrationTestResourceGroup, vnetName, subnetName) + if err != nil { + t.Errorf("Failed to delete subnet: %v", err) + } + + err = deleteVNetForBatchPEC(ctx, vnetClient, integrationTestResourceGroup, vnetName) + if err != nil { + t.Errorf("Failed to delete VNet: %v", err) + } + + err = deleteStorageAccount(ctx, saClient, integrationTestResourceGroup, storageAccountName) + if err != nil { + t.Errorf("Failed to delete storage account: %v", err) + } + }) +} + +func createBatchAccountWithPrivateEndpointPolicy(ctx context.Context, client *armbatch.AccountClient, resourceGroupName, accountName, location, storageAccountID string) error { + _, err := client.Get(ctx, resourceGroupName, accountName, nil) + if err == nil { + log.Printf("Batch account %s already exists, skipping creation", accountName) + return nil + } + + publicNetworkDisabled := armbatch.PublicNetworkAccessTypeDisabled + + poller, err := client.BeginCreate(ctx, resourceGroupName, accountName, armbatch.AccountCreateParameters{ + Location: new(location), + Properties: &armbatch.AccountCreateProperties{ + AutoStorage: &armbatch.AutoStorageBaseProperties{ + StorageAccountID: new(storageAccountID), + }, + PoolAllocationMode: new(armbatch.PoolAllocationModeBatchService), + PublicNetworkAccess: &publicNetworkDisabled, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "test": new("batch-private-endpoint-connection"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + if respErr.StatusCode == http.StatusConflict { + log.Printf("Batch account %s already exists (conflict), skipping creation", accountName) + return nil + } + if respErr.ErrorCode == "SubscriptionQuotaExceeded" { + return fmt.Errorf("%w: %s", errBatchQuotaExceeded, respErr.Error()) + } + } + return fmt.Errorf("failed to begin creating batch account: %w", err) + } + + resp, err := poller.PollUntilDone(ctx, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.ErrorCode == "SubscriptionQuotaExceeded" { + return fmt.Errorf("%w: %s", errBatchQuotaExceeded, respErr.Error()) + } + return fmt.Errorf("failed to create batch account: %w", err) + } + + if resp.Properties == nil || resp.Properties.ProvisioningState == nil { + return fmt.Errorf("batch account created but provisioning state is unknown") + } + + provisioningState := *resp.Properties.ProvisioningState + if provisioningState != armbatch.ProvisioningStateSucceeded { + return fmt.Errorf("batch account provisioning state is %s, expected %s", provisioningState, armbatch.ProvisioningStateSucceeded) + } + + log.Printf("Batch account %s created successfully with private endpoint support", accountName) + return nil +} + +func createVNetForBatchPEC(ctx context.Context, client *armnetwork.VirtualNetworksClient, resourceGroupName, vnetName, location string) error { + _, err := client.Get(ctx, resourceGroupName, vnetName, nil) + if err == nil { + log.Printf("VNet %s already exists, skipping creation", vnetName) + return nil + } + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, vnetName, armnetwork.VirtualNetwork{ + Location: new(location), + Properties: &armnetwork.VirtualNetworkPropertiesFormat{ + AddressSpace: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{new("10.0.0.0/16")}, + }, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + log.Printf("VNet %s already exists (conflict), skipping creation", vnetName) + return nil + } + return fmt.Errorf("failed to begin creating VNet: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create VNet: %w", err) + } + + log.Printf("VNet %s created successfully", vnetName) + return nil +} + +func createSubnetForBatchPEC(ctx context.Context, client *armnetwork.SubnetsClient, resourceGroupName, vnetName, subnetName string) error { + _, err := client.Get(ctx, resourceGroupName, vnetName, subnetName, nil) + if err == nil { + log.Printf("Subnet %s already exists, skipping creation", subnetName) + return nil + } + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, vnetName, subnetName, armnetwork.Subnet{ + Properties: &armnetwork.SubnetPropertiesFormat{ + AddressPrefix: new("10.0.1.0/24"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + log.Printf("Subnet %s already exists (conflict), skipping creation", subnetName) + return nil + } + return fmt.Errorf("failed to begin creating subnet: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create subnet: %w", err) + } + + log.Printf("Subnet %s created successfully", subnetName) + return nil +} + +func createPrivateEndpointForBatch(ctx context.Context, client *armnetwork.PrivateEndpointsClient, resourceGroupName, peName, location, batchAccountID, vnetName, subnetName string) error { + _, err := client.Get(ctx, resourceGroupName, peName, nil) + if err == nil { + log.Printf("Private endpoint %s already exists, skipping creation", peName) + return nil + } + + subnetID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/virtualNetworks/%s/subnets/%s", + os.Getenv("AZURE_SUBSCRIPTION_ID"), resourceGroupName, vnetName, subnetName) + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, peName, armnetwork.PrivateEndpoint{ + Location: new(location), + Properties: &armnetwork.PrivateEndpointProperties{ + Subnet: &armnetwork.Subnet{ + ID: new(subnetID), + }, + PrivateLinkServiceConnections: []*armnetwork.PrivateLinkServiceConnection{ + { + Name: new(peName + "-connection"), + Properties: &armnetwork.PrivateLinkServiceConnectionProperties{ + PrivateLinkServiceID: new(batchAccountID), + GroupIDs: []*string{new("batchAccount")}, + }, + }, + }, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + log.Printf("Private endpoint %s already exists (conflict), skipping creation", peName) + return nil + } + return fmt.Errorf("failed to begin creating private endpoint: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create private endpoint: %w", err) + } + + log.Printf("Private endpoint %s created successfully", peName) + return nil +} + +func waitForBatchPrivateEndpointConnection(ctx context.Context, client *armbatch.PrivateEndpointConnectionClient, resourceGroupName, accountName string) (string, error) { + maxAttempts := 30 + pollInterval := 10 * time.Second + + log.Printf("Waiting for private endpoint connection on batch account %s...", accountName) + + for attempt := 1; attempt <= maxAttempts; attempt++ { + pager := client.NewListByBatchAccountPager(resourceGroupName, accountName, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + log.Printf("Error listing private endpoint connections (attempt %d/%d): %v", attempt, maxAttempts, err) + break + } + for _, conn := range page.Value { + if conn != nil && conn.Name != nil { + log.Printf("Found private endpoint connection: %s", *conn.Name) + return *conn.Name, nil + } + } + } + log.Printf("No private endpoint connections found yet (attempt %d/%d), waiting...", attempt, maxAttempts) + time.Sleep(pollInterval) + } + + return "", fmt.Errorf("timeout waiting for private endpoint connection on batch account %s", accountName) +} + +func deletePrivateEndpointForBatch(ctx context.Context, client *armnetwork.PrivateEndpointsClient, resourceGroupName, peName string) error { + log.Printf("Deleting private endpoint %s...", peName) + + poller, err := client.BeginDelete(ctx, resourceGroupName, peName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Private endpoint %s not found, skipping deletion", peName) + return nil + } + return fmt.Errorf("failed to begin deleting private endpoint: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete private endpoint: %w", err) + } + + log.Printf("Private endpoint %s deleted successfully", peName) + return nil +} + +func deleteSubnetForBatchPEC(ctx context.Context, client *armnetwork.SubnetsClient, resourceGroupName, vnetName, subnetName string) error { + log.Printf("Deleting subnet %s...", subnetName) + + poller, err := client.BeginDelete(ctx, resourceGroupName, vnetName, subnetName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Subnet %s not found, skipping deletion", subnetName) + return nil + } + return fmt.Errorf("failed to begin deleting subnet: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete subnet: %w", err) + } + + log.Printf("Subnet %s deleted successfully", subnetName) + return nil +} + +func deleteVNetForBatchPEC(ctx context.Context, client *armnetwork.VirtualNetworksClient, resourceGroupName, vnetName string) error { + log.Printf("Deleting VNet %s...", vnetName) + + poller, err := client.BeginDelete(ctx, resourceGroupName, vnetName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("VNet %s not found, skipping deletion", vnetName) + return nil + } + return fmt.Errorf("failed to begin deleting VNet: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete VNet: %w", err) + } + + log.Printf("VNet %s deleted successfully", vnetName) + return nil +} diff --git a/sources/azure/integration-tests/dbforpostgresql-flexible-server-administrator_test.go b/sources/azure/integration-tests/dbforpostgresql-flexible-server-administrator_test.go new file mode 100644 index 00000000..66356ce6 --- /dev/null +++ b/sources/azure/integration-tests/dbforpostgresql-flexible-server-administrator_test.go @@ -0,0 +1,474 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +const ( + integrationTestPGAdminServerName = "ovm-integ-test-pg-admin" +) + +func TestDBforPostgreSQLFlexibleServerAdministratorIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + adminLogin := os.Getenv("AZURE_POSTGRESQL_SERVER_ADMIN_LOGIN") + adminPassword := os.Getenv("AZURE_POSTGRESQL_SERVER_ADMIN_PASSWORD") + if adminLogin == "" || adminPassword == "" { + t.Skip("AZURE_POSTGRESQL_SERVER_ADMIN_LOGIN and AZURE_POSTGRESQL_SERVER_ADMIN_PASSWORD must be set for PostgreSQL tests") + } + + entraAdminObjectID := os.Getenv("AZURE_POSTGRESQL_ENTRA_ADMIN_OBJECT_ID") + entraAdminPrincipalName := os.Getenv("AZURE_POSTGRESQL_ENTRA_ADMIN_PRINCIPAL_NAME") + entraAdminTenantID := os.Getenv("AZURE_POSTGRESQL_ENTRA_ADMIN_TENANT_ID") + + if entraAdminObjectID == "" || entraAdminPrincipalName == "" || entraAdminTenantID == "" { + t.Skip("AZURE_POSTGRESQL_ENTRA_ADMIN_OBJECT_ID, AZURE_POSTGRESQL_ENTRA_ADMIN_PRINCIPAL_NAME, and AZURE_POSTGRESQL_ENTRA_ADMIN_TENANT_ID must be set for PostgreSQL Administrator tests") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + postgreSQLServerClient, err := armpostgresqlflexibleservers.NewServersClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create PostgreSQL Flexible Servers client: %v", err) + } + + administratorsClient, err := armpostgresqlflexibleservers.NewAdministratorsMicrosoftEntraClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create PostgreSQL Administrators client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + pgServerName := generatePostgreSQLServerName(integrationTestPGAdminServerName) + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + err = createPostgreSQLFlexibleServerWithMicrosoftEntraAuth(ctx, postgreSQLServerClient, integrationTestResourceGroup, pgServerName, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create PostgreSQL Flexible Server: %v", err) + } + + err = waitForPostgreSQLServerAvailable(ctx, postgreSQLServerClient, integrationTestResourceGroup, pgServerName) + if err != nil { + t.Fatalf("Failed waiting for PostgreSQL server to be available: %v", err) + } + + err = createPostgreSQLAdministrator(ctx, administratorsClient, integrationTestResourceGroup, pgServerName, entraAdminObjectID, entraAdminPrincipalName, entraAdminTenantID) + if err != nil { + t.Fatalf("Failed to create PostgreSQL Administrator: %v", err) + } + + err = waitForPostgreSQLAdministratorAvailable(ctx, administratorsClient, integrationTestResourceGroup, pgServerName, entraAdminObjectID) + if err != nil { + t.Fatalf("Failed waiting for PostgreSQL Administrator to be available: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetPostgreSQLFlexibleServerAdministrator", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator( + clients.NewDBforPostgreSQLFlexibleServerAdministratorClient(administratorsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(pgServerName, entraAdminObjectID) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + if sdpItem.GetType() != azureshared.DBforPostgreSQLFlexibleServerAdministrator.String() { + t.Errorf("Expected type %s, got %s", azureshared.DBforPostgreSQLFlexibleServerAdministrator, sdpItem.GetType()) + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + if uniqueAttrKey != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", uniqueAttrKey) + } + + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + expectedUniqueAttrValue := shared.CompositeLookupKey(pgServerName, entraAdminObjectID) + if uniqueAttrValue != expectedUniqueAttrValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueAttrValue, uniqueAttrValue) + } + + if sdpItem.GetScope() != fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) { + t.Errorf("Expected scope %s.%s, got %s", subscriptionID, integrationTestResourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Successfully retrieved administrator %s", entraAdminObjectID) + }) + + t.Run("SearchPostgreSQLFlexibleServerAdministrators", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator( + clients.NewDBforPostgreSQLFlexibleServerAdministratorClient(administratorsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, scope, pgServerName, true) + if err != nil { + t.Fatalf("Failed to search administrators: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one administrator, got %d", len(sdpItems)) + } + + var foundAdmin bool + for _, item := range sdpItems { + if err := item.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + if item.GetType() != azureshared.DBforPostgreSQLFlexibleServerAdministrator.String() { + t.Errorf("Expected type %s, got %s", azureshared.DBforPostgreSQLFlexibleServerAdministrator, item.GetType()) + } + + expectedUniqueValue := shared.CompositeLookupKey(pgServerName, entraAdminObjectID) + if item.UniqueAttributeValue() == expectedUniqueValue { + foundAdmin = true + } + } + + if !foundAdmin { + t.Errorf("Expected to find administrator %s in search results", entraAdminObjectID) + } + + log.Printf("Found %d administrators in search results", len(sdpItems)) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator( + clients.NewDBforPostgreSQLFlexibleServerAdministratorClient(administratorsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(pgServerName, entraAdminObjectID) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + for _, liq := range linkedQueries { + if liq.GetQuery().GetType() == "" { + t.Error("Expected linked query Type to be non-empty") + } + if liq.GetQuery().GetMethod() != sdp.QueryMethod_GET && liq.GetQuery().GetMethod() != sdp.QueryMethod_SEARCH { + t.Errorf("Expected linked query Method to be GET or SEARCH, got %s", liq.GetQuery().GetMethod()) + } + if liq.GetQuery().GetQuery() == "" { + t.Error("Expected linked query Query to be non-empty") + } + if liq.GetQuery().GetScope() == "" { + t.Error("Expected linked query Scope to be non-empty") + } + } + + var hasServerLink bool + for _, liq := range linkedQueries { + if liq.GetQuery().GetType() == azureshared.DBforPostgreSQLFlexibleServer.String() { + hasServerLink = true + if liq.GetQuery().GetQuery() != pgServerName { + t.Errorf("Expected linked query to server %s, got %s", pgServerName, liq.GetQuery().GetQuery()) + } + if liq.GetQuery().GetMethod() != sdp.QueryMethod_GET { + t.Errorf("Expected linked query method GET, got %s", liq.GetQuery().GetMethod()) + } + if liq.GetQuery().GetScope() != scope { + t.Errorf("Expected linked query scope %s, got %s", scope, liq.GetQuery().GetScope()) + } + break + } + } + + if !hasServerLink { + t.Error("Expected linked query to PostgreSQL Flexible Server, but didn't find one") + } + + log.Printf("Verified %d linked item queries for administrator %s", len(linkedQueries), entraAdminObjectID) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator( + clients.NewDBforPostgreSQLFlexibleServerAdministratorClient(administratorsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(pgServerName, entraAdminObjectID) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.DBforPostgreSQLFlexibleServerAdministrator.String() { + t.Errorf("Expected type %s, got %s", azureshared.DBforPostgreSQLFlexibleServerAdministrator, sdpItem.GetType()) + } + + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + err := deletePostgreSQLAdministrator(ctx, administratorsClient, integrationTestResourceGroup, pgServerName, entraAdminObjectID) + if err != nil { + log.Printf("Warning: Failed to delete PostgreSQL Administrator: %v", err) + } + + err = deletePostgreSQLFlexibleServer(ctx, postgreSQLServerClient, integrationTestResourceGroup, pgServerName) + if err != nil { + t.Fatalf("Failed to delete PostgreSQL Flexible Server: %v", err) + } + }) +} + +// createPostgreSQLFlexibleServerWithMicrosoftEntraAuth creates a PostgreSQL Flexible Server with Microsoft Entra authentication enabled +func createPostgreSQLFlexibleServerWithMicrosoftEntraAuth(ctx context.Context, client *armpostgresqlflexibleservers.ServersClient, resourceGroupName, serverName, location string) error { + _, err := client.Get(ctx, resourceGroupName, serverName, nil) + if err == nil { + log.Printf("PostgreSQL Flexible Server %s already exists, skipping creation", serverName) + return nil + } + + adminLogin := os.Getenv("AZURE_POSTGRESQL_SERVER_ADMIN_LOGIN") + adminPassword := os.Getenv("AZURE_POSTGRESQL_SERVER_ADMIN_PASSWORD") + + if adminLogin == "" || adminPassword == "" { + return fmt.Errorf("AZURE_POSTGRESQL_SERVER_ADMIN_LOGIN and AZURE_POSTGRESQL_SERVER_ADMIN_PASSWORD environment variables must be set for integration tests") + } + + opCtx, cancel := context.WithTimeout(ctx, 25*time.Minute) + defer cancel() + + poller, err := client.BeginCreateOrUpdate(opCtx, resourceGroupName, serverName, armpostgresqlflexibleservers.Server{ + Location: new(location), + Properties: &armpostgresqlflexibleservers.ServerProperties{ + AdministratorLogin: new(adminLogin), + AdministratorLoginPassword: new(adminPassword), + Version: new(armpostgresqlflexibleservers.PostgresMajorVersion("14")), + Storage: &armpostgresqlflexibleservers.Storage{StorageSizeGB: new(int32(32))}, + Backup: &armpostgresqlflexibleservers.Backup{BackupRetentionDays: new(int32(7)), GeoRedundantBackup: new(armpostgresqlflexibleservers.GeographicallyRedundantBackupDisabled)}, + Network: &armpostgresqlflexibleservers.Network{PublicNetworkAccess: new(armpostgresqlflexibleservers.ServerPublicNetworkAccessStateEnabled)}, + HighAvailability: nil, + AuthConfig: &armpostgresqlflexibleservers.AuthConfig{ + ActiveDirectoryAuth: new(armpostgresqlflexibleservers.MicrosoftEntraAuthEnabled), + PasswordAuth: new(armpostgresqlflexibleservers.PasswordBasedAuthEnabled), + }, + }, + SKU: &armpostgresqlflexibleservers.SKU{ + Name: new("Standard_B1ms"), + Tier: new(armpostgresqlflexibleservers.SKUTierBurstable), + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "test": new("dbforpostgresql-administrator"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + log.Printf("PostgreSQL Flexible Server %s already exists, skipping creation", serverName) + return nil + } + return fmt.Errorf("failed to begin creating PostgreSQL Flexible Server: %w", err) + } + + resp, err := poller.PollUntilDone(opCtx, nil) + if err != nil { + return fmt.Errorf("failed to create PostgreSQL Flexible Server: %w", err) + } + + if resp.Properties == nil { + return fmt.Errorf("PostgreSQL Flexible Server created but properties are nil") + } + + log.Printf("PostgreSQL Flexible Server %s created successfully with Microsoft Entra authentication enabled", serverName) + return nil +} + +// createPostgreSQLAdministrator creates a Microsoft Entra administrator for a PostgreSQL Flexible Server +func createPostgreSQLAdministrator(ctx context.Context, client *armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClient, resourceGroupName, serverName, objectID, principalName, tenantID string) error { + _, err := client.Get(ctx, resourceGroupName, serverName, objectID, nil) + if err == nil { + log.Printf("PostgreSQL Administrator %s already exists on server %s, skipping creation", objectID, serverName) + return nil + } + + opCtx, cancel := context.WithTimeout(ctx, 10*time.Minute) + defer cancel() + + principalType := armpostgresqlflexibleservers.PrincipalTypeServicePrincipal + + poller, err := client.BeginCreateOrUpdate(opCtx, resourceGroupName, serverName, objectID, armpostgresqlflexibleservers.AdministratorMicrosoftEntraAdd{ + Properties: &armpostgresqlflexibleservers.AdministratorMicrosoftEntraPropertiesForAdd{ + PrincipalName: new(principalName), + PrincipalType: &principalType, + TenantID: new(tenantID), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + log.Printf("PostgreSQL Administrator %s already exists on server %s, skipping creation", objectID, serverName) + return nil + } + return fmt.Errorf("failed to begin creating PostgreSQL Administrator: %w", err) + } + + _, err = poller.PollUntilDone(opCtx, nil) + if err != nil { + return fmt.Errorf("failed to create PostgreSQL Administrator: %w", err) + } + + log.Printf("PostgreSQL Administrator %s created successfully on server %s", objectID, serverName) + return nil +} + +// waitForPostgreSQLAdministratorAvailable waits for a PostgreSQL Administrator to be fully available +func waitForPostgreSQLAdministratorAvailable(ctx context.Context, client *armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClient, resourceGroupName, serverName, objectID string) error { + maxAttempts := 30 + pollInterval := 10 * time.Second + + log.Printf("Waiting for PostgreSQL Administrator %s to be available on server %s...", objectID, serverName) + + for attempt := 1; attempt <= maxAttempts; attempt++ { + _, err := client.Get(ctx, resourceGroupName, serverName, objectID, nil) + if err == nil { + log.Printf("PostgreSQL Administrator %s is available on server %s", objectID, serverName) + return nil + } + + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("PostgreSQL Administrator %s not yet available (attempt %d/%d), waiting %v...", objectID, attempt, maxAttempts, pollInterval) + time.Sleep(pollInterval) + continue + } + + return fmt.Errorf("error checking PostgreSQL Administrator availability: %w", err) + } + + return fmt.Errorf("timeout waiting for PostgreSQL Administrator %s to be available on server %s", objectID, serverName) +} + +// deletePostgreSQLAdministrator deletes a Microsoft Entra administrator from a PostgreSQL Flexible Server +func deletePostgreSQLAdministrator(ctx context.Context, client *armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClient, resourceGroupName, serverName, objectID string) error { + opCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + poller, err := client.BeginDelete(opCtx, resourceGroupName, serverName, objectID, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("PostgreSQL Administrator %s already deleted or does not exist on server %s", objectID, serverName) + return nil + } + return fmt.Errorf("failed to begin deleting PostgreSQL Administrator: %w", err) + } + + _, err = poller.PollUntilDone(opCtx, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("PostgreSQL Administrator %s already deleted", objectID) + return nil + } + return fmt.Errorf("failed to delete PostgreSQL Administrator: %w", err) + } + + log.Printf("PostgreSQL Administrator %s deleted successfully from server %s", objectID, serverName) + return nil +} diff --git a/sources/azure/integration-tests/dbforpostgresql-flexible-server-virtual-endpoint_test.go b/sources/azure/integration-tests/dbforpostgresql-flexible-server-virtual-endpoint_test.go new file mode 100644 index 00000000..d4483cb9 --- /dev/null +++ b/sources/azure/integration-tests/dbforpostgresql-flexible-server-virtual-endpoint_test.go @@ -0,0 +1,428 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +const ( + integrationTestPGVirtualEndpointServerName = "ovm-integ-test-pg-vep" + integrationTestPGVirtualEndpointName = "ovm-integ-test-vep" +) + +func TestDBforPostgreSQLFlexibleServerVirtualEndpointIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + adminLogin := os.Getenv("AZURE_POSTGRESQL_SERVER_ADMIN_LOGIN") + adminPassword := os.Getenv("AZURE_POSTGRESQL_SERVER_ADMIN_PASSWORD") + if adminLogin == "" || adminPassword == "" { + t.Skip("AZURE_POSTGRESQL_SERVER_ADMIN_LOGIN and AZURE_POSTGRESQL_SERVER_ADMIN_PASSWORD must be set for PostgreSQL tests") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + postgreSQLServerClient, err := armpostgresqlflexibleservers.NewServersClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create PostgreSQL Flexible Servers client: %v", err) + } + + virtualEndpointsClient, err := armpostgresqlflexibleservers.NewVirtualEndpointsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create PostgreSQL Virtual Endpoints client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + pgServerName := generatePostgreSQLServerName(integrationTestPGVirtualEndpointServerName) + + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + err = createPostgreSQLFlexibleServerForVirtualEndpoint(ctx, postgreSQLServerClient, integrationTestResourceGroup, pgServerName, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create PostgreSQL Flexible Server: %v", err) + } + + err = waitForPostgreSQLServerAvailable(ctx, postgreSQLServerClient, integrationTestResourceGroup, pgServerName) + if err != nil { + t.Fatalf("Failed waiting for PostgreSQL server to be available: %v", err) + } + + err = createVirtualEndpoint(ctx, virtualEndpointsClient, integrationTestResourceGroup, pgServerName, integrationTestPGVirtualEndpointName) + if err != nil { + t.Fatalf("Failed to create virtual endpoint: %v", err) + } + + err = waitForVirtualEndpointAvailable(ctx, virtualEndpointsClient, integrationTestResourceGroup, pgServerName, integrationTestPGVirtualEndpointName) + if err != nil { + t.Fatalf("Failed waiting for virtual endpoint to be available: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetVirtualEndpoint", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint( + clients.NewDBforPostgreSQLFlexibleServerVirtualEndpointClient(virtualEndpointsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(pgServerName, integrationTestPGVirtualEndpointName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + if sdpItem.GetType() != azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint.String() { + t.Errorf("Expected type %s, got %s", azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint, sdpItem.GetType()) + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + if uniqueAttrKey != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", uniqueAttrKey) + } + + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + expectedUniqueAttrValue := shared.CompositeLookupKey(pgServerName, integrationTestPGVirtualEndpointName) + if uniqueAttrValue != expectedUniqueAttrValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueAttrValue, uniqueAttrValue) + } + + if sdpItem.GetScope() != fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) { + t.Errorf("Expected scope %s.%s, got %s", subscriptionID, integrationTestResourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Successfully retrieved virtual endpoint %s", integrationTestPGVirtualEndpointName) + }) + + t.Run("SearchVirtualEndpoints", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint( + clients.NewDBforPostgreSQLFlexibleServerVirtualEndpointClient(virtualEndpointsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, scope, pgServerName, true) + if err != nil { + t.Fatalf("Failed to search virtual endpoints: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one virtual endpoint, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil { + expectedValue := shared.CompositeLookupKey(pgServerName, integrationTestPGVirtualEndpointName) + if v == expectedValue { + found = true + break + } + } + } + + if !found { + t.Fatalf("Expected to find virtual endpoint %s in the search results", integrationTestPGVirtualEndpointName) + } + + log.Printf("Found %d virtual endpoints in search results", len(sdpItems)) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint( + clients.NewDBforPostgreSQLFlexibleServerVirtualEndpointClient(virtualEndpointsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(pgServerName, integrationTestPGVirtualEndpointName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + var hasServerLink bool + for _, liq := range linkedQueries { + q := liq.GetQuery() + if q.GetType() == azureshared.DBforPostgreSQLFlexibleServer.String() { + hasServerLink = true + if q.GetMethod() != sdp.QueryMethod_GET { + t.Errorf("Expected linked query method GET, got %s", q.GetMethod()) + } + if q.GetScope() != scope { + t.Errorf("Expected linked query scope %s, got %s", scope, q.GetScope()) + } + break + } + } + + if !hasServerLink { + t.Error("Expected linked query to PostgreSQL Flexible Server, but didn't find one") + } + + log.Printf("Verified %d linked item queries for virtual endpoint %s", len(linkedQueries), integrationTestPGVirtualEndpointName) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint( + clients.NewDBforPostgreSQLFlexibleServerVirtualEndpointClient(virtualEndpointsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(pgServerName, integrationTestPGVirtualEndpointName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint.String() { + t.Errorf("Expected type %s, got %s", azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint, sdpItem.GetType()) + } + + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + err := deleteVirtualEndpoint(ctx, virtualEndpointsClient, integrationTestResourceGroup, pgServerName, integrationTestPGVirtualEndpointName) + if err != nil { + log.Printf("Warning: failed to delete virtual endpoint: %v", err) + } + + err = deletePostgreSQLFlexibleServer(ctx, postgreSQLServerClient, integrationTestResourceGroup, pgServerName) + if err != nil { + t.Fatalf("Failed to delete PostgreSQL Flexible Server: %v", err) + } + }) +} + +func createPostgreSQLFlexibleServerForVirtualEndpoint(ctx context.Context, client *armpostgresqlflexibleservers.ServersClient, resourceGroupName, serverName, location string) error { + _, err := client.Get(ctx, resourceGroupName, serverName, nil) + if err == nil { + log.Printf("PostgreSQL Flexible Server %s already exists, skipping creation", serverName) + return nil + } + + adminLogin := os.Getenv("AZURE_POSTGRESQL_SERVER_ADMIN_LOGIN") + adminPassword := os.Getenv("AZURE_POSTGRESQL_SERVER_ADMIN_PASSWORD") + if adminLogin == "" || adminPassword == "" { + return fmt.Errorf("AZURE_POSTGRESQL_SERVER_ADMIN_LOGIN and AZURE_POSTGRESQL_SERVER_ADMIN_PASSWORD must be set") + } + + opCtx, cancel := context.WithTimeout(ctx, 25*time.Minute) + defer cancel() + + poller, err := client.BeginCreateOrUpdate(opCtx, resourceGroupName, serverName, armpostgresqlflexibleservers.Server{ + Location: new(location), + Properties: &armpostgresqlflexibleservers.ServerProperties{ + AdministratorLogin: new(adminLogin), + AdministratorLoginPassword: new(adminPassword), + Version: new(armpostgresqlflexibleservers.PostgresMajorVersion("14")), + Storage: &armpostgresqlflexibleservers.Storage{StorageSizeGB: new(int32(32))}, + Backup: &armpostgresqlflexibleservers.Backup{BackupRetentionDays: new(int32(7)), GeoRedundantBackup: new(armpostgresqlflexibleservers.GeographicallyRedundantBackupDisabled)}, + Network: &armpostgresqlflexibleservers.Network{PublicNetworkAccess: new(armpostgresqlflexibleservers.ServerPublicNetworkAccessStateEnabled)}, + }, + SKU: &armpostgresqlflexibleservers.SKU{ + Name: new("Standard_D2s_v3"), + Tier: new(armpostgresqlflexibleservers.SKUTierGeneralPurpose), + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "test": new("dbforpostgresql-virtual-endpoint"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + log.Printf("PostgreSQL Flexible Server %s already exists, skipping creation", serverName) + return nil + } + return fmt.Errorf("failed to begin creating PostgreSQL Flexible Server: %w", err) + } + + _, err = poller.PollUntilDone(opCtx, nil) + if err != nil { + return fmt.Errorf("failed to create PostgreSQL Flexible Server: %w", err) + } + + log.Printf("PostgreSQL Flexible Server %s (GeneralPurpose) created successfully", serverName) + return nil +} + +func createVirtualEndpoint(ctx context.Context, client *armpostgresqlflexibleservers.VirtualEndpointsClient, resourceGroupName, serverName, virtualEndpointName string) error { + _, err := client.Get(ctx, resourceGroupName, serverName, virtualEndpointName, nil) + if err == nil { + log.Printf("Virtual endpoint %s already exists, skipping creation", virtualEndpointName) + return nil + } + + opCtx, cancel := context.WithTimeout(ctx, 15*time.Minute) + defer cancel() + + endpointType := armpostgresqlflexibleservers.VirtualEndpointTypeReadWrite + poller, err := client.BeginCreate(opCtx, resourceGroupName, serverName, virtualEndpointName, armpostgresqlflexibleservers.VirtualEndpoint{ + Properties: &armpostgresqlflexibleservers.VirtualEndpointResourceProperties{ + EndpointType: &endpointType, + Members: []*string{new(serverName)}, + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + if _, getErr := client.Get(ctx, resourceGroupName, serverName, virtualEndpointName, nil); getErr == nil { + log.Printf("Virtual endpoint %s already exists (conflict), skipping", virtualEndpointName) + return nil + } + return fmt.Errorf("virtual endpoint %s conflict but not retrievable: %w", virtualEndpointName, err) + } + return fmt.Errorf("failed to begin creating virtual endpoint: %w", err) + } + + _, err = poller.PollUntilDone(opCtx, nil) + if err != nil { + return fmt.Errorf("failed to create virtual endpoint: %w", err) + } + + log.Printf("Virtual endpoint %s created successfully", virtualEndpointName) + return nil +} + +func waitForVirtualEndpointAvailable(ctx context.Context, client *armpostgresqlflexibleservers.VirtualEndpointsClient, resourceGroupName, serverName, virtualEndpointName string) error { + maxAttempts := 30 + pollInterval := 10 * time.Second + + for attempt := 1; attempt <= maxAttempts; attempt++ { + _, err := client.Get(ctx, resourceGroupName, serverName, virtualEndpointName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Virtual endpoint %s not yet available (attempt %d/%d), waiting...", virtualEndpointName, attempt, maxAttempts) + time.Sleep(pollInterval) + continue + } + return fmt.Errorf("error checking virtual endpoint availability: %w", err) + } + + log.Printf("Virtual endpoint %s is available", virtualEndpointName) + return nil + } + + return fmt.Errorf("timeout waiting for virtual endpoint %s to be available", virtualEndpointName) +} + +func deleteVirtualEndpoint(ctx context.Context, client *armpostgresqlflexibleservers.VirtualEndpointsClient, resourceGroupName, serverName, virtualEndpointName string) error { + _, err := client.Get(ctx, resourceGroupName, serverName, virtualEndpointName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Virtual endpoint %s does not exist, skipping deletion", virtualEndpointName) + return nil + } + return fmt.Errorf("error checking virtual endpoint existence: %w", err) + } + + poller, err := client.BeginDelete(ctx, resourceGroupName, serverName, virtualEndpointName, nil) + if err != nil { + return fmt.Errorf("failed to begin deleting virtual endpoint: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete virtual endpoint: %w", err) + } + + log.Printf("Virtual endpoint %s deleted successfully", virtualEndpointName) + return nil +} diff --git a/sources/azure/integration-tests/elastic-san-volume_test.go b/sources/azure/integration-tests/elastic-san-volume_test.go new file mode 100644 index 00000000..478b2fcc --- /dev/null +++ b/sources/azure/integration-tests/elastic-san-volume_test.go @@ -0,0 +1,590 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/elasticsan/armelasticsan" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +const ( + integrationTestElasticSanName = "ovm-integ-test-esan" + integrationTestVolumeGroupName = "ovm-integ-test-vg" + integrationTestVolumeName = "ovm-integ-test-vol" + integrationTestElasticSanBaseTiB = int64(1) + integrationTestVolumeSizeGiB = int64(1) +) + +func TestElasticSanVolumeIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + // Create Azure SDK clients + esClient, err := armelasticsan.NewElasticSansClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Elastic SAN client: %v", err) + } + + vgClient, err := armelasticsan.NewVolumeGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Volume Groups client: %v", err) + } + + volClient, err := armelasticsan.NewVolumesClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Volumes client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + // Create resource group if it doesn't exist + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + // Create Elastic SAN + err = createElasticSan(ctx, esClient, integrationTestResourceGroup, integrationTestElasticSanName, integrationTestLocation, integrationTestElasticSanBaseTiB) + if err != nil { + t.Fatalf("Failed to create Elastic SAN: %v", err) + } + + // Wait for Elastic SAN to be available + err = waitForElasticSanAvailable(ctx, esClient, integrationTestResourceGroup, integrationTestElasticSanName) + if err != nil { + t.Fatalf("Failed waiting for Elastic SAN to be available: %v", err) + } + + // Create Volume Group + err = createVolumeGroup(ctx, vgClient, integrationTestResourceGroup, integrationTestElasticSanName, integrationTestVolumeGroupName) + if err != nil { + t.Fatalf("Failed to create Volume Group: %v", err) + } + + // Wait for Volume Group to be available + err = waitForVolumeGroupAvailable(ctx, vgClient, integrationTestResourceGroup, integrationTestElasticSanName, integrationTestVolumeGroupName) + if err != nil { + t.Fatalf("Failed waiting for Volume Group to be available: %v", err) + } + + // Create Volume + err = createVolume(ctx, volClient, integrationTestResourceGroup, integrationTestElasticSanName, integrationTestVolumeGroupName, integrationTestVolumeName, integrationTestVolumeSizeGiB) + if err != nil { + t.Fatalf("Failed to create Volume: %v", err) + } + + // Wait for Volume to be available + err = waitForVolumeAvailable(ctx, volClient, integrationTestResourceGroup, integrationTestElasticSanName, integrationTestVolumeGroupName, integrationTestVolumeName) + if err != nil { + t.Fatalf("Failed waiting for Volume to be available: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetVolume", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving volume %s in volume group %s, elastic san %s, subscription %s, resource group %s", + integrationTestVolumeName, integrationTestVolumeGroupName, integrationTestElasticSanName, subscriptionID, integrationTestResourceGroup) + + volWrapper := manual.NewElasticSanVolume( + clients.NewElasticSanVolumeClient(volClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := volWrapper.Scopes()[0] + + volAdapter := sources.WrapperToAdapter(volWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestElasticSanName, integrationTestVolumeGroupName, integrationTestVolumeName) + sdpItem, qErr := volAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + expectedUnique := shared.CompositeLookupKey(integrationTestElasticSanName, integrationTestVolumeGroupName, integrationTestVolumeName) + if uniqueAttrValue != expectedUnique { + t.Errorf("Expected unique attribute value %s, got %s", expectedUnique, uniqueAttrValue) + } + + log.Printf("Successfully retrieved volume %s", integrationTestVolumeName) + }) + + t.Run("SearchVolumes", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Searching volumes in volume group %s, elastic san %s", integrationTestVolumeGroupName, integrationTestElasticSanName) + + volWrapper := manual.NewElasticSanVolume( + clients.NewElasticSanVolumeClient(volClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := volWrapper.Scopes()[0] + + volAdapter := sources.WrapperToAdapter(volWrapper, sdpcache.NewNoOpCache()) + + searchable, ok := volAdapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + query := shared.CompositeLookupKey(integrationTestElasticSanName, integrationTestVolumeGroupName) + sdpItems, err := searchable.Search(ctx, scope, query, true) + if err != nil { + t.Fatalf("Failed to search volumes: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one volume, got %d", len(sdpItems)) + } + + var found bool + expectedUnique := shared.CompositeLookupKey(integrationTestElasticSanName, integrationTestVolumeGroupName, integrationTestVolumeName) + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == expectedUnique { + found = true + break + } + } + + if !found { + t.Fatalf("Expected to find volume %s in the search results", integrationTestVolumeName) + } + + log.Printf("Found %d volumes in search results", len(sdpItems)) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for volume %s", integrationTestVolumeName) + + volWrapper := manual.NewElasticSanVolume( + clients.NewElasticSanVolumeClient(volClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := volWrapper.Scopes()[0] + + volAdapter := sources.WrapperToAdapter(volWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestElasticSanName, integrationTestVolumeGroupName, integrationTestVolumeName) + sdpItem, qErr := volAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + var hasElasticSanLink bool + var hasVolumeGroupLink bool + for _, liq := range linkedQueries { + query := liq.GetQuery() + if query.GetType() == "" { + t.Error("Linked item query has empty Type") + } + if query.GetQuery() == "" { + t.Error("Linked item query has empty Query") + } + if query.GetScope() == "" { + t.Error("Linked item query has empty Scope") + } + + if query.GetType() == azureshared.ElasticSan.String() { + hasElasticSanLink = true + if query.GetQuery() != integrationTestElasticSanName { + t.Errorf("Expected linked query to elastic san %s, got %s", integrationTestElasticSanName, query.GetQuery()) + } + } + if query.GetType() == azureshared.ElasticSanVolumeGroup.String() { + hasVolumeGroupLink = true + expectedQuery := shared.CompositeLookupKey(integrationTestElasticSanName, integrationTestVolumeGroupName) + if query.GetQuery() != expectedQuery { + t.Errorf("Expected linked query to volume group %s, got %s", expectedQuery, query.GetQuery()) + } + } + } + + if !hasElasticSanLink { + t.Error("Expected linked query to elastic san, but didn't find one") + } + if !hasVolumeGroupLink { + t.Error("Expected linked query to volume group, but didn't find one") + } + + log.Printf("Verified %d linked item queries for volume %s", len(linkedQueries), integrationTestVolumeName) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + volWrapper := manual.NewElasticSanVolume( + clients.NewElasticSanVolumeClient(volClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := volWrapper.Scopes()[0] + + volAdapter := sources.WrapperToAdapter(volWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestElasticSanName, integrationTestVolumeGroupName, integrationTestVolumeName) + sdpItem, qErr := volAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify item type + if sdpItem.GetType() != azureshared.ElasticSanVolume.String() { + t.Errorf("Expected type %s, got %s", azureshared.ElasticSanVolume.String(), sdpItem.GetType()) + } + + // Verify scope + expectedScope := subscriptionID + "." + integrationTestResourceGroup + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + // Verify unique attribute + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + // Validate item + if err := sdpItem.Validate(); err != nil { + t.Errorf("Item validation failed: %v", err) + } + + log.Printf("Verified item attributes for volume %s", integrationTestVolumeName) + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + // Delete Volume + err := deleteVolume(ctx, volClient, integrationTestResourceGroup, integrationTestElasticSanName, integrationTestVolumeGroupName, integrationTestVolumeName) + if err != nil { + t.Logf("Failed to delete volume: %v", err) + } + + // Delete Volume Group + err = deleteVolumeGroup(ctx, vgClient, integrationTestResourceGroup, integrationTestElasticSanName, integrationTestVolumeGroupName) + if err != nil { + t.Logf("Failed to delete volume group: %v", err) + } + + // Delete Elastic SAN + err = deleteElasticSan(ctx, esClient, integrationTestResourceGroup, integrationTestElasticSanName) + if err != nil { + t.Logf("Failed to delete elastic san: %v", err) + } + + // Resource group is kept for faster subsequent test runs + }) +} + +// createElasticSan creates an Azure Elastic SAN (idempotent) +func createElasticSan(ctx context.Context, client *armelasticsan.ElasticSansClient, resourceGroupName, elasticSanName, location string, baseSizeTiB int64) error { + _, err := client.Get(ctx, resourceGroupName, elasticSanName, nil) + if err == nil { + log.Printf("Elastic SAN %s already exists, skipping creation", elasticSanName) + return nil + } + + extendedCapacitySizeTiB := int64(0) + poller, err := client.BeginCreate(ctx, resourceGroupName, elasticSanName, armelasticsan.ElasticSan{ + Location: &location, + Properties: &armelasticsan.Properties{ + BaseSizeTiB: &baseSizeTiB, + ExtendedCapacitySizeTiB: &extendedCapacitySizeTiB, + SKU: &armelasticsan.SKU{ + Name: new(armelasticsan.SKUNamePremiumLRS), + }, + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + if _, getErr := client.Get(ctx, resourceGroupName, elasticSanName, nil); getErr == nil { + log.Printf("Elastic SAN %s already exists (conflict), skipping creation", elasticSanName) + return nil + } + return fmt.Errorf("elastic san %s conflict but not retrievable: %w", elasticSanName, err) + } + return fmt.Errorf("failed to create elastic san: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create elastic san: %w", err) + } + + log.Printf("Elastic SAN %s created successfully", elasticSanName) + return nil +} + +// waitForElasticSanAvailable waits for Elastic SAN to be available +func waitForElasticSanAvailable(ctx context.Context, client *armelasticsan.ElasticSansClient, resourceGroupName, elasticSanName string) error { + maxAttempts := 30 + pollInterval := 10 * time.Second + maxNotFoundAttempts := 5 + notFoundCount := 0 + + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := client.Get(ctx, resourceGroupName, elasticSanName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + notFoundCount++ + if notFoundCount >= maxNotFoundAttempts { + return fmt.Errorf("elastic san %s not found after %d attempts", elasticSanName, notFoundCount) + } + time.Sleep(pollInterval) + continue + } + return fmt.Errorf("error checking elastic san: %w", err) + } + notFoundCount = 0 + if resp.Properties != nil && resp.Properties.ProvisioningState != nil && *resp.Properties.ProvisioningState == armelasticsan.ProvisioningStatesSucceeded { + return nil + } + time.Sleep(pollInterval) + } + return fmt.Errorf("timeout waiting for elastic san %s", elasticSanName) +} + +// createVolumeGroup creates an Azure Elastic SAN Volume Group (idempotent) +func createVolumeGroup(ctx context.Context, client *armelasticsan.VolumeGroupsClient, resourceGroupName, elasticSanName, volumeGroupName string) error { + _, err := client.Get(ctx, resourceGroupName, elasticSanName, volumeGroupName, nil) + if err == nil { + log.Printf("Volume Group %s already exists, skipping creation", volumeGroupName) + return nil + } + + poller, err := client.BeginCreate(ctx, resourceGroupName, elasticSanName, volumeGroupName, armelasticsan.VolumeGroup{ + Properties: &armelasticsan.VolumeGroupProperties{ + ProtocolType: new(armelasticsan.StorageTargetTypeIscsi), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + if _, getErr := client.Get(ctx, resourceGroupName, elasticSanName, volumeGroupName, nil); getErr == nil { + log.Printf("Volume Group %s already exists (conflict), skipping creation", volumeGroupName) + return nil + } + return fmt.Errorf("volume group %s conflict but not retrievable: %w", volumeGroupName, err) + } + return fmt.Errorf("failed to create volume group: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create volume group: %w", err) + } + + log.Printf("Volume Group %s created successfully", volumeGroupName) + return nil +} + +// waitForVolumeGroupAvailable waits for Volume Group to be available +func waitForVolumeGroupAvailable(ctx context.Context, client *armelasticsan.VolumeGroupsClient, resourceGroupName, elasticSanName, volumeGroupName string) error { + maxAttempts := 20 + pollInterval := 5 * time.Second + maxNotFoundAttempts := 5 + notFoundCount := 0 + + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := client.Get(ctx, resourceGroupName, elasticSanName, volumeGroupName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + notFoundCount++ + if notFoundCount >= maxNotFoundAttempts { + return fmt.Errorf("volume group %s not found after %d attempts", volumeGroupName, notFoundCount) + } + time.Sleep(pollInterval) + continue + } + return fmt.Errorf("error checking volume group: %w", err) + } + notFoundCount = 0 + if resp.Properties != nil && resp.Properties.ProvisioningState != nil && *resp.Properties.ProvisioningState == armelasticsan.ProvisioningStatesSucceeded { + return nil + } + time.Sleep(pollInterval) + } + return fmt.Errorf("timeout waiting for volume group %s", volumeGroupName) +} + +// createVolume creates an Azure Elastic SAN Volume (idempotent) +func createVolume(ctx context.Context, client *armelasticsan.VolumesClient, resourceGroupName, elasticSanName, volumeGroupName, volumeName string, sizeGiB int64) error { + _, err := client.Get(ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, nil) + if err == nil { + log.Printf("Volume %s already exists, skipping creation", volumeName) + return nil + } + + poller, err := client.BeginCreate(ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, armelasticsan.Volume{ + Properties: &armelasticsan.VolumeProperties{ + SizeGiB: &sizeGiB, + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + if _, getErr := client.Get(ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, nil); getErr == nil { + log.Printf("Volume %s already exists (conflict), skipping creation", volumeName) + return nil + } + return fmt.Errorf("volume %s conflict but not retrievable: %w", volumeName, err) + } + return fmt.Errorf("failed to create volume: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create volume: %w", err) + } + + log.Printf("Volume %s created successfully", volumeName) + return nil +} + +// waitForVolumeAvailable waits for Volume to be available +func waitForVolumeAvailable(ctx context.Context, client *armelasticsan.VolumesClient, resourceGroupName, elasticSanName, volumeGroupName, volumeName string) error { + maxAttempts := 20 + pollInterval := 5 * time.Second + maxNotFoundAttempts := 5 + notFoundCount := 0 + + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := client.Get(ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + notFoundCount++ + if notFoundCount >= maxNotFoundAttempts { + return fmt.Errorf("volume %s not found after %d attempts", volumeName, notFoundCount) + } + time.Sleep(pollInterval) + continue + } + return fmt.Errorf("error checking volume: %w", err) + } + notFoundCount = 0 + if resp.Properties != nil && resp.Properties.ProvisioningState != nil && *resp.Properties.ProvisioningState == armelasticsan.ProvisioningStatesSucceeded { + return nil + } + time.Sleep(pollInterval) + } + return fmt.Errorf("timeout waiting for volume %s", volumeName) +} + +// deleteVolume deletes an Azure Elastic SAN Volume +func deleteVolume(ctx context.Context, client *armelasticsan.VolumesClient, resourceGroupName, elasticSanName, volumeGroupName, volumeName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Volume %s not found, skipping deletion", volumeName) + return nil + } + return fmt.Errorf("failed to delete volume: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete volume: %w", err) + } + + log.Printf("Volume %s deleted successfully", volumeName) + return nil +} + +// deleteVolumeGroup deletes an Azure Elastic SAN Volume Group +func deleteVolumeGroup(ctx context.Context, client *armelasticsan.VolumeGroupsClient, resourceGroupName, elasticSanName, volumeGroupName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, elasticSanName, volumeGroupName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Volume Group %s not found, skipping deletion", volumeGroupName) + return nil + } + return fmt.Errorf("failed to delete volume group: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete volume group: %w", err) + } + + log.Printf("Volume Group %s deleted successfully", volumeGroupName) + return nil +} + +// deleteElasticSan deletes an Azure Elastic SAN +func deleteElasticSan(ctx context.Context, client *armelasticsan.ElasticSansClient, resourceGroupName, elasticSanName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, elasticSanName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Elastic SAN %s not found, skipping deletion", elasticSanName) + return nil + } + return fmt.Errorf("failed to delete elastic san: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete elastic san: %w", err) + } + + log.Printf("Elastic SAN %s deleted successfully", elasticSanName) + return nil +} diff --git a/sources/azure/integration-tests/managedidentity-federated-identity-credential_test.go b/sources/azure/integration-tests/managedidentity-federated-identity-credential_test.go new file mode 100644 index 00000000..d8a6186e --- /dev/null +++ b/sources/azure/integration-tests/managedidentity-federated-identity-credential_test.go @@ -0,0 +1,334 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +const ( + integrationTestIdentityName = "ovm-integ-test-identity" + integrationTestFedCredName = "ovm-integ-test-fed-cred" + integrationTestFedCredIssuer = "https://token.actions.githubusercontent.com" + integrationTestFedCredSubject = "repo:overmindtech/test-repo:ref:refs/heads/main" + integrationTestFedCredAudience = "api://AzureADTokenExchange" +) + +func TestManagedIdentityFederatedIdentityCredentialIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + uaiClient, err := armmsi.NewUserAssignedIdentitiesClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create User Assigned Identities client: %v", err) + } + + ficClient, err := armmsi.NewFederatedIdentityCredentialsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Federated Identity Credentials client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + err = createUserAssignedIdentity(ctx, uaiClient, integrationTestResourceGroup, integrationTestIdentityName, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create user assigned identity: %v", err) + } + + err = waitForUserAssignedIdentityAvailable(ctx, uaiClient, integrationTestResourceGroup, integrationTestIdentityName) + if err != nil { + t.Fatalf("Failed waiting for user assigned identity to be available: %v", err) + } + + err = createFederatedIdentityCredential(ctx, ficClient, integrationTestResourceGroup, integrationTestIdentityName, integrationTestFedCredName) + if err != nil { + t.Fatalf("Failed to create federated identity credential: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetFederatedIdentityCredential", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving federated identity credential %s for identity %s, subscription %s, resource group %s", + integrationTestFedCredName, integrationTestIdentityName, subscriptionID, integrationTestResourceGroup) + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential( + clients.NewFederatedIdentityCredentialsClient(ficClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestIdentityName, integrationTestFedCredName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + expectedUniqueValue := shared.CompositeLookupKey(integrationTestIdentityName, integrationTestFedCredName) + if uniqueAttrValue != expectedUniqueValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueValue, uniqueAttrValue) + } + + log.Printf("Successfully retrieved federated identity credential %s", integrationTestFedCredName) + }) + + t.Run("SearchFederatedIdentityCredentials", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Searching federated identity credentials for identity %s", integrationTestIdentityName) + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential( + clients.NewFederatedIdentityCredentialsClient(ficClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, scope, integrationTestIdentityName, true) + if err != nil { + t.Fatalf("Failed to search federated identity credentials: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one federated identity credential, got %d", len(sdpItems)) + } + + var found bool + expectedValue := shared.CompositeLookupKey(integrationTestIdentityName, integrationTestFedCredName) + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == expectedValue { + found = true + break + } + } + + if !found { + t.Fatalf("Expected to find credential %s in the search results", integrationTestFedCredName) + } + + log.Printf("Found %d federated identity credentials in search results", len(sdpItems)) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for federated identity credential %s", integrationTestFedCredName) + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential( + clients.NewFederatedIdentityCredentialsClient(ficClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestIdentityName, integrationTestFedCredName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + var hasIdentityLink bool + var hasDNSLink bool + for _, liq := range linkedQueries { + query := liq.GetQuery() + if query.GetType() == "" { + t.Error("Linked query has empty Type") + } + if query.GetQuery() == "" { + t.Error("Linked query has empty Query") + } + if query.GetScope() == "" { + t.Error("Linked query has empty Scope") + } + + if query.GetType() == azureshared.ManagedIdentityUserAssignedIdentity.String() { + hasIdentityLink = true + if query.GetQuery() != integrationTestIdentityName { + t.Errorf("Expected linked query to identity %s, got %s", integrationTestIdentityName, query.GetQuery()) + } + } + + if query.GetType() == "dns" { + hasDNSLink = true + if query.GetQuery() != "token.actions.githubusercontent.com" { + t.Errorf("Expected DNS query to token.actions.githubusercontent.com, got %s", query.GetQuery()) + } + if query.GetScope() != "global" { + t.Errorf("Expected DNS query scope to be global, got %s", query.GetScope()) + } + } + } + + if !hasIdentityLink { + t.Error("Expected linked query to user assigned identity, but didn't find one") + } + + if !hasDNSLink { + t.Error("Expected linked query to DNS (from Issuer URL), but didn't find one") + } + + log.Printf("Verified %d linked item queries for federated identity credential %s", len(linkedQueries), integrationTestFedCredName) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential( + clients.NewFederatedIdentityCredentialsClient(ficClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestIdentityName, integrationTestFedCredName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.ManagedIdentityFederatedIdentityCredential.String() { + t.Errorf("Expected type %s, got %s", azureshared.ManagedIdentityFederatedIdentityCredential, sdpItem.GetType()) + } + + expectedScope := subscriptionID + "." + integrationTestResourceGroup + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Errorf("Item validation failed: %v", err) + } + + log.Printf("Verified item attributes for federated identity credential %s", integrationTestFedCredName) + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + err := deleteFederatedIdentityCredential(ctx, ficClient, integrationTestResourceGroup, integrationTestIdentityName, integrationTestFedCredName) + if err != nil { + t.Fatalf("Failed to delete federated identity credential: %v", err) + } + + err = deleteUserAssignedIdentity(ctx, uaiClient, integrationTestResourceGroup, integrationTestIdentityName) + if err != nil { + t.Fatalf("Failed to delete user assigned identity: %v", err) + } + }) +} + +func createFederatedIdentityCredential(ctx context.Context, client *armmsi.FederatedIdentityCredentialsClient, resourceGroupName, identityName, credentialName string) error { + _, err := client.Get(ctx, resourceGroupName, identityName, credentialName, nil) + if err == nil { + log.Printf("Federated identity credential %s already exists, skipping creation", credentialName) + return nil + } + + _, err = client.CreateOrUpdate(ctx, resourceGroupName, identityName, credentialName, armmsi.FederatedIdentityCredential{ + Properties: &armmsi.FederatedIdentityCredentialProperties{ + Issuer: new(integrationTestFedCredIssuer), + Subject: new(integrationTestFedCredSubject), + Audiences: []*string{new(integrationTestFedCredAudience)}, + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + if _, getErr := client.Get(ctx, resourceGroupName, identityName, credentialName, nil); getErr == nil { + log.Printf("Federated identity credential %s already exists (conflict), skipping creation", credentialName) + return nil + } + return fmt.Errorf("federated identity credential %s conflict but not retrievable: %w", credentialName, err) + } + return fmt.Errorf("failed to create federated identity credential: %w", err) + } + + log.Printf("Federated identity credential %s created successfully", credentialName) + return nil +} + +func deleteFederatedIdentityCredential(ctx context.Context, client *armmsi.FederatedIdentityCredentialsClient, resourceGroupName, identityName, credentialName string) error { + _, err := client.Delete(ctx, resourceGroupName, identityName, credentialName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Federated identity credential %s not found, skipping deletion", credentialName) + return nil + } + return fmt.Errorf("failed to delete federated identity credential: %w", err) + } + + log.Printf("Federated identity credential %s deleted successfully", credentialName) + return nil +} diff --git a/sources/azure/integration-tests/network-ip-group_test.go b/sources/azure/integration-tests/network-ip-group_test.go new file mode 100644 index 00000000..30854ad3 --- /dev/null +++ b/sources/azure/integration-tests/network-ip-group_test.go @@ -0,0 +1,364 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" +) + +const ( + integrationTestIPGroupName = "ovm-integ-test-ip-group" +) + +func TestNetworkIPGroupIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + ipGroupsClient, err := armnetwork.NewIPGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create IP Groups client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + err = createIPGroup(ctx, ipGroupsClient, integrationTestResourceGroup, integrationTestIPGroupName, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create IP group: %v", err) + } + + err = waitForIPGroupAvailable(ctx, ipGroupsClient, integrationTestResourceGroup, integrationTestIPGroupName) + if err != nil { + t.Fatalf("Failed waiting for IP group to be available: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetIPGroup", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving IP group %s in subscription %s, resource group %s", + integrationTestIPGroupName, subscriptionID, integrationTestResourceGroup) + + ipGroupWrapper := manual.NewNetworkIPGroup( + clients.NewIPGroupsClient(ipGroupsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := ipGroupWrapper.Scopes()[0] + + ipGroupAdapter := sources.WrapperToAdapter(ipGroupWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := ipGroupAdapter.Get(ctx, scope, integrationTestIPGroupName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + if uniqueAttrValue != integrationTestIPGroupName { + t.Fatalf("Expected unique attribute value to be %s, got %s", integrationTestIPGroupName, uniqueAttrValue) + } + + log.Printf("Successfully retrieved IP group %s", integrationTestIPGroupName) + }) + + t.Run("ListIPGroups", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Listing IP groups in subscription %s, resource group %s", + subscriptionID, integrationTestResourceGroup) + + ipGroupWrapper := manual.NewNetworkIPGroup( + clients.NewIPGroupsClient(ipGroupsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := ipGroupWrapper.Scopes()[0] + + ipGroupAdapter := sources.WrapperToAdapter(ipGroupWrapper, sdpcache.NewNoOpCache()) + + listable, ok := ipGroupAdapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Failed to list IP groups: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one IP group, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == integrationTestIPGroupName { + found = true + break + } + } + + if !found { + t.Fatalf("Expected to find IP group %s in the list", integrationTestIPGroupName) + } + + log.Printf("Found %d IP groups in resource group %s", len(sdpItems), integrationTestResourceGroup) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying item attributes for IP group %s", integrationTestIPGroupName) + + ipGroupWrapper := manual.NewNetworkIPGroup( + clients.NewIPGroupsClient(ipGroupsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := ipGroupWrapper.Scopes()[0] + + ipGroupAdapter := sources.WrapperToAdapter(ipGroupWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := ipGroupAdapter.Get(ctx, scope, integrationTestIPGroupName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkIPGroup.String() { + t.Errorf("Expected item type %s, got %s", azureshared.NetworkIPGroup, sdpItem.GetType()) + } + + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Verified item attributes for IP group %s", integrationTestIPGroupName) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for IP group %s", integrationTestIPGroupName) + + ipGroupWrapper := manual.NewNetworkIPGroup( + clients.NewIPGroupsClient(ipGroupsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := ipGroupWrapper.Scopes()[0] + + ipGroupAdapter := sources.WrapperToAdapter(ipGroupWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := ipGroupAdapter.Get(ctx, scope, integrationTestIPGroupName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + log.Printf("Found %d linked item queries for IP group %s", len(linkedQueries), integrationTestIPGroupName) + + for _, liq := range linkedQueries { + query := liq.GetQuery() + if query == nil { + t.Error("Linked item query has nil Query") + continue + } + + if query.GetType() == "" { + t.Error("Linked item query has empty Type") + } + if query.GetMethod() != sdp.QueryMethod_GET && query.GetMethod() != sdp.QueryMethod_SEARCH { + t.Errorf("Linked item query has unexpected Method: %v", query.GetMethod()) + } + if query.GetQuery() == "" { + t.Error("Linked item query has empty Query") + } + if query.GetScope() == "" { + t.Error("Linked item query has empty Scope") + } + + log.Printf("Verified linked item query: Type=%s, Method=%s, Query=%s, Scope=%s", + query.GetType(), query.GetMethod(), query.GetQuery(), query.GetScope()) + } + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + err := deleteIPGroup(ctx, ipGroupsClient, integrationTestResourceGroup, integrationTestIPGroupName) + if err != nil { + t.Fatalf("Failed to delete IP group: %v", err) + } + }) +} + +func createIPGroup(ctx context.Context, client *armnetwork.IPGroupsClient, resourceGroupName, ipGroupName, location string) error { + existingIPGroup, err := client.Get(ctx, resourceGroupName, ipGroupName, nil) + if err == nil { + if existingIPGroup.Properties != nil && existingIPGroup.Properties.ProvisioningState != nil { + state := *existingIPGroup.Properties.ProvisioningState + if state == armnetwork.ProvisioningStateSucceeded { + log.Printf("IP group %s already exists with state %s, skipping creation", ipGroupName, state) + return nil + } + log.Printf("IP group %s exists but in state %s, will wait for it", ipGroupName, state) + } else { + log.Printf("IP group %s already exists, skipping creation", ipGroupName) + return nil + } + } + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, ipGroupName, armnetwork.IPGroup{ + Location: new(location), + Properties: &armnetwork.IPGroupPropertiesFormat{ + IPAddresses: []*string{ + new("10.0.0.0/24"), + new("192.168.1.1"), + }, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "test": new("network-ip-group"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + log.Printf("IP group %s already exists (conflict), skipping creation", ipGroupName) + return nil + } + return fmt.Errorf("failed to begin creating IP group: %w", err) + } + + resp, err := poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create IP group: %w", err) + } + + if resp.Properties == nil || resp.Properties.ProvisioningState == nil { + return fmt.Errorf("IP group created but provisioning state is unknown") + } + + provisioningState := *resp.Properties.ProvisioningState + if provisioningState != armnetwork.ProvisioningStateSucceeded { + return fmt.Errorf("IP group provisioning state is %s, expected Succeeded", provisioningState) + } + + log.Printf("IP group %s created successfully with provisioning state: %s", ipGroupName, provisioningState) + return nil +} + +func waitForIPGroupAvailable(ctx context.Context, client *armnetwork.IPGroupsClient, resourceGroupName, ipGroupName string) error { + maxAttempts := 20 + pollInterval := 5 * time.Second + + log.Printf("Waiting for IP group %s to be available via API...", ipGroupName) + + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := client.Get(ctx, resourceGroupName, ipGroupName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("IP group %s not yet available (attempt %d/%d), waiting %v...", ipGroupName, attempt, maxAttempts, pollInterval) + time.Sleep(pollInterval) + continue + } + return fmt.Errorf("error checking IP group availability: %w", err) + } + + if resp.Properties != nil && resp.Properties.ProvisioningState != nil { + state := *resp.Properties.ProvisioningState + if state == armnetwork.ProvisioningStateSucceeded { + log.Printf("IP group %s is available with provisioning state: %s", ipGroupName, state) + return nil + } + if state == armnetwork.ProvisioningStateFailed { + return fmt.Errorf("IP group provisioning failed with state: %s", state) + } + log.Printf("IP group %s provisioning state: %s (attempt %d/%d), waiting...", ipGroupName, state, attempt, maxAttempts) + time.Sleep(pollInterval) + continue + } + + log.Printf("IP group %s is available", ipGroupName) + return nil + } + + return fmt.Errorf("timeout waiting for IP group %s to be available after %d attempts", ipGroupName, maxAttempts) +} + +func deleteIPGroup(ctx context.Context, client *armnetwork.IPGroupsClient, resourceGroupName, ipGroupName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, ipGroupName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("IP group %s not found, skipping deletion", ipGroupName) + return nil + } + return fmt.Errorf("failed to begin deleting IP group: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete IP group: %w", err) + } + + log.Printf("IP group %s deleted successfully", ipGroupName) + return nil +} diff --git a/sources/azure/integration-tests/network-local-network-gateway_test.go b/sources/azure/integration-tests/network-local-network-gateway_test.go new file mode 100644 index 00000000..1dd39d17 --- /dev/null +++ b/sources/azure/integration-tests/network-local-network-gateway_test.go @@ -0,0 +1,370 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" +) + +const ( + integrationTestLocalNetworkGatewayName = "ovm-integ-test-lng" +) + +func TestNetworkLocalNetworkGatewayIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + localNetworkGatewaysClient, err := armnetwork.NewLocalNetworkGatewaysClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Local Network Gateways client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + err = createLocalNetworkGateway(ctx, localNetworkGatewaysClient, integrationTestResourceGroup, integrationTestLocalNetworkGatewayName, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create local network gateway: %v", err) + } + + err = waitForLocalNetworkGatewayAvailable(ctx, localNetworkGatewaysClient, integrationTestResourceGroup, integrationTestLocalNetworkGatewayName) + if err != nil { + t.Fatalf("Failed waiting for local network gateway to be available: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetLocalNetworkGateway", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving local network gateway %s in subscription %s, resource group %s", + integrationTestLocalNetworkGatewayName, subscriptionID, integrationTestResourceGroup) + + wrapper := manual.NewNetworkLocalNetworkGateway( + clients.NewLocalNetworkGatewaysClient(localNetworkGatewaysClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := adapter.Get(ctx, scope, integrationTestLocalNetworkGatewayName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + if uniqueAttrValue != integrationTestLocalNetworkGatewayName { + t.Fatalf("Expected unique attribute value to be %s, got %s", integrationTestLocalNetworkGatewayName, uniqueAttrValue) + } + + log.Printf("Successfully retrieved local network gateway %s", integrationTestLocalNetworkGatewayName) + }) + + t.Run("ListLocalNetworkGateways", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Listing local network gateways in subscription %s, resource group %s", + subscriptionID, integrationTestResourceGroup) + + wrapper := manual.NewNetworkLocalNetworkGateway( + clients.NewLocalNetworkGatewaysClient(localNetworkGatewaysClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Failed to list local network gateways: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one local network gateway, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == integrationTestLocalNetworkGatewayName { + found = true + break + } + } + + if !found { + t.Fatalf("Expected to find local network gateway %s in the list", integrationTestLocalNetworkGatewayName) + } + + log.Printf("Found %d local network gateways in resource group %s", len(sdpItems), integrationTestResourceGroup) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying item attributes for local network gateway %s", integrationTestLocalNetworkGatewayName) + + wrapper := manual.NewNetworkLocalNetworkGateway( + clients.NewLocalNetworkGatewaysClient(localNetworkGatewaysClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := adapter.Get(ctx, scope, integrationTestLocalNetworkGatewayName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkLocalNetworkGateway.String() { + t.Errorf("Expected item type %s, got %s", azureshared.NetworkLocalNetworkGateway, sdpItem.GetType()) + } + + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Verified item attributes for local network gateway %s", integrationTestLocalNetworkGatewayName) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for local network gateway %s", integrationTestLocalNetworkGatewayName) + + wrapper := manual.NewNetworkLocalNetworkGateway( + clients.NewLocalNetworkGatewaysClient(localNetworkGatewaysClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := adapter.Get(ctx, scope, integrationTestLocalNetworkGatewayName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + log.Printf("Found %d linked item queries for local network gateway %s", len(linkedQueries), integrationTestLocalNetworkGatewayName) + + for _, liq := range linkedQueries { + query := liq.GetQuery() + if query == nil { + t.Error("Linked item query has nil Query") + continue + } + + if query.GetType() == "" { + t.Error("Linked item query has empty Type") + } + if query.GetMethod() != sdp.QueryMethod_GET && query.GetMethod() != sdp.QueryMethod_SEARCH { + t.Errorf("Linked item query has unexpected Method: %v", query.GetMethod()) + } + if query.GetQuery() == "" { + t.Error("Linked item query has empty Query") + } + if query.GetScope() == "" { + t.Error("Linked item query has empty Scope") + } + + log.Printf("Verified linked item query: Type=%s, Method=%s, Query=%s, Scope=%s", + query.GetType(), query.GetMethod(), query.GetQuery(), query.GetScope()) + } + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + err := deleteLocalNetworkGateway(ctx, localNetworkGatewaysClient, integrationTestResourceGroup, integrationTestLocalNetworkGatewayName) + if err != nil { + t.Fatalf("Failed to delete local network gateway: %v", err) + } + }) +} + +func createLocalNetworkGateway(ctx context.Context, client *armnetwork.LocalNetworkGatewaysClient, resourceGroupName, gatewayName, location string) error { + existingGateway, err := client.Get(ctx, resourceGroupName, gatewayName, nil) + if err == nil { + if existingGateway.Properties != nil && existingGateway.Properties.ProvisioningState != nil { + state := string(*existingGateway.Properties.ProvisioningState) + if state == "Succeeded" { + log.Printf("Local network gateway %s already exists with state %s, skipping creation", gatewayName, state) + return nil + } + log.Printf("Local network gateway %s exists but in state %s, will wait for it", gatewayName, state) + } else { + log.Printf("Local network gateway %s already exists, skipping creation", gatewayName) + return nil + } + } + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, gatewayName, armnetwork.LocalNetworkGateway{ + Location: new(location), + Properties: &armnetwork.LocalNetworkGatewayPropertiesFormat{ + GatewayIPAddress: new("203.0.113.1"), + LocalNetworkAddressSpace: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{ + new("10.1.0.0/16"), + new("10.2.0.0/16"), + }, + }, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "test": new("network-local-network-gateway"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + if _, getErr := client.Get(ctx, resourceGroupName, gatewayName, nil); getErr == nil { + log.Printf("Local network gateway %s already exists (conflict), skipping creation", gatewayName) + return nil + } + return fmt.Errorf("local network gateway %s conflict but not retrievable: %w", gatewayName, err) + } + return fmt.Errorf("failed to begin creating local network gateway: %w", err) + } + + resp, err := poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create local network gateway: %w", err) + } + + if resp.Properties == nil || resp.Properties.ProvisioningState == nil { + return fmt.Errorf("local network gateway created but provisioning state is unknown") + } + + provisioningState := string(*resp.Properties.ProvisioningState) + if provisioningState != "Succeeded" { + return fmt.Errorf("local network gateway provisioning state is %s, expected Succeeded", provisioningState) + } + + log.Printf("Local network gateway %s created successfully with provisioning state: %s", gatewayName, provisioningState) + return nil +} + +func waitForLocalNetworkGatewayAvailable(ctx context.Context, client *armnetwork.LocalNetworkGatewaysClient, resourceGroupName, gatewayName string) error { + maxAttempts := 20 + pollInterval := 5 * time.Second + + log.Printf("Waiting for local network gateway %s to be available via API...", gatewayName) + + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := client.Get(ctx, resourceGroupName, gatewayName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Local network gateway %s not yet available (attempt %d/%d), waiting %v...", gatewayName, attempt, maxAttempts, pollInterval) + time.Sleep(pollInterval) + continue + } + return fmt.Errorf("error checking local network gateway availability: %w", err) + } + + if resp.Properties != nil && resp.Properties.ProvisioningState != nil { + state := string(*resp.Properties.ProvisioningState) + if state == "Succeeded" { + log.Printf("Local network gateway %s is available with provisioning state: %s", gatewayName, state) + return nil + } + if state == "Failed" { + return fmt.Errorf("local network gateway provisioning failed with state: %s", state) + } + log.Printf("Local network gateway %s provisioning state: %s (attempt %d/%d), waiting...", gatewayName, state, attempt, maxAttempts) + time.Sleep(pollInterval) + continue + } + + log.Printf("Local network gateway %s is available", gatewayName) + return nil + } + + return fmt.Errorf("timeout waiting for local network gateway %s to be available after %d attempts", gatewayName, maxAttempts) +} + +func deleteLocalNetworkGateway(ctx context.Context, client *armnetwork.LocalNetworkGatewaysClient, resourceGroupName, gatewayName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, gatewayName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Local network gateway %s not found, skipping deletion", gatewayName) + return nil + } + return fmt.Errorf("failed to begin deleting local network gateway: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete local network gateway: %w", err) + } + + log.Printf("Local network gateway %s deleted successfully", gatewayName) + return nil +} diff --git a/sources/azure/integration-tests/network-network-interface-ip-configuration_test.go b/sources/azure/integration-tests/network-network-interface-ip-configuration_test.go new file mode 100644 index 00000000..4390f1f9 --- /dev/null +++ b/sources/azure/integration-tests/network-network-interface-ip-configuration_test.go @@ -0,0 +1,410 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +const ( + integrationTestIPConfigNICName = "ovm-integ-test-nic-for-ipconfig" + integrationTestIPConfigVNetName = "ovm-integ-test-vnet-for-ipconfig" + integrationTestIPConfigSubnetName = "default" + integrationTestIPConfigIPConfigName = "ipconfig1" +) + +func TestNetworkNetworkInterfaceIPConfigurationIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + vnetClient, err := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Virtual Networks client: %v", err) + } + + subnetClient, err := armnetwork.NewSubnetsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Subnets client: %v", err) + } + + nicClient, err := armnetwork.NewInterfacesClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Network Interfaces client: %v", err) + } + + ipConfigClient, err := armnetwork.NewInterfaceIPConfigurationsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Interface IP Configurations client: %v", err) + } + + setupCompleted := false + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + err = createVirtualNetworkForIPConfig(ctx, vnetClient, integrationTestResourceGroup, integrationTestIPConfigVNetName, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create virtual network: %v", err) + } + + subnetResp, err := subnetClient.Get(ctx, integrationTestResourceGroup, integrationTestIPConfigVNetName, integrationTestIPConfigSubnetName, nil) + if err != nil { + t.Fatalf("Failed to get subnet: %v", err) + } + + err = createNetworkInterfaceForIPConfig(ctx, nicClient, integrationTestResourceGroup, integrationTestIPConfigNICName, integrationTestLocation, *subnetResp.ID) + if err != nil { + t.Fatalf("Failed to create network interface: %v", err) + } + + setupCompleted = true + log.Printf("Setup completed: Network interface %s created with IP configuration %s", integrationTestIPConfigNICName, integrationTestIPConfigIPConfigName) + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetIPConfiguration", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving IP configuration %s from NIC %s in subscription %s, resource group %s", + integrationTestIPConfigIPConfigName, integrationTestIPConfigNICName, subscriptionID, integrationTestResourceGroup) + + ipConfigWrapper := manual.NewNetworkNetworkInterfaceIPConfiguration( + clients.NewInterfaceIPConfigurationsClient(ipConfigClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := ipConfigWrapper.Scopes()[0] + + ipConfigAdapter := sources.WrapperToAdapter(ipConfigWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestIPConfigNICName, integrationTestIPConfigIPConfigName) + sdpItem, qErr := ipConfigAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + expectedUniqueValue := shared.CompositeLookupKey(integrationTestIPConfigNICName, integrationTestIPConfigIPConfigName) + if uniqueAttrValue != expectedUniqueValue { + t.Fatalf("Expected unique attribute value to be %s, got %s", expectedUniqueValue, uniqueAttrValue) + } + + log.Printf("Successfully retrieved IP configuration %s", integrationTestIPConfigIPConfigName) + }) + + t.Run("SearchIPConfigurations", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Searching IP configurations in NIC %s", integrationTestIPConfigNICName) + + ipConfigWrapper := manual.NewNetworkNetworkInterfaceIPConfiguration( + clients.NewInterfaceIPConfigurationsClient(ipConfigClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := ipConfigWrapper.Scopes()[0] + + ipConfigAdapter := sources.WrapperToAdapter(ipConfigWrapper, sdpcache.NewNoOpCache()) + searchable, ok := ipConfigAdapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, scope, integrationTestIPConfigNICName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) == 0 { + t.Fatalf("Expected at least 1 IP configuration, got: %d", len(sdpItems)) + } + + var found bool + expectedUniqueValue := shared.CompositeLookupKey(integrationTestIPConfigNICName, integrationTestIPConfigIPConfigName) + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == expectedUniqueValue { + found = true + break + } + } + + if !found { + t.Fatalf("Expected to find IP configuration %s in search results", integrationTestIPConfigIPConfigName) + } + + log.Printf("Successfully found %d IP configurations in search results", len(sdpItems)) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for IP configuration %s", integrationTestIPConfigIPConfigName) + + ipConfigWrapper := manual.NewNetworkNetworkInterfaceIPConfiguration( + clients.NewInterfaceIPConfigurationsClient(ipConfigClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := ipConfigWrapper.Scopes()[0] + + ipConfigAdapter := sources.WrapperToAdapter(ipConfigWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestIPConfigNICName, integrationTestIPConfigIPConfigName) + sdpItem, qErr := ipConfigAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + var hasNICLink bool + var hasSubnetLink bool + for _, liq := range linkedQueries { + query := liq.GetQuery() + if query.GetType() == "" { + t.Error("Linked item query has empty type") + } + if query.GetQuery() == "" { + t.Error("Linked item query has empty query") + } + if query.GetScope() == "" { + t.Error("Linked item query has empty scope") + } + + switch query.GetType() { + case azureshared.NetworkNetworkInterface.String(): + hasNICLink = true + if query.GetQuery() != integrationTestIPConfigNICName { + t.Errorf("Expected linked query to NIC %s, got %s", integrationTestIPConfigNICName, query.GetQuery()) + } + case azureshared.NetworkSubnet.String(): + hasSubnetLink = true + } + } + + if !hasNICLink { + t.Error("Expected linked query to parent network interface, but didn't find one") + } + + if !hasSubnetLink { + t.Error("Expected linked query to subnet, but didn't find one") + } + + log.Printf("Verified %d linked item queries for IP configuration %s", len(linkedQueries), integrationTestIPConfigIPConfigName) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + ipConfigWrapper := manual.NewNetworkNetworkInterfaceIPConfiguration( + clients.NewInterfaceIPConfigurationsClient(ipConfigClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := ipConfigWrapper.Scopes()[0] + + ipConfigAdapter := sources.WrapperToAdapter(ipConfigWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(integrationTestIPConfigNICName, integrationTestIPConfigIPConfigName) + sdpItem, qErr := ipConfigAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkNetworkInterfaceIPConfiguration.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkNetworkInterfaceIPConfiguration, sdpItem.GetType()) + } + + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Errorf("Item validation failed: %v", err) + } + + log.Printf("Verified item attributes for IP configuration %s", integrationTestIPConfigIPConfigName) + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + err := deleteNetworkInterfaceForIPConfig(ctx, nicClient, integrationTestResourceGroup, integrationTestIPConfigNICName) + if err != nil { + t.Fatalf("Failed to delete network interface: %v", err) + } + + err = deleteVirtualNetworkForIPConfig(ctx, vnetClient, integrationTestResourceGroup, integrationTestIPConfigVNetName) + if err != nil { + t.Fatalf("Failed to delete virtual network: %v", err) + } + }) +} + +func createVirtualNetworkForIPConfig(ctx context.Context, client *armnetwork.VirtualNetworksClient, resourceGroupName, vnetName, location string) error { + _, err := client.Get(ctx, resourceGroupName, vnetName, nil) + if err == nil { + log.Printf("Virtual network %s already exists, skipping creation", vnetName) + return nil + } + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, vnetName, armnetwork.VirtualNetwork{ + Location: new(location), + Properties: &armnetwork.VirtualNetworkPropertiesFormat{ + AddressSpace: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{new("10.2.0.0/16")}, + }, + Subnets: []*armnetwork.Subnet{ + { + Name: new(integrationTestIPConfigSubnetName), + Properties: &armnetwork.SubnetPropertiesFormat{ + AddressPrefix: new("10.2.0.0/24"), + }, + }, + }, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to begin creating virtual network: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create virtual network: %w", err) + } + + log.Printf("Virtual network %s created successfully", vnetName) + return nil +} + +func deleteVirtualNetworkForIPConfig(ctx context.Context, client *armnetwork.VirtualNetworksClient, resourceGroupName, vnetName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, vnetName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Virtual network %s not found, skipping deletion", vnetName) + return nil + } + return fmt.Errorf("failed to begin deleting virtual network: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete virtual network: %w", err) + } + + log.Printf("Virtual network %s deleted successfully", vnetName) + return nil +} + +func createNetworkInterfaceForIPConfig(ctx context.Context, client *armnetwork.InterfacesClient, resourceGroupName, nicName, location, subnetID string) error { + _, err := client.Get(ctx, resourceGroupName, nicName, nil) + if err == nil { + log.Printf("Network interface %s already exists, skipping creation", nicName) + return nil + } + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, nicName, armnetwork.Interface{ + Location: new(location), + Properties: &armnetwork.InterfacePropertiesFormat{ + IPConfigurations: []*armnetwork.InterfaceIPConfiguration{ + { + Name: new(integrationTestIPConfigIPConfigName), + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ + Subnet: &armnetwork.Subnet{ + ID: new(subnetID), + }, + PrivateIPAllocationMethod: new(armnetwork.IPAllocationMethodDynamic), + Primary: new(true), + }, + }, + }, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to begin creating network interface: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create network interface: %w", err) + } + + log.Printf("Network interface %s created successfully", nicName) + return nil +} + +func deleteNetworkInterfaceForIPConfig(ctx context.Context, client *armnetwork.InterfacesClient, resourceGroupName, nicName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, nicName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Network interface %s not found, skipping deletion", nicName) + return nil + } + return fmt.Errorf("failed to begin deleting network interface: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete network interface: %w", err) + } + + log.Printf("Network interface %s deleted successfully", nicName) + return nil +} diff --git a/sources/azure/integration-tests/network-network-watcher_test.go b/sources/azure/integration-tests/network-network-watcher_test.go new file mode 100644 index 00000000..b8b00918 --- /dev/null +++ b/sources/azure/integration-tests/network-network-watcher_test.go @@ -0,0 +1,339 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" +) + +const ( + // Azure only allows one Network Watcher per region per subscription. + // We create a test Network Watcher in our integration test resource group. + integrationTestNetworkWatcherTestName = "ovm-integ-test-nw" +) + +func TestNetworkNetworkWatcherIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + networkWatchersClient, err := armnetwork.NewWatchersClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Network Watchers client: %v", err) + } + + setupCompleted := false + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + // Create resource group if it doesn't exist + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + // Create network watcher - Azure only allows one per region per subscription + err = createNetworkWatcher(ctx, networkWatchersClient, integrationTestResourceGroup, integrationTestNetworkWatcherTestName, integrationTestLocation) + if err != nil { + // If we hit the limit, it means a Network Watcher already exists in another RG + if strings.Contains(err.Error(), "NetworkWatcherCountLimitReached") { + t.Skipf("Skipping: Azure allows only one Network Watcher per region. One already exists: %v", err) + } + t.Fatalf("Failed to create network watcher: %v", err) + } + + // Wait for network watcher to be available + err = waitForNetworkWatcherAvailable(ctx, networkWatchersClient, integrationTestResourceGroup, integrationTestNetworkWatcherTestName) + if err != nil { + t.Fatalf("Failed waiting for network watcher: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetNetworkWatcher", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving network watcher %s in subscription %s, resource group %s", + integrationTestNetworkWatcherTestName, subscriptionID, integrationTestResourceGroup) + + wrapper := manual.NewNetworkNetworkWatcher( + clients.NewNetworkWatchersClient(networkWatchersClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, scope, integrationTestNetworkWatcherTestName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + if uniqueAttrValue != integrationTestNetworkWatcherTestName { + t.Fatalf("Expected unique attribute value to be %s, got %s", integrationTestNetworkWatcherTestName, uniqueAttrValue) + } + + log.Printf("Successfully retrieved network watcher %s", integrationTestNetworkWatcherTestName) + }) + + t.Run("ListNetworkWatchers", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Listing network watchers in subscription %s, resource group %s", + subscriptionID, integrationTestResourceGroup) + + wrapper := manual.NewNetworkNetworkWatcher( + clients.NewNetworkWatchersClient(networkWatchersClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Failed to list network watchers: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one network watcher, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == integrationTestNetworkWatcherTestName { + found = true + break + } + } + + if !found { + t.Fatalf("Expected to find network watcher %s in the list", integrationTestNetworkWatcherTestName) + } + + log.Printf("Found %d network watchers in resource group %s", len(sdpItems), integrationTestResourceGroup) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for network watcher %s", integrationTestNetworkWatcherTestName) + + wrapper := manual.NewNetworkNetworkWatcher( + clients.NewNetworkWatchersClient(networkWatchersClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, scope, integrationTestNetworkWatcherTestName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + + for _, query := range linkedQueries { + q := query.GetQuery() + if q == nil { + t.Error("LinkedItemQuery has nil Query") + continue + } + + if q.GetType() == "" { + t.Error("LinkedItemQuery has empty Type") + } + + if q.GetMethod() != sdp.QueryMethod_GET && q.GetMethod() != sdp.QueryMethod_SEARCH { + t.Errorf("LinkedItemQuery has invalid Method: %v", q.GetMethod()) + } + + if q.GetQuery() == "" { + t.Error("LinkedItemQuery has empty Query") + } + + if q.GetScope() == "" { + t.Error("LinkedItemQuery has empty Scope") + } + } + + log.Printf("Verified %d linked item queries for network watcher %s", len(linkedQueries), integrationTestNetworkWatcherTestName) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying item attributes for network watcher %s", integrationTestNetworkWatcherTestName) + + wrapper := manual.NewNetworkNetworkWatcher( + clients.NewNetworkWatchersClient(networkWatchersClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, scope, integrationTestNetworkWatcherTestName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkNetworkWatcher.String() { + t.Errorf("Expected item type %s, got %s", azureshared.NetworkNetworkWatcher, sdpItem.GetType()) + } + + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Verified item attributes for network watcher %s", integrationTestNetworkWatcherTestName) + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + // Delete the network watcher we created + err := deleteNetworkWatcher(ctx, networkWatchersClient, integrationTestResourceGroup, integrationTestNetworkWatcherTestName) + if err != nil { + t.Logf("Warning: Failed to delete network watcher %s: %v", integrationTestNetworkWatcherTestName, err) + } + }) +} + +func createNetworkWatcher(ctx context.Context, client *armnetwork.WatchersClient, resourceGroup, name, location string) error { + _, err := client.Get(ctx, resourceGroup, name, nil) + if err == nil { + log.Printf("Network watcher %s already exists, skipping creation", name) + return nil + } + + result, err := client.CreateOrUpdate(ctx, resourceGroup, name, armnetwork.Watcher{ + Location: &location, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + if _, getErr := client.Get(ctx, resourceGroup, name, nil); getErr == nil { + log.Printf("Network watcher %s already exists (conflict), skipping", name) + return nil + } + return fmt.Errorf("network watcher %s conflict but not retrievable: %w", name, err) + } + return fmt.Errorf("failed to create network watcher: %w", err) + } + + log.Printf("Network watcher %s created: %v", name, result.Watcher.Name) + return nil +} + +func waitForNetworkWatcherAvailable(ctx context.Context, client *armnetwork.WatchersClient, resourceGroup, name string) error { + maxAttempts := 20 + pollInterval := 5 * time.Second + maxNotFoundAttempts := 5 + notFoundCount := 0 + + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := client.Get(ctx, resourceGroup, name, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + notFoundCount++ + if notFoundCount >= maxNotFoundAttempts { + return fmt.Errorf("network watcher %s not found after %d attempts", name, notFoundCount) + } + time.Sleep(pollInterval) + continue + } + return fmt.Errorf("error checking network watcher: %w", err) + } + notFoundCount = 0 + if resp.Properties != nil && resp.Properties.ProvisioningState != nil && *resp.Properties.ProvisioningState == armnetwork.ProvisioningStateSucceeded { + log.Printf("Network watcher %s is available", name) + return nil + } + time.Sleep(pollInterval) + } + return fmt.Errorf("timeout waiting for network watcher %s", name) +} + +func deleteNetworkWatcher(ctx context.Context, client *armnetwork.WatchersClient, resourceGroup, name string) error { + poller, err := client.BeginDelete(ctx, resourceGroup, name, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Network watcher %s already deleted", name) + return nil + } + return fmt.Errorf("failed to begin delete network watcher: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete network watcher: %w", err) + } + + log.Printf("Network watcher %s deleted successfully", name) + return nil +} diff --git a/sources/azure/integration-tests/network-private-link-service_test.go b/sources/azure/integration-tests/network-private-link-service_test.go new file mode 100644 index 00000000..9644aa09 --- /dev/null +++ b/sources/azure/integration-tests/network-private-link-service_test.go @@ -0,0 +1,555 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" +) + +const ( + integrationTestPLSName = "ovm-integ-test-pls" + integrationTestVNetNameForPLS = "ovm-integ-test-vnet-for-pls" + integrationTestSubnetNameForPLS = "pls-subnet" + integrationTestLBNameForPLS = "ovm-integ-test-lb-for-pls" + integrationTestFrontendIPForPLS = "frontend-ip-config" + integrationTestBackendPoolForPLS = "backend-pool" +) + +func TestNetworkPrivateLinkServiceIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + // Initialize Azure credentials using DefaultAzureCredential + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + // Create Azure SDK clients + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + vnetClient, err := armnetwork.NewVirtualNetworksClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Virtual Networks client: %v", err) + } + + subnetClient, err := armnetwork.NewSubnetsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Subnets client: %v", err) + } + + lbClient, err := armnetwork.NewLoadBalancersClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Load Balancers client: %v", err) + } + + plsClient, err := armnetwork.NewPrivateLinkServicesClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Private Link Services client: %v", err) + } + + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + // Create resource group if it doesn't exist + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + // Create virtual network for private link service (with special subnet settings) + err = createVirtualNetworkForPLS(ctx, vnetClient, integrationTestResourceGroup, integrationTestVNetNameForPLS, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create virtual network: %v", err) + } + + // Get subnet ID for load balancer and private link service + subnetResp, err := subnetClient.Get(ctx, integrationTestResourceGroup, integrationTestVNetNameForPLS, integrationTestSubnetNameForPLS, nil) + if err != nil { + t.Fatalf("Failed to get subnet: %v", err) + } + + // Create internal load balancer for private link service + err = createInternalLoadBalancerForPLS(ctx, lbClient, subscriptionID, integrationTestResourceGroup, integrationTestLBNameForPLS, integrationTestLocation, *subnetResp.ID) + if err != nil { + t.Fatalf("Failed to create internal load balancer: %v", err) + } + + // Get load balancer frontend IP configuration ID + lbResp, err := lbClient.Get(ctx, integrationTestResourceGroup, integrationTestLBNameForPLS, nil) + if err != nil { + t.Fatalf("Failed to get load balancer: %v", err) + } + + var frontendIPConfigID string + if lbResp.Properties != nil && len(lbResp.Properties.FrontendIPConfigurations) > 0 { + frontendIPConfigID = *lbResp.Properties.FrontendIPConfigurations[0].ID + } + if frontendIPConfigID == "" { + t.Fatalf("Failed to get frontend IP configuration ID") + } + + // Create private link service + err = createPrivateLinkService(ctx, plsClient, integrationTestResourceGroup, integrationTestPLSName, integrationTestLocation, *subnetResp.ID, frontendIPConfigID) + if err != nil { + t.Fatalf("Failed to create private link service: %v", err) + } + + setupCompleted = true + log.Printf("Setup completed: Private Link Service %s created", integrationTestPLSName) + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetPrivateLinkService", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving private link service %s in subscription %s, resource group %s", + integrationTestPLSName, subscriptionID, integrationTestResourceGroup) + + plsWrapper := manual.NewNetworkPrivateLinkService( + clients.NewPrivateLinkServicesClient(plsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := plsWrapper.Scopes()[0] + + plsAdapter := sources.WrapperToAdapter(plsWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := plsAdapter.Get(ctx, scope, integrationTestPLSName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + if uniqueAttrValue != integrationTestPLSName { + t.Fatalf("Expected unique attribute value to be %s, got %s", integrationTestPLSName, uniqueAttrValue) + } + + if sdpItem.GetType() != azureshared.NetworkPrivateLinkService.String() { + t.Fatalf("Expected type %s, got %s", azureshared.NetworkPrivateLinkService, sdpItem.GetType()) + } + + log.Printf("Successfully retrieved private link service %s", integrationTestPLSName) + }) + + t.Run("ListPrivateLinkServices", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Listing private link services in subscription %s, resource group %s", + subscriptionID, integrationTestResourceGroup) + + plsWrapper := manual.NewNetworkPrivateLinkService( + clients.NewPrivateLinkServicesClient(plsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := plsWrapper.Scopes()[0] + + plsAdapter := sources.WrapperToAdapter(plsWrapper, sdpcache.NewNoOpCache()) + listable, ok := plsAdapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least 1 private link service, got: %d", len(sdpItems)) + } + + // Find our test private link service + found := false + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil { + if v == integrationTestPLSName { + found = true + if item.GetType() != azureshared.NetworkPrivateLinkService.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkPrivateLinkService, item.GetType()) + } + break + } + } + } + + if !found { + t.Fatalf("Expected to find private link service %s in list, but didn't", integrationTestPLSName) + } + + log.Printf("Successfully listed %d private link services", len(sdpItems)) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + plsWrapper := manual.NewNetworkPrivateLinkService( + clients.NewPrivateLinkServicesClient(plsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := plsWrapper.Scopes()[0] + + plsAdapter := sources.WrapperToAdapter(plsWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := plsAdapter.Get(ctx, scope, integrationTestPLSName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify item type + if sdpItem.GetType() != azureshared.NetworkPrivateLinkService.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkPrivateLinkService, sdpItem.GetType()) + } + + // Verify scope + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + // Verify unique attribute + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + // Verify Validate() passes + if err := sdpItem.Validate(); err != nil { + t.Errorf("Expected item to validate, got error: %v", err) + } + + log.Printf("Verified item attributes for private link service %s", integrationTestPLSName) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + plsWrapper := manual.NewNetworkPrivateLinkService( + clients.NewPrivateLinkServicesClient(plsClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := plsWrapper.Scopes()[0] + + plsAdapter := sources.WrapperToAdapter(plsWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := plsAdapter.Get(ctx, scope, integrationTestPLSName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + // Verify each linked item query has required fields + for i, liq := range linkedQueries { + query := liq.GetQuery() + if query.GetType() == "" { + t.Errorf("Linked query %d has empty Type", i) + } + if query.GetQuery() == "" { + t.Errorf("Linked query %d has empty Query", i) + } + if query.GetScope() == "" { + t.Errorf("Linked query %d has empty Scope", i) + } + } + + // Verify expected linked item types + expectedLinkedTypes := map[string]bool{ + azureshared.NetworkSubnet.String(): false, + azureshared.NetworkVirtualNetwork.String(): false, + azureshared.NetworkLoadBalancerFrontendIPConfiguration.String(): false, + azureshared.NetworkLoadBalancer.String(): false, + } + + for _, liq := range linkedQueries { + linkedType := liq.GetQuery().GetType() + if _, exists := expectedLinkedTypes[linkedType]; exists { + expectedLinkedTypes[linkedType] = true + } + } + + for linkedType, found := range expectedLinkedTypes { + if !found { + t.Errorf("Expected linked query to %s, but didn't find one", linkedType) + } + } + + log.Printf("Verified %d linked item queries for private link service %s", len(linkedQueries), integrationTestPLSName) + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + // Delete private link service + err := deletePrivateLinkService(ctx, plsClient, integrationTestResourceGroup, integrationTestPLSName) + if err != nil { + t.Logf("Warning: Failed to delete private link service: %v", err) + } + + // Delete load balancer + err = deleteLoadBalancer(ctx, lbClient, integrationTestResourceGroup, integrationTestLBNameForPLS) + if err != nil { + t.Logf("Warning: Failed to delete load balancer: %v", err) + } + + // Delete VNet (this also deletes the subnet) + err = deleteVirtualNetworkForPLS(ctx, vnetClient, integrationTestResourceGroup, integrationTestVNetNameForPLS) + if err != nil { + t.Logf("Warning: Failed to delete virtual network: %v", err) + } + + log.Printf("Teardown completed") + }) +} + +// createVirtualNetworkForPLS creates an Azure virtual network with a subnet that has privateLinkServiceNetworkPolicies disabled +func createVirtualNetworkForPLS(ctx context.Context, client *armnetwork.VirtualNetworksClient, resourceGroupName, vnetName, location string) error { + // Check if VNet already exists + _, err := client.Get(ctx, resourceGroupName, vnetName, nil) + if err == nil { + log.Printf("Virtual network %s already exists, skipping creation", vnetName) + return nil + } + + // Create the VNet with a subnet that has privateLinkServiceNetworkPolicies disabled + disabled := armnetwork.VirtualNetworkPrivateLinkServiceNetworkPoliciesDisabled + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, vnetName, armnetwork.VirtualNetwork{ + Location: new(location), + Properties: &armnetwork.VirtualNetworkPropertiesFormat{ + AddressSpace: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{new("10.3.0.0/16")}, + }, + Subnets: []*armnetwork.Subnet{ + { + Name: new(integrationTestSubnetNameForPLS), + Properties: &armnetwork.SubnetPropertiesFormat{ + AddressPrefix: new("10.3.0.0/24"), + PrivateLinkServiceNetworkPolicies: &disabled, + }, + }, + }, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to begin creating virtual network: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create virtual network: %w", err) + } + + log.Printf("Virtual network %s created successfully", vnetName) + return nil +} + +// deleteVirtualNetworkForPLS deletes an Azure virtual network +func deleteVirtualNetworkForPLS(ctx context.Context, client *armnetwork.VirtualNetworksClient, resourceGroupName, vnetName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, vnetName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Virtual network %s not found, skipping deletion", vnetName) + return nil + } + return fmt.Errorf("failed to begin deleting virtual network: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete virtual network: %w", err) + } + + log.Printf("Virtual network %s deleted successfully", vnetName) + return nil +} + +// createInternalLoadBalancerForPLS creates an Azure internal load balancer for private link service +func createInternalLoadBalancerForPLS(ctx context.Context, client *armnetwork.LoadBalancersClient, subscriptionID, resourceGroupName, lbName, location, subnetID string) error { + // Check if load balancer already exists + _, err := client.Get(ctx, resourceGroupName, lbName, nil) + if err == nil { + log.Printf("Load balancer %s already exists, skipping creation", lbName) + return nil + } + + // Create the internal load balancer + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, lbName, armnetwork.LoadBalancer{ + Location: new(location), + Properties: &armnetwork.LoadBalancerPropertiesFormat{ + FrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ + { + Name: new(integrationTestFrontendIPForPLS), + Properties: &armnetwork.FrontendIPConfigurationPropertiesFormat{ + Subnet: &armnetwork.Subnet{ + ID: new(subnetID), + }, + PrivateIPAllocationMethod: new(armnetwork.IPAllocationMethodDynamic), + }, + }, + }, + BackendAddressPools: []*armnetwork.BackendAddressPool{ + { + Name: new(integrationTestBackendPoolForPLS), + }, + }, + LoadBalancingRules: []*armnetwork.LoadBalancingRule{ + { + Name: new("lb-rule"), + Properties: &armnetwork.LoadBalancingRulePropertiesFormat{ + FrontendIPConfiguration: &armnetwork.SubResource{ + ID: new(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/frontendIPConfigurations/%s", + subscriptionID, resourceGroupName, lbName, integrationTestFrontendIPForPLS)), + }, + BackendAddressPool: &armnetwork.SubResource{ + ID: new(fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/%s/backendAddressPools/%s", + subscriptionID, resourceGroupName, lbName, integrationTestBackendPoolForPLS)), + }, + Protocol: new(armnetwork.TransportProtocolTCP), + FrontendPort: new(int32(80)), + BackendPort: new(int32(80)), + EnableFloatingIP: new(false), + IdleTimeoutInMinutes: new(int32(4)), + }, + }, + }, + }, + SKU: &armnetwork.LoadBalancerSKU{ + Name: new(armnetwork.LoadBalancerSKUNameStandard), + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to begin creating load balancer: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create load balancer: %w", err) + } + + log.Printf("Load balancer %s created successfully", lbName) + return nil +} + +// createPrivateLinkService creates an Azure Private Link Service +func createPrivateLinkService(ctx context.Context, client *armnetwork.PrivateLinkServicesClient, resourceGroupName, plsName, location, subnetID, frontendIPConfigID string) error { + // Check if private link service already exists + _, err := client.Get(ctx, resourceGroupName, plsName, nil) + if err == nil { + log.Printf("Private link service %s already exists, skipping creation", plsName) + return nil + } + + // Create the private link service + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, plsName, armnetwork.PrivateLinkService{ + Location: new(location), + Properties: &armnetwork.PrivateLinkServiceProperties{ + LoadBalancerFrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ + { + ID: new(frontendIPConfigID), + }, + }, + IPConfigurations: []*armnetwork.PrivateLinkServiceIPConfiguration{ + { + Name: new("pls-ip-config"), + Properties: &armnetwork.PrivateLinkServiceIPConfigurationProperties{ + Subnet: &armnetwork.Subnet{ + ID: new(subnetID), + }, + PrivateIPAllocationMethod: new(armnetwork.IPAllocationMethodDynamic), + Primary: new(true), + }, + }, + }, + EnableProxyProtocol: new(false), + Fqdns: []*string{ + new("test-pls.example.com"), + }, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusConflict { + // Verify the resource actually exists before treating conflict as success + if _, getErr := client.Get(ctx, resourceGroupName, plsName, nil); getErr == nil { + log.Printf("Private link service %s already exists (conflict), skipping creation", plsName) + return nil + } + return fmt.Errorf("private link service %s conflict but not retrievable: %w", plsName, err) + } + return fmt.Errorf("failed to begin creating private link service: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create private link service: %w", err) + } + + log.Printf("Private link service %s created successfully", plsName) + return nil +} + +// deletePrivateLinkService deletes an Azure Private Link Service +func deletePrivateLinkService(ctx context.Context, client *armnetwork.PrivateLinkServicesClient, resourceGroupName, plsName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, plsName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Private link service %s not found, skipping deletion", plsName) + return nil + } + return fmt.Errorf("failed to begin deleting private link service: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete private link service: %w", err) + } + + log.Printf("Private link service %s deleted successfully", plsName) + return nil +} diff --git a/sources/azure/integration-tests/operational-insights-workspace_test.go b/sources/azure/integration-tests/operational-insights-workspace_test.go new file mode 100644 index 00000000..4df2f838 --- /dev/null +++ b/sources/azure/integration-tests/operational-insights-workspace_test.go @@ -0,0 +1,421 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" +) + +const ( + integrationTestWorkspaceName = "ovm-integ-test-workspace" +) + +// errOperationalInsightsAuthorizationFailed is a sentinel error for authorization failures +var errOperationalInsightsAuthorizationFailed = errors.New("authorization failed for Operational Insights resource provider") + +func TestOperationalInsightsWorkspaceIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + // Initialize Azure credentials using DefaultAzureCredential + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + // Create Azure SDK clients + workspacesClient, err := armoperationalinsights.NewWorkspacesClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Workspaces client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + // Create resource group if it doesn't exist + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + // Create workspace + err = createOperationalInsightsWorkspace(ctx, workspacesClient, integrationTestResourceGroup, integrationTestWorkspaceName, integrationTestLocation) + if err != nil { + if errors.Is(err, errOperationalInsightsAuthorizationFailed) { + t.Skipf("Skipping test: %v (service principal lacks permission to register Microsoft.OperationalInsights resource provider)", err) + } + t.Fatalf("Failed to create workspace: %v", err) + } + + // Wait for workspace to be fully available + err = waitForOperationalInsightsWorkspaceAvailable(ctx, workspacesClient, integrationTestResourceGroup, integrationTestWorkspaceName) + if err != nil { + t.Fatalf("Failed waiting for workspace to be available: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetWorkspace", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving workspace %s in subscription %s, resource group %s", + integrationTestWorkspaceName, subscriptionID, integrationTestResourceGroup) + + workspaceWrapper := manual.NewOperationalInsightsWorkspace( + clients.NewOperationalInsightsWorkspaceClient(workspacesClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := workspaceWrapper.Scopes()[0] + + workspaceAdapter := sources.WrapperToAdapter(workspaceWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := workspaceAdapter.Get(ctx, scope, integrationTestWorkspaceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + if uniqueAttrValue != integrationTestWorkspaceName { + t.Fatalf("Expected unique attribute value to be %s, got %s", integrationTestWorkspaceName, uniqueAttrValue) + } + + log.Printf("Successfully retrieved workspace %s", integrationTestWorkspaceName) + }) + + t.Run("ListWorkspaces", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Listing workspaces in subscription %s, resource group %s", + subscriptionID, integrationTestResourceGroup) + + workspaceWrapper := manual.NewOperationalInsightsWorkspace( + clients.NewOperationalInsightsWorkspaceClient(workspacesClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := workspaceWrapper.Scopes()[0] + + workspaceAdapter := sources.WrapperToAdapter(workspaceWrapper, sdpcache.NewNoOpCache()) + + // Check if adapter supports listing + listable, ok := workspaceAdapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Failed to list workspaces: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one workspace, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil && v == integrationTestWorkspaceName { + found = true + break + } + } + + if !found { + t.Fatalf("Expected to find workspace %s in the list of workspaces", integrationTestWorkspaceName) + } + + log.Printf("Found %d workspaces in resource group %s", len(sdpItems), integrationTestResourceGroup) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying item attributes for workspace %s", integrationTestWorkspaceName) + + workspaceWrapper := manual.NewOperationalInsightsWorkspace( + clients.NewOperationalInsightsWorkspaceClient(workspacesClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := workspaceWrapper.Scopes()[0] + + workspaceAdapter := sources.WrapperToAdapter(workspaceWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := workspaceAdapter.Get(ctx, scope, integrationTestWorkspaceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify item type + if sdpItem.GetType() != azureshared.OperationalInsightsWorkspace.String() { + t.Errorf("Expected item type %s, got %s", azureshared.OperationalInsightsWorkspace, sdpItem.GetType()) + } + + // Verify scope + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + // Verify unique attribute + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + // Verify item validation + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Verified item attributes for workspace %s", integrationTestWorkspaceName) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for workspace %s", integrationTestWorkspaceName) + + workspaceWrapper := manual.NewOperationalInsightsWorkspace( + clients.NewOperationalInsightsWorkspaceClient(workspacesClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := workspaceWrapper.Scopes()[0] + + workspaceAdapter := sources.WrapperToAdapter(workspaceWrapper, sdpcache.NewNoOpCache()) + sdpItem, qErr := workspaceAdapter.Get(ctx, scope, integrationTestWorkspaceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify that linked items exist (if any) + linkedQueries := sdpItem.GetLinkedItemQueries() + log.Printf("Found %d linked item queries for workspace %s", len(linkedQueries), integrationTestWorkspaceName) + + // For a standalone workspace without private link, there may not be any linked items + // But we should verify the structure is correct if links exist + for _, liq := range linkedQueries { + query := liq.GetQuery() + if query == nil { + t.Error("Linked item query has nil Query") + continue + } + + // Verify query has required fields + if query.GetType() == "" { + t.Error("Linked item query has empty Type") + } + // Method should be GET or SEARCH (not empty) + if query.GetMethod() == sdp.QueryMethod_GET || query.GetMethod() == sdp.QueryMethod_SEARCH { + // Valid method + } else { + t.Errorf("Linked item query has unexpected Method: %v", query.GetMethod()) + } + if query.GetQuery() == "" { + t.Error("Linked item query has empty Query") + } + if query.GetScope() == "" { + t.Error("Linked item query has empty Scope") + } + + log.Printf("Verified linked item query: Type=%s, Method=%s, Query=%s, Scope=%s", + query.GetType(), query.GetMethod(), query.GetQuery(), query.GetScope()) + } + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + // Delete workspace + err := deleteOperationalInsightsWorkspace(ctx, workspacesClient, integrationTestResourceGroup, integrationTestWorkspaceName) + if err != nil { + t.Fatalf("Failed to delete workspace: %v", err) + } + + // Note: We keep the resource group for faster subsequent test runs + }) +} + +// createOperationalInsightsWorkspace creates an Azure Log Analytics workspace (idempotent) +func createOperationalInsightsWorkspace(ctx context.Context, client *armoperationalinsights.WorkspacesClient, resourceGroupName, workspaceName, location string) error { + // Check if workspace already exists + existingWorkspace, err := client.Get(ctx, resourceGroupName, workspaceName, nil) + if err == nil { + // Workspace exists, check its state + if existingWorkspace.Properties != nil && existingWorkspace.Properties.ProvisioningState != nil { + state := *existingWorkspace.Properties.ProvisioningState + if state == armoperationalinsights.WorkspaceEntityStatusSucceeded { + log.Printf("Workspace %s already exists with state %s, skipping creation", workspaceName, state) + return nil + } + log.Printf("Workspace %s exists but in state %s, will wait for it", workspaceName, state) + } else { + log.Printf("Workspace %s already exists, skipping creation", workspaceName) + return nil + } + } + + // Create the workspace + retentionDays := int32(30) + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroupName, workspaceName, armoperationalinsights.Workspace{ + Location: new(location), + Properties: &armoperationalinsights.WorkspaceProperties{ + RetentionInDays: &retentionDays, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "test": new("operational-insights-workspace"), + }, + }, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) { + // Check for authorization failure (resource provider not registered) + if respErr.StatusCode == http.StatusForbidden && respErr.ErrorCode == "AuthorizationFailed" { + return fmt.Errorf("%w: %s", errOperationalInsightsAuthorizationFailed, respErr.Error()) + } + // Check for missing resource provider registration + if strings.Contains(respErr.Error(), "register/action") { + return fmt.Errorf("%w: %s", errOperationalInsightsAuthorizationFailed, respErr.Error()) + } + // Check if workspace already exists (conflict) + if respErr.StatusCode == http.StatusConflict { + // Verify conflict is real before treating it as success. + if _, getErr := client.Get(ctx, resourceGroupName, workspaceName, nil); getErr == nil { + log.Printf("Workspace %s already exists (conflict), skipping", workspaceName) + return nil + } + return fmt.Errorf("workspace %s conflict but not retrievable: %w", workspaceName, err) + } + } + return fmt.Errorf("failed to begin creating workspace: %w", err) + } + + resp, err := poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create workspace: %w", err) + } + + // Verify the workspace was created successfully + if resp.Properties == nil || resp.Properties.ProvisioningState == nil { + return fmt.Errorf("workspace created but provisioning state is unknown") + } + + provisioningState := *resp.Properties.ProvisioningState + if provisioningState != armoperationalinsights.WorkspaceEntityStatusSucceeded { + return fmt.Errorf("workspace provisioning state is %s, expected Succeeded", provisioningState) + } + + log.Printf("Workspace %s created successfully with provisioning state: %s", workspaceName, provisioningState) + return nil +} + +// waitForOperationalInsightsWorkspaceAvailable polls until the workspace is available via the Get API +func waitForOperationalInsightsWorkspaceAvailable(ctx context.Context, client *armoperationalinsights.WorkspacesClient, resourceGroupName, workspaceName string) error { + maxAttempts := 20 + pollInterval := 5 * time.Second + maxNotFoundAttempts := 5 + notFoundCount := 0 + + log.Printf("Waiting for workspace %s to be available via API...", workspaceName) + + for attempt := 1; attempt <= maxAttempts; attempt++ { + resp, err := client.Get(ctx, resourceGroupName, workspaceName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + notFoundCount++ + if notFoundCount >= maxNotFoundAttempts { + return fmt.Errorf("workspace %s not found after %d attempts", workspaceName, notFoundCount) + } + log.Printf("Workspace %s not yet available (attempt %d/%d), waiting %v...", workspaceName, attempt, maxAttempts, pollInterval) + time.Sleep(pollInterval) + continue + } + return fmt.Errorf("error checking workspace availability: %w", err) + } + notFoundCount = 0 + + // Check provisioning state + if resp.Properties != nil && resp.Properties.ProvisioningState != nil { + state := *resp.Properties.ProvisioningState + if state == armoperationalinsights.WorkspaceEntityStatusSucceeded { + log.Printf("Workspace %s is available with provisioning state: %s", workspaceName, state) + return nil + } + if state == armoperationalinsights.WorkspaceEntityStatusFailed { + return fmt.Errorf("workspace provisioning failed with state: %s", state) + } + // Still provisioning, wait and retry + log.Printf("Workspace %s provisioning state: %s (attempt %d/%d), waiting...", workspaceName, state, attempt, maxAttempts) + time.Sleep(pollInterval) + continue + } + + // Workspace exists but no provisioning state - consider it available + log.Printf("Workspace %s is available", workspaceName) + return nil + } + + return fmt.Errorf("timeout waiting for workspace %s to be available after %d attempts", workspaceName, maxAttempts) +} + +// deleteOperationalInsightsWorkspace deletes an Azure Log Analytics workspace +func deleteOperationalInsightsWorkspace(ctx context.Context, client *armoperationalinsights.WorkspacesClient, resourceGroupName, workspaceName string) error { + poller, err := client.BeginDelete(ctx, resourceGroupName, workspaceName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Workspace %s not found, skipping deletion", workspaceName) + return nil + } + return fmt.Errorf("failed to begin deleting workspace: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete workspace: %w", err) + } + + log.Printf("Workspace %s deleted successfully", workspaceName) + return nil +} diff --git a/sources/azure/integration-tests/sql-server-failover-group_test.go b/sources/azure/integration-tests/sql-server-failover-group_test.go new file mode 100644 index 00000000..156b9e72 --- /dev/null +++ b/sources/azure/integration-tests/sql-server-failover-group_test.go @@ -0,0 +1,653 @@ +package integrationtests + +import ( + "context" + "errors" + "fmt" + "math/rand" + "net/http" + "os" + "strings" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +const ( + integrationTestFailoverGroupName = "ovm-integ-test-failover-group" + integrationTestPrimaryServerName = "ovm-integ-test-primary-server" + integrationTestSecondaryServerName = "ovm-integ-test-secondary-server" + integrationTestPrimaryLocation = "westus2" + integrationTestSecondaryLocation = "eastus" + integrationTestFailoverGroupDBName = "ovm-integ-test-fg-database" +) + +func TestSQLServerFailoverGroupIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + // SQL server admin credentials are required for creating SQL servers + adminLogin := os.Getenv("AZURE_SQL_SERVER_ADMIN_LOGIN") + adminPassword := os.Getenv("AZURE_SQL_SERVER_ADMIN_PASSWORD") + if adminLogin == "" || adminPassword == "" { + t.Skip("AZURE_SQL_SERVER_ADMIN_LOGIN and AZURE_SQL_SERVER_ADMIN_PASSWORD environment variables must be set for SQL failover group integration tests") + } + + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + // Create Azure SDK clients + sqlServerClient, err := armsql.NewServersClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create SQL Servers client: %v", err) + } + + sqlDatabaseClient, err := armsql.NewDatabasesClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create SQL Databases client: %v", err) + } + + sqlFailoverGroupClient, err := armsql.NewFailoverGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create SQL Failover Groups client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + // Generate unique names for SQL servers (must be globally unique) + primaryServerName := generateFailoverGroupServerName(integrationTestPrimaryServerName) + secondaryServerName := generateFailoverGroupServerName(integrationTestSecondaryServerName) + + var setupCompleted bool + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + // Create resource group if it doesn't exist + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + // Create primary SQL server + err = createFailoverGroupSQLServer(ctx, sqlServerClient, integrationTestResourceGroup, primaryServerName, integrationTestPrimaryLocation) + if err != nil { + t.Fatalf("Failed to create primary SQL server: %v", err) + } + + // Wait for primary SQL server to be available + err = waitForFailoverGroupSQLServerAvailable(ctx, sqlServerClient, integrationTestResourceGroup, primaryServerName) + if err != nil { + t.Fatalf("Failed waiting for primary SQL server to be available: %v", err) + } + + // Create secondary SQL server (in a different region) + err = createFailoverGroupSQLServer(ctx, sqlServerClient, integrationTestResourceGroup, secondaryServerName, integrationTestSecondaryLocation) + if err != nil { + t.Fatalf("Failed to create secondary SQL server: %v", err) + } + + // Wait for secondary SQL server to be available + err = waitForFailoverGroupSQLServerAvailable(ctx, sqlServerClient, integrationTestResourceGroup, secondaryServerName) + if err != nil { + t.Fatalf("Failed waiting for secondary SQL server to be available: %v", err) + } + + // Create a database on the primary server (failover groups need at least one database) + err = createFailoverGroupDatabase(ctx, sqlDatabaseClient, integrationTestResourceGroup, primaryServerName, integrationTestFailoverGroupDBName, integrationTestPrimaryLocation) + if err != nil { + t.Fatalf("Failed to create database: %v", err) + } + + // Wait for database to be available + err = waitForFailoverGroupDatabaseAvailable(ctx, sqlDatabaseClient, integrationTestResourceGroup, primaryServerName, integrationTestFailoverGroupDBName) + if err != nil { + t.Fatalf("Failed waiting for database to be available: %v", err) + } + + // Create the failover group + err = createFailoverGroup(ctx, sqlFailoverGroupClient, integrationTestResourceGroup, primaryServerName, secondaryServerName, integrationTestFailoverGroupName, subscriptionID) + if err != nil { + t.Fatalf("Failed to create failover group: %v", err) + } + + // Wait for the failover group to be available + err = waitForFailoverGroupAvailable(ctx, sqlFailoverGroupClient, integrationTestResourceGroup, primaryServerName, integrationTestFailoverGroupName) + if err != nil { + t.Fatalf("Failed waiting for failover group to be available: %v", err) + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetSQLServerFailoverGroup", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving failover group %s in SQL server %s, subscription %s, resource group %s", + integrationTestFailoverGroupName, primaryServerName, subscriptionID, integrationTestResourceGroup) + + wrapper := manual.NewSqlServerFailoverGroup( + clients.NewSqlFailoverGroupsClient(sqlFailoverGroupClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(primaryServerName, integrationTestFailoverGroupName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + if sdpItem.GetType() != azureshared.SQLServerFailoverGroup.String() { + t.Errorf("Expected type %s, got %s", azureshared.SQLServerFailoverGroup, sdpItem.GetType()) + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + if uniqueAttrKey != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", uniqueAttrKey) + } + + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + expectedUniqueAttrValue := shared.CompositeLookupKey(primaryServerName, integrationTestFailoverGroupName) + if uniqueAttrValue != expectedUniqueAttrValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueAttrValue, uniqueAttrValue) + } + + if sdpItem.GetScope() != fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) { + t.Errorf("Expected scope %s.%s, got %s", subscriptionID, integrationTestResourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Successfully retrieved failover group %s", integrationTestFailoverGroupName) + }) + + t.Run("SearchSQLServerFailoverGroups", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Searching failover groups in SQL server %s", primaryServerName) + + wrapper := manual.NewSqlServerFailoverGroup( + clients.NewSqlFailoverGroupsClient(sqlFailoverGroupClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, scope, primaryServerName, true) + if err != nil { + t.Fatalf("Failed to search failover groups: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one failover group, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil { + expectedValue := shared.CompositeLookupKey(primaryServerName, integrationTestFailoverGroupName) + if v == expectedValue { + found = true + break + } + } + } + + if !found { + t.Fatalf("Expected to find failover group %s in the search results", integrationTestFailoverGroupName) + } + + log.Printf("Found %d failover groups in search results", len(sdpItems)) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for failover group %s", integrationTestFailoverGroupName) + + wrapper := manual.NewSqlServerFailoverGroup( + clients.NewSqlFailoverGroupsClient(sqlFailoverGroupClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(primaryServerName, integrationTestFailoverGroupName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify that linked items exist + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + var hasPrimaryServerLink bool + var hasPartnerServerLink bool + for _, liq := range linkedQueries { + query := liq.GetQuery() + if query.GetType() == "" { + t.Error("Found linked query with empty type") + } + if query.GetMethod() != sdp.QueryMethod_GET && query.GetMethod() != sdp.QueryMethod_SEARCH { + t.Errorf("Found linked query with invalid method: %s", query.GetMethod()) + } + if query.GetQuery() == "" { + t.Error("Found linked query with empty query") + } + if query.GetScope() == "" { + t.Error("Found linked query with empty scope") + } + + if query.GetType() == azureshared.SQLServer.String() { + if query.GetQuery() == primaryServerName { + hasPrimaryServerLink = true + } + if query.GetQuery() == secondaryServerName { + hasPartnerServerLink = true + } + } + } + + if !hasPrimaryServerLink { + t.Error("Expected linked query to primary SQL server, but didn't find one") + } + + if !hasPartnerServerLink { + t.Error("Expected linked query to partner (secondary) SQL server, but didn't find one") + } + + log.Printf("Verified %d linked item queries for failover group %s", len(linkedQueries), integrationTestFailoverGroupName) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + wrapper := manual.NewSqlServerFailoverGroup( + clients.NewSqlFailoverGroupsClient(sqlFailoverGroupClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := wrapper.Scopes()[0] + + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(primaryServerName, integrationTestFailoverGroupName) + sdpItem, qErr := adapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.SQLServerFailoverGroup.String() { + t.Errorf("Expected type %s, got %s", azureshared.SQLServerFailoverGroup, sdpItem.GetType()) + } + + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + // Delete the failover group first + err := deleteFailoverGroup(ctx, sqlFailoverGroupClient, integrationTestResourceGroup, primaryServerName, integrationTestFailoverGroupName) + if err != nil { + t.Logf("Warning: Failed to delete failover group: %v", err) + } + + // Delete the database + err = deleteFailoverGroupDatabase(ctx, sqlDatabaseClient, integrationTestResourceGroup, primaryServerName, integrationTestFailoverGroupDBName) + if err != nil { + t.Logf("Warning: Failed to delete database: %v", err) + } + + // Delete secondary SQL server first (since failover group is deleted) + err = deleteFailoverGroupSQLServer(ctx, sqlServerClient, integrationTestResourceGroup, secondaryServerName) + if err != nil { + t.Logf("Warning: Failed to delete secondary SQL server: %v", err) + } + + // Delete primary SQL server + err = deleteFailoverGroupSQLServer(ctx, sqlServerClient, integrationTestResourceGroup, primaryServerName) + if err != nil { + t.Logf("Warning: Failed to delete primary SQL server: %v", err) + } + }) +} + +// generateFailoverGroupServerName generates a unique SQL server name for failover group tests +func generateFailoverGroupServerName(baseName string) string { + baseName = strings.ToLower(baseName) + baseName = strings.ReplaceAll(baseName, "_", "-") + baseName = strings.ReplaceAll(baseName, " ", "-") + + rng := rand.New(rand.NewSource(time.Now().UnixNano() + int64(os.Getpid()))) + suffix := rng.Intn(10000) + return fmt.Sprintf("%s-%04d", baseName, suffix) +} + +// createFailoverGroupSQLServer creates an Azure SQL server for failover group testing +func createFailoverGroupSQLServer(ctx context.Context, client *armsql.ServersClient, resourceGroup, serverName, location string) error { + _, err := client.Get(ctx, resourceGroup, serverName, nil) + if err == nil { + log.Printf("SQL server %s already exists, skipping creation", serverName) + return nil + } + + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) { + return fmt.Errorf("failed to check if SQL server exists: %w", err) + } + if respErr != nil && respErr.StatusCode != http.StatusNotFound { + return fmt.Errorf("failed to check if SQL server exists: %w", err) + } + + adminLogin := os.Getenv("AZURE_SQL_SERVER_ADMIN_LOGIN") + adminPassword := os.Getenv("AZURE_SQL_SERVER_ADMIN_PASSWORD") + + if adminLogin == "" || adminPassword == "" { + return fmt.Errorf("AZURE_SQL_SERVER_ADMIN_LOGIN and AZURE_SQL_SERVER_ADMIN_PASSWORD environment variables must be set for integration tests") + } + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroup, serverName, armsql.Server{ + Location: &location, + Properties: &armsql.ServerProperties{ + AdministratorLogin: &adminLogin, + AdministratorLoginPassword: &adminPassword, + Version: new("12.0"), + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "managed": new("true"), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to start SQL server creation: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create SQL server: %w", err) + } + + log.Printf("SQL server %s created successfully in location %s", serverName, location) + return nil +} + +// waitForFailoverGroupSQLServerAvailable waits for a SQL server to be available +func waitForFailoverGroupSQLServerAvailable(ctx context.Context, client *armsql.ServersClient, resourceGroup, serverName string) error { + maxAttempts := 60 // Longer timeout for failover group tests + for range maxAttempts { + server, err := client.Get(ctx, resourceGroup, serverName, nil) + if err == nil { + if server.Properties != nil && server.Properties.State != nil && *server.Properties.State == "Ready" { + return nil + } + } + time.Sleep(5 * time.Second) + } + return fmt.Errorf("SQL server %s did not become available within expected time", serverName) +} + +// createFailoverGroupDatabase creates an Azure SQL database for failover group +func createFailoverGroupDatabase(ctx context.Context, client *armsql.DatabasesClient, resourceGroup, serverName, databaseName, location string) error { + _, err := client.Get(ctx, resourceGroup, serverName, databaseName, nil) + if err == nil { + log.Printf("SQL database %s already exists, skipping creation", databaseName) + return nil + } + + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) { + return fmt.Errorf("failed to check if SQL database exists: %w", err) + } + if respErr != nil && respErr.StatusCode != http.StatusNotFound { + return fmt.Errorf("failed to check if SQL database exists: %w", err) + } + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroup, serverName, databaseName, armsql.Database{ + Location: &location, + Properties: &armsql.DatabaseProperties{ + RequestedServiceObjectiveName: new("Basic"), + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "managed": new("true"), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to start SQL database creation: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create SQL database: %w", err) + } + + log.Printf("SQL database %s created successfully in server %s", databaseName, serverName) + return nil +} + +// waitForFailoverGroupDatabaseAvailable waits for a SQL database to be available +func waitForFailoverGroupDatabaseAvailable(ctx context.Context, client *armsql.DatabasesClient, resourceGroup, serverName, databaseName string) error { + maxAttempts := 60 + for range maxAttempts { + database, err := client.Get(ctx, resourceGroup, serverName, databaseName, nil) + if err == nil { + if database.Properties != nil && database.Properties.Status != nil && *database.Properties.Status == armsql.DatabaseStatusOnline { + return nil + } + } + time.Sleep(5 * time.Second) + } + return fmt.Errorf("SQL database %s did not become available within expected time", databaseName) +} + +// createFailoverGroup creates an Azure SQL Failover Group +func createFailoverGroup(ctx context.Context, client *armsql.FailoverGroupsClient, resourceGroup, primaryServerName, secondaryServerName, failoverGroupName, subscriptionID string) error { + _, err := client.Get(ctx, resourceGroup, primaryServerName, failoverGroupName, nil) + if err == nil { + log.Printf("Failover group %s already exists, skipping creation", failoverGroupName) + return nil + } + + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) { + return fmt.Errorf("failed to check if failover group exists: %w", err) + } + if respErr != nil && respErr.StatusCode != http.StatusNotFound { + return fmt.Errorf("failed to check if failover group exists: %w", err) + } + + secondaryServerID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Sql/servers/%s", + subscriptionID, resourceGroup, secondaryServerName) + + poller, err := client.BeginCreateOrUpdate(ctx, resourceGroup, primaryServerName, failoverGroupName, armsql.FailoverGroup{ + Properties: &armsql.FailoverGroupProperties{ + PartnerServers: []*armsql.PartnerInfo{ + { + ID: &secondaryServerID, + }, + }, + ReadWriteEndpoint: &armsql.FailoverGroupReadWriteEndpoint{ + FailoverPolicy: new(armsql.ReadWriteEndpointFailoverPolicyAutomatic), + FailoverWithDataLossGracePeriodMinutes: new(int32(60)), + }, + ReadOnlyEndpoint: &armsql.FailoverGroupReadOnlyEndpoint{ + FailoverPolicy: new(armsql.ReadOnlyEndpointFailoverPolicyDisabled), + }, + Databases: []*string{}, + }, + Tags: map[string]*string{ + "purpose": new("overmind-integration-tests"), + "managed": new("true"), + }, + }, nil) + if err != nil { + return fmt.Errorf("failed to start failover group creation: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to create failover group: %w", err) + } + + log.Printf("Failover group %s created successfully", failoverGroupName) + return nil +} + +// waitForFailoverGroupAvailable waits for a failover group to be available +func waitForFailoverGroupAvailable(ctx context.Context, client *armsql.FailoverGroupsClient, resourceGroup, serverName, failoverGroupName string) error { + maxAttempts := 60 + for range maxAttempts { + fg, err := client.Get(ctx, resourceGroup, serverName, failoverGroupName, nil) + if err == nil { + // Replication state can be empty string (ready), "CATCH_UP", "PENDING", "SEEDING", "SUSPENDED" + if fg.Properties != nil && fg.Properties.ReplicationState != nil { + state := *fg.Properties.ReplicationState + if state == "" || state == "CATCH_UP" { + // Empty string or CATCH_UP indicates the failover group is functional + return nil + } + } else if fg.Properties != nil { + // ReplicationState is nil, check if properties exist (group created) + return nil + } + } + time.Sleep(5 * time.Second) + } + return fmt.Errorf("failover group %s did not become available within expected time", failoverGroupName) +} + +// deleteFailoverGroup deletes an Azure SQL Failover Group +func deleteFailoverGroup(ctx context.Context, client *armsql.FailoverGroupsClient, resourceGroup, serverName, failoverGroupName string) error { + _, err := client.Get(ctx, resourceGroup, serverName, failoverGroupName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("Failover group %s does not exist, skipping deletion", failoverGroupName) + return nil + } + return fmt.Errorf("failed to check if failover group exists: %w", err) + } + + poller, err := client.BeginDelete(ctx, resourceGroup, serverName, failoverGroupName, nil) + if err != nil { + return fmt.Errorf("failed to start failover group deletion: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete failover group: %w", err) + } + + log.Printf("Failover group %s deleted successfully", failoverGroupName) + return nil +} + +// deleteFailoverGroupDatabase deletes an Azure SQL database +func deleteFailoverGroupDatabase(ctx context.Context, client *armsql.DatabasesClient, resourceGroup, serverName, databaseName string) error { + _, err := client.Get(ctx, resourceGroup, serverName, databaseName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("SQL database %s does not exist, skipping deletion", databaseName) + return nil + } + return fmt.Errorf("failed to check if SQL database exists: %w", err) + } + + poller, err := client.BeginDelete(ctx, resourceGroup, serverName, databaseName, nil) + if err != nil { + return fmt.Errorf("failed to start SQL database deletion: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete SQL database: %w", err) + } + + log.Printf("SQL database %s deleted successfully", databaseName) + return nil +} + +// deleteFailoverGroupSQLServer deletes an Azure SQL server +func deleteFailoverGroupSQLServer(ctx context.Context, client *armsql.ServersClient, resourceGroup, serverName string) error { + _, err := client.Get(ctx, resourceGroup, serverName, nil) + if err != nil { + var respErr *azcore.ResponseError + if errors.As(err, &respErr) && respErr.StatusCode == http.StatusNotFound { + log.Printf("SQL server %s does not exist, skipping deletion", serverName) + return nil + } + return fmt.Errorf("failed to check if SQL server exists: %w", err) + } + + poller, err := client.BeginDelete(ctx, resourceGroup, serverName, nil) + if err != nil { + return fmt.Errorf("failed to start SQL server deletion: %w", err) + } + + _, err = poller.PollUntilDone(ctx, nil) + if err != nil { + return fmt.Errorf("failed to delete SQL server: %w", err) + } + + log.Printf("SQL server %s deleted successfully", serverName) + return nil +} diff --git a/sources/azure/integration-tests/sql-server-key_test.go b/sources/azure/integration-tests/sql-server-key_test.go new file mode 100644 index 00000000..26fa7ace --- /dev/null +++ b/sources/azure/integration-tests/sql-server-key_test.go @@ -0,0 +1,341 @@ +package integrationtests + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" + log "github.com/sirupsen/logrus" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +// findExistingSQLServer searches for an existing SQL server in the resource group +// Returns the server name if found, empty string otherwise +func findExistingSQLServer(ctx context.Context, client *armsql.ServersClient, resourceGroup string) string { + pager := client.NewListByResourceGroupPager(resourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + log.Printf("Failed to list SQL servers: %v", err) + return "" + } + for _, server := range page.Value { + if server.Name != nil && *server.Name != "" { + log.Printf("Found existing SQL server: %s", *server.Name) + return *server.Name + } + } + } + return "" +} + +func TestSQLServerKeyIntegration(t *testing.T) { + subscriptionID := os.Getenv("AZURE_SUBSCRIPTION_ID") + if subscriptionID == "" { + t.Skip("AZURE_SUBSCRIPTION_ID environment variable not set") + } + + // Initialize Azure credentials using DefaultAzureCredential + cred, err := azureshared.NewAzureCredential(t.Context()) + if err != nil { + t.Fatalf("Failed to create Azure credential: %v", err) + } + + // Create Azure SDK clients + sqlServerClient, err := armsql.NewServersClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create SQL Servers client: %v", err) + } + + serverKeysClient, err := armsql.NewServerKeysClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create SQL Server Keys client: %v", err) + } + + rgClient, err := armresources.NewResourceGroupsClient(subscriptionID, cred, nil) + if err != nil { + t.Fatalf("Failed to create Resource Groups client: %v", err) + } + + // Track setup completion for skipping Run if Setup fails + setupCompleted := false + + // Track if we created the server (for cleanup) + serverCreated := false + + // SQL server name - will be set in Setup + var sqlServerName string + + // The ServiceManaged key name is always "ServiceManaged" + const serviceManagedKeyName = "ServiceManaged" + + t.Run("Setup", func(t *testing.T) { + ctx := t.Context() + + // Create resource group if it doesn't exist + err := createResourceGroup(ctx, rgClient, integrationTestResourceGroup, integrationTestLocation) + if err != nil { + t.Fatalf("Failed to create resource group: %v", err) + } + + // First, try to find an existing SQL server to reuse + // This helps when admin credentials are not available + sqlServerName = findExistingSQLServer(ctx, sqlServerClient, integrationTestResourceGroup) + + if sqlServerName == "" { + // No existing server found, try to create one + sqlServerName = generateSQLServerName(integrationTestSQLServerName) + err = createSQLServer(ctx, sqlServerClient, integrationTestResourceGroup, sqlServerName, integrationTestLocation) + if err != nil { + t.Skipf("Skipping test: Failed to create SQL server (admin credentials may be missing): %v", err) + } + serverCreated = true + + // Wait for SQL server to be available + err = waitForSQLServerAvailable(ctx, sqlServerClient, integrationTestResourceGroup, sqlServerName) + if err != nil { + t.Fatalf("Failed waiting for SQL server to be available: %v", err) + } + } + + setupCompleted = true + }) + + t.Run("Run", func(t *testing.T) { + if !setupCompleted { + t.Skip("Skipping Run: Setup did not complete successfully") + } + + t.Run("GetSQLServerKey", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Retrieving SQL server key %s for server %s in subscription %s, resource group %s", + serviceManagedKeyName, sqlServerName, subscriptionID, integrationTestResourceGroup) + + serverKeyWrapper := manual.NewSqlServerKey( + clients.NewSqlServerKeysClient(serverKeysClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := serverKeyWrapper.Scopes()[0] + + serverKeyAdapter := sources.WrapperToAdapter(serverKeyWrapper, sdpcache.NewNoOpCache()) + // Get requires serverName and keyName as query parts + query := shared.CompositeLookupKey(sqlServerName, serviceManagedKeyName) + sdpItem, qErr := serverKeyAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem == nil { + t.Fatalf("Expected sdpItem to be non-nil") + } + + if sdpItem.GetType() != azureshared.SQLServerKey.String() { + t.Errorf("Expected type %s, got %s", azureshared.SQLServerKey, sdpItem.GetType()) + } + + uniqueAttrKey := sdpItem.GetUniqueAttribute() + if uniqueAttrKey != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", uniqueAttrKey) + } + + uniqueAttrValue, err := sdpItem.GetAttributes().Get(uniqueAttrKey) + if err != nil { + t.Fatalf("Failed to get unique attribute: %v", err) + } + + expectedUniqueAttrValue := shared.CompositeLookupKey(sqlServerName, serviceManagedKeyName) + if uniqueAttrValue != expectedUniqueAttrValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueAttrValue, uniqueAttrValue) + } + + if sdpItem.GetScope() != fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) { + t.Errorf("Expected scope %s.%s, got %s", subscriptionID, integrationTestResourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Successfully retrieved SQL server key %s", serviceManagedKeyName) + }) + + t.Run("SearchSQLServerKeys", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Searching SQL server keys for server %s", sqlServerName) + + serverKeyWrapper := manual.NewSqlServerKey( + clients.NewSqlServerKeysClient(serverKeysClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := serverKeyWrapper.Scopes()[0] + + serverKeyAdapter := sources.WrapperToAdapter(serverKeyWrapper, sdpcache.NewNoOpCache()) + + // Check if adapter supports search + searchable, ok := serverKeyAdapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, scope, sqlServerName, true) + if err != nil { + t.Fatalf("Failed to search SQL server keys: %v", err) + } + + if len(sdpItems) < 1 { + t.Fatalf("Expected at least one SQL server key, got %d", len(sdpItems)) + } + + var found bool + for _, item := range sdpItems { + uniqueAttrKey := item.GetUniqueAttribute() + if v, err := item.GetAttributes().Get(uniqueAttrKey); err == nil { + expectedValue := shared.CompositeLookupKey(sqlServerName, serviceManagedKeyName) + if v == expectedValue { + found = true + break + } + } + } + + if !found { + t.Fatalf("Expected to find key %s in the search results", serviceManagedKeyName) + } + + log.Printf("Found %d SQL server keys in search results", len(sdpItems)) + }) + + t.Run("VerifyLinkedItems", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying linked items for SQL server key %s", serviceManagedKeyName) + + serverKeyWrapper := manual.NewSqlServerKey( + clients.NewSqlServerKeysClient(serverKeysClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := serverKeyWrapper.Scopes()[0] + + serverKeyAdapter := sources.WrapperToAdapter(serverKeyWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(sqlServerName, serviceManagedKeyName) + sdpItem, qErr := serverKeyAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify that linked items exist (SQL server should be linked as parent) + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) == 0 { + t.Fatalf("Expected linked item queries, but got none") + } + + // Verify each linked item query has required fields + for _, liq := range linkedQueries { + if liq.GetQuery().GetType() == "" { + t.Error("Linked item query has empty Type") + } + if liq.GetQuery().GetMethod() != sdp.QueryMethod_GET && liq.GetQuery().GetMethod() != sdp.QueryMethod_SEARCH { + t.Errorf("Linked item query has invalid Method: %v", liq.GetQuery().GetMethod()) + } + if liq.GetQuery().GetQuery() == "" { + t.Error("Linked item query has empty Query") + } + if liq.GetQuery().GetScope() == "" { + t.Error("Linked item query has empty Scope") + } + } + + // Verify parent SQL Server link exists + var hasSQLServerLink bool + for _, liq := range linkedQueries { + if liq.GetQuery().GetType() == azureshared.SQLServer.String() { + hasSQLServerLink = true + if liq.GetQuery().GetQuery() != sqlServerName { + t.Errorf("Expected linked query to SQL server %s, got %s", sqlServerName, liq.GetQuery().GetQuery()) + } + if liq.GetQuery().GetMethod() != sdp.QueryMethod_GET { + t.Errorf("Expected linked query method GET for SQL server, got %v", liq.GetQuery().GetMethod()) + } + break + } + } + + if !hasSQLServerLink { + t.Error("Expected linked query to parent SQL server, but didn't find one") + } + + log.Printf("Verified %d linked item queries for SQL server key %s", len(linkedQueries), serviceManagedKeyName) + }) + + t.Run("VerifyItemAttributes", func(t *testing.T) { + ctx := t.Context() + + log.Printf("Verifying item attributes for SQL server key %s", serviceManagedKeyName) + + serverKeyWrapper := manual.NewSqlServerKey( + clients.NewSqlServerKeysClient(serverKeysClient), + []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, integrationTestResourceGroup)}, + ) + scope := serverKeyWrapper.Scopes()[0] + + serverKeyAdapter := sources.WrapperToAdapter(serverKeyWrapper, sdpcache.NewNoOpCache()) + query := shared.CompositeLookupKey(sqlServerName, serviceManagedKeyName) + sdpItem, qErr := serverKeyAdapter.Get(ctx, scope, query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify GetType returns the expected item type + if sdpItem.GetType() != azureshared.SQLServerKey.String() { + t.Errorf("Expected type %s, got %s", azureshared.SQLServerKey, sdpItem.GetType()) + } + + // Verify GetScope returns the expected scope + expectedScope := fmt.Sprintf("%s.%s", subscriptionID, integrationTestResourceGroup) + if sdpItem.GetScope() != expectedScope { + t.Errorf("Expected scope %s, got %s", expectedScope, sdpItem.GetScope()) + } + + // Verify GetUniqueAttribute returns the correct attribute + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + // Verify Validate passes + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Item validation failed: %v", err) + } + + log.Printf("Verified item attributes for SQL server key %s", serviceManagedKeyName) + }) + }) + + t.Run("Teardown", func(t *testing.T) { + ctx := t.Context() + + // Only delete the SQL server if we created it + if serverCreated && sqlServerName != "" { + err := deleteSQLServer(ctx, sqlServerClient, integrationTestResourceGroup, sqlServerName) + if err != nil { + t.Fatalf("Failed to delete SQL server: %v", err) + } + } else { + log.Printf("Skipping SQL server deletion (using pre-existing server)") + } + + // We don't delete the resource group to allow faster subsequent test runs + }) +} diff --git a/sources/azure/manual/adapters.go b/sources/azure/manual/adapters.go index d3f8ff7e..889e647d 100644 --- a/sources/azure/manual/adapters.go +++ b/sources/azure/manual/adapters.go @@ -14,6 +14,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/keyvault/armkeyvault/v2" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns" "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources/v2" @@ -138,6 +139,11 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create network interfaces client: %w", err) } + interfaceIPConfigurationsClient, err := armnetwork.NewInterfaceIPConfigurationsClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create interface IP configurations client: %w", err) + } + sqlDatabasesClient, err := armsql.NewDatabasesClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create sql databases client: %w", err) @@ -208,6 +214,11 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create private endpoints client: %w", err) } + privateLinkServicesClient, err := armnetwork.NewPrivateLinkServicesClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create private link services client: %w", err) + } + batchAccountsClient, err := armbatch.NewAccountClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create batch accounts client: %w", err) @@ -228,6 +239,11 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create batch application package client: %w", err) } + batchPrivateEndpointConnectionClient, err := armbatch.NewPrivateEndpointConnectionClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create batch private endpoint connection client: %w", err) + } + virtualMachineScaleSetsClient, err := armcompute.NewVirtualMachineScaleSetsClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create virtual machine scale sets client: %w", err) @@ -277,11 +293,21 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create application security groups client: %w", err) } + ipGroupsClient, err := armnetwork.NewIPGroupsClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create IP groups client: %w", err) + } + virtualNetworkGatewaysClient, err := armnetwork.NewVirtualNetworkGatewaysClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create virtual network gateways client: %w", err) } + localNetworkGatewaysClient, err := armnetwork.NewLocalNetworkGatewaysClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create local network gateways client: %w", err) + } + natGatewaysClient, err := armnetwork.NewNatGatewaysClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create nat gateways client: %w", err) @@ -292,6 +318,11 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create flow logs client: %w", err) } + networkWatchersClient, err := armnetwork.NewWatchersClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create network watchers client: %w", err) + } + managedHSMsClient, err := armkeyvault.NewManagedHsmsClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create managed hsms client: %w", err) @@ -327,6 +358,16 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create sql private endpoint connections client: %w", err) } + sqlFailoverGroupsClient, err := armsql.NewFailoverGroupsClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create sql failover groups client: %w", err) + } + + sqlServerKeysClient, err := armsql.NewServerKeysClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create sql server keys client: %w", err) + } + postgresqlFlexibleServersClient, err := armpostgresqlflexibleservers.NewServersClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create postgresql flexible servers client: %w", err) @@ -357,6 +398,16 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create postgresql flexible server configurations client: %w", err) } + postgresqlVirtualEndpointsClient, err := armpostgresqlflexibleservers.NewVirtualEndpointsClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create postgresql flexible server virtual endpoints client: %w", err) + } + + postgresqlAdministratorsClient, err := armpostgresqlflexibleservers.NewAdministratorsMicrosoftEntraClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create postgresql flexible server administrators client: %w", err) + } + secretsClient, err := armkeyvault.NewSecretsClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create secrets client: %w", err) @@ -372,11 +423,21 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create user assigned identities client: %w", err) } + federatedIdentityCredentialsClient, err := armmsi.NewFederatedIdentityCredentialsClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create federated identity credentials client: %w", err) + } + roleAssignmentsClient, err := armauthorization.NewRoleAssignmentsClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create role assignments client: %w", err) } + roleDefinitionsClient, err := armauthorization.NewRoleDefinitionsClient(cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create role definitions client: %w", err) + } + diskEncryptionSetsClient, err := armcompute.NewDiskEncryptionSetsClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create disk encryption sets client: %w", err) @@ -482,11 +543,21 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred return nil, fmt.Errorf("failed to create elastic san volume groups client: %w", err) } + elasticSanVolumesClient, err := armelasticsan.NewVolumesClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create elastic san volumes client: %w", err) + } + sharedGalleryImagesClient, err := armcompute.NewSharedGalleryImagesClient(subscriptionID, cred, nil) if err != nil { return nil, fmt.Errorf("failed to create shared gallery images client: %w", err) } + operationalInsightsWorkspacesClient, err := armoperationalinsights.NewWorkspacesClient(subscriptionID, cred, nil) + if err != nil { + return nil, fmt.Errorf("failed to create operational insights workspaces client: %w", err) + } + // Multi-scope resource group adapters (one adapter per type handling all resource groups) if len(resourceGroupScopes) > 0 { adapters = append(adapters, @@ -538,6 +609,10 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewNetworkInterfacesClient(networkInterfacesClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewNetworkNetworkInterfaceIPConfiguration( + clients.NewInterfaceIPConfigurationsClient(interfaceIPConfigurationsClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewSqlDatabase( clients.NewSqlDatabasesClient(sqlDatabasesClient), resourceGroupScopes, @@ -562,6 +637,14 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewSQLServerPrivateEndpointConnectionsClient(sqlPrivateEndpointConnectionsClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewSqlServerFailoverGroup( + clients.NewSqlFailoverGroupsClient(sqlFailoverGroupsClient), + resourceGroupScopes, + ), cache), + sources.WrapperToAdapter(NewSqlServerKey( + clients.NewSqlServerKeysClient(sqlServerKeysClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewDocumentDBDatabaseAccounts( clients.NewDocumentDBDatabaseAccountsClient(documentDBDatabaseAccountsClient), resourceGroupScopes, @@ -618,6 +701,10 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewPrivateEndpointsClient(privateEndpointsClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewNetworkPrivateLinkService( + clients.NewPrivateLinkServicesClient(privateLinkServicesClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewNetworkZone( clients.NewZonesClient(zonesClient), resourceGroupScopes, @@ -650,6 +737,10 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewBatchApplicationPackagesClient(batchApplicationPackageClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewBatchPrivateEndpointConnection( + clients.NewBatchPrivateEndpointConnectionClient(batchPrivateEndpointConnectionClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewComputeVirtualMachineScaleSet( clients.NewVirtualMachineScaleSetsClient(virtualMachineScaleSetsClient), resourceGroupScopes, @@ -670,6 +761,10 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewApplicationSecurityGroupsClient(applicationSecurityGroupsClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewNetworkIPGroup( + clients.NewIPGroupsClient(ipGroupsClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewNetworkRouteTable( clients.NewRouteTablesClient(routeTablesClient), resourceGroupScopes, @@ -694,6 +789,10 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewVirtualNetworkGatewaysClient(virtualNetworkGatewaysClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewNetworkLocalNetworkGateway( + clients.NewLocalNetworkGatewaysClient(localNetworkGatewaysClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewNetworkNatGateway( clients.NewNatGatewaysClient(natGatewaysClient), resourceGroupScopes, @@ -702,6 +801,10 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewFlowLogsClient(flowLogsClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewNetworkNetworkWatcher( + clients.NewNetworkWatchersClient(networkWatchersClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewSqlServer( clients.NewSqlServersClient(sqlServersClient), resourceGroupScopes, @@ -730,6 +833,14 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewPostgreSQLConfigurationsClient(postgresqlConfigurationsClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServerVirtualEndpoint( + clients.NewDBforPostgreSQLFlexibleServerVirtualEndpointClient(postgresqlVirtualEndpointsClient), + resourceGroupScopes, + ), cache), + sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServerAdministrator( + clients.NewDBforPostgreSQLFlexibleServerAdministratorClient(postgresqlAdministratorsClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewKeyVaultSecret( clients.NewSecretsClient(secretsClient), resourceGroupScopes, @@ -742,6 +853,10 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewUserAssignedIdentitiesClient(userAssignedIdentitiesClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewManagedIdentityFederatedIdentityCredential( + clients.NewFederatedIdentityCredentialsClient(federatedIdentityCredentialsClient), + resourceGroupScopes, + ), cache), sources.WrapperToAdapter(NewAuthorizationRoleAssignment( clients.NewRoleAssignmentsClient(roleAssignmentsClient), resourceGroupScopes, @@ -822,6 +937,14 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewElasticSanVolumeGroupClient(elasticSanVolumeGroupsClient), resourceGroupScopes, ), cache), + sources.WrapperToAdapter(NewElasticSanVolume( + clients.NewElasticSanVolumeClient(elasticSanVolumesClient), + resourceGroupScopes, + ), cache), + sources.WrapperToAdapter(NewOperationalInsightsWorkspace( + clients.NewOperationalInsightsWorkspaceClient(operationalInsightsWorkspacesClient), + resourceGroupScopes, + ), cache), ) } @@ -831,6 +954,10 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred clients.NewSharedGalleryImagesClient(sharedGalleryImagesClient), subscriptionID, ), cache), + sources.WrapperToAdapter(NewAuthorizationRoleDefinition( + clients.NewRoleDefinitionsClient(roleDefinitionsClient), + subscriptionID, + ), cache), ) log.WithFields(log.Fields{ @@ -857,11 +984,14 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred sources.WrapperToAdapter(NewNetworkSubnet(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkVirtualNetworkPeering(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkNetworkInterface(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkNetworkInterfaceIPConfiguration(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewSqlDatabase(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewSqlDatabaseSchema(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewSqlServerFirewallRule(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewSqlServerVirtualNetworkRule(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewSQLServerPrivateEndpointConnection(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewSqlServerFailoverGroup(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewSqlServerKey(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewDocumentDBDatabaseAccounts(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewDocumentDBPrivateEndpointConnection(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewKeyVaultVault(nil, placeholderResourceGroupScopes), noOpCache), @@ -883,18 +1013,22 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred sources.WrapperToAdapter(NewBatchBatchApplication(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewBatchBatchPool(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewBatchBatchApplicationPackage(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewBatchPrivateEndpointConnection(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewComputeVirtualMachineScaleSet(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewComputeAvailabilitySet(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewComputeDisk(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkNetworkSecurityGroup(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkApplicationSecurityGroup(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkIPGroup(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkSecurityRule(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkDefaultSecurityRule(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkRouteTable(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkApplicationGateway(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkVirtualNetworkGateway(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkLocalNetworkGateway(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkNatGateway(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewNetworkFlowLog(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkNetworkWatcher(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewSqlServer(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServer(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServerFirewallRule(nil, placeholderResourceGroupScopes), noOpCache), @@ -902,9 +1036,12 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServerBackup(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServerReplica(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServerConfiguration(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServerVirtualEndpoint(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewDBforPostgreSQLFlexibleServerAdministrator(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewKeyVaultSecret(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewKeyVaultKey(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewManagedIdentityUserAssignedIdentity(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewManagedIdentityFederatedIdentityCredential(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewAuthorizationRoleAssignment(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewComputeDiskEncryptionSet(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewComputeImage(nil, placeholderResourceGroupScopes), noOpCache), @@ -925,8 +1062,12 @@ func Adapters(ctx context.Context, subscriptionID string, regions []string, cred sources.WrapperToAdapter(NewElasticSan(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewElasticSanVolumeSnapshot(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewElasticSanVolumeGroup(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewElasticSanVolume(nil, placeholderResourceGroupScopes), noOpCache), sources.WrapperToAdapter(NewComputeSharedGalleryImage(nil, subscriptionID), noOpCache), + sources.WrapperToAdapter(NewAuthorizationRoleDefinition(nil, subscriptionID), noOpCache), sources.WrapperToAdapter(NewNetworkPrivateEndpoint(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewNetworkPrivateLinkService(nil, placeholderResourceGroupScopes), noOpCache), + sources.WrapperToAdapter(NewOperationalInsightsWorkspace(nil, placeholderResourceGroupScopes), noOpCache), ) _ = regions diff --git a/sources/azure/manual/authorization-role-definition.go b/sources/azure/manual/authorization-role-definition.go new file mode 100644 index 00000000..c678e64e --- /dev/null +++ b/sources/azure/manual/authorization-role-definition.go @@ -0,0 +1,203 @@ +package manual + +import ( + "context" + "errors" + "fmt" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +var AuthorizationRoleDefinitionLookupByName = shared.NewItemTypeLookup("name", azureshared.AuthorizationRoleDefinition) + +type authorizationRoleDefinitionWrapper struct { + client clients.RoleDefinitionsClient + + *azureshared.SubscriptionBase +} + +func NewAuthorizationRoleDefinition(client clients.RoleDefinitionsClient, subscriptionID string) sources.ListableWrapper { + return &authorizationRoleDefinitionWrapper{ + client: client, + SubscriptionBase: azureshared.NewSubscriptionBase( + subscriptionID, + sdp.AdapterCategory_ADAPTER_CATEGORY_SECURITY, + azureshared.AuthorizationRoleDefinition, + ), + } +} + +// List retrieves all role definitions within the subscription scope. +// ref: https://learn.microsoft.com/en-us/rest/api/authorization/role-definitions/list +func (c authorizationRoleDefinitionWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { + if scope == "" { + return nil, azureshared.QueryError(errors.New("scope cannot be empty"), scope, c.Type()) + } + + azureScope := fmt.Sprintf("/subscriptions/%s", c.SubscriptionID()) + pager := c.client.NewListPager(azureScope, nil) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + for _, roleDefinition := range page.Value { + if roleDefinition == nil || roleDefinition.Name == nil { + continue + } + item, sdpErr := c.azureRoleDefinitionToSDPItem(roleDefinition, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +// ListStream streams all role definitions within the subscription scope. +func (c authorizationRoleDefinitionWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + if scope == "" { + stream.SendError(azureshared.QueryError(errors.New("scope cannot be empty"), scope, c.Type())) + return + } + + azureScope := fmt.Sprintf("/subscriptions/%s", c.SubscriptionID()) + pager := c.client.NewListPager(azureScope, nil) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + for _, roleDefinition := range page.Value { + if roleDefinition == nil || roleDefinition.Name == nil { + continue + } + item, sdpErr := c.azureRoleDefinitionToSDPItem(roleDefinition, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +// Get retrieves a role definition by its ID (GUID). +// ref: https://learn.microsoft.com/en-us/rest/api/authorization/role-definitions/get +func (c authorizationRoleDefinitionWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if scope == "" { + return nil, azureshared.QueryError(errors.New("scope cannot be empty"), scope, c.Type()) + } + if len(queryParts) != 1 { + return nil, azureshared.QueryError(errors.New("Get requires 1 query part: roleDefinitionID"), scope, c.Type()) + } + + roleDefinitionID := queryParts[0] + if roleDefinitionID == "" { + return nil, azureshared.QueryError(errors.New("roleDefinitionID cannot be empty"), scope, c.Type()) + } + + azureScope := fmt.Sprintf("/subscriptions/%s", c.SubscriptionID()) + resp, err := c.client.Get(ctx, azureScope, roleDefinitionID, nil) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + return c.azureRoleDefinitionToSDPItem(&resp.RoleDefinition, scope) +} + +func (c authorizationRoleDefinitionWrapper) azureRoleDefinitionToSDPItem(roleDefinition *armauthorization.RoleDefinition, scope string) (*sdp.Item, *sdp.QueryError) { + if roleDefinition.Name == nil { + return nil, azureshared.QueryError(errors.New("role definition name is nil"), scope, c.Type()) + } + + attributes, err := shared.ToAttributesWithExclude(roleDefinition) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.AuthorizationRoleDefinition.String(), + UniqueAttribute: "name", + Attributes: attributes, + Scope: scope, + } + + // Link to AssignableScopes (subscriptions and resource groups) + if roleDefinition.Properties != nil && roleDefinition.Properties.AssignableScopes != nil { + for _, assignableScope := range roleDefinition.Properties.AssignableScopes { + if assignableScope == nil || *assignableScope == "" { + continue + } + scopePath := *assignableScope + + // Determine if this is a subscription or resource group scope + // Format: /subscriptions/{subscriptionId} or /subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName} + if rgName := azureshared.ExtractResourceGroupFromResourceID(scopePath); rgName != "" { + // Resource group scope + subscriptionID := azureshared.ExtractSubscriptionIDFromResourceID(scopePath) + if subscriptionID != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ResourcesResourceGroup.String(), + Method: sdp.QueryMethod_GET, + Query: rgName, + Scope: subscriptionID, + }, + }) + } + } else if subscriptionID := azureshared.ExtractSubscriptionIDFromResourceID(scopePath); subscriptionID != "" { + // Subscription scope only + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ResourcesSubscription.String(), + Method: sdp.QueryMethod_GET, + Query: subscriptionID, + Scope: "global", + }, + }) + } + } + } + + return sdpItem, nil +} + +func (c authorizationRoleDefinitionWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + AuthorizationRoleDefinitionLookupByName, + } +} + +// PotentialLinks returns all resource types this adapter can link to. +func (c authorizationRoleDefinitionWrapper) PotentialLinks() map[shared.ItemType]bool { + return shared.NewItemTypesSet( + azureshared.ResourcesSubscription, + azureshared.ResourcesResourceGroup, + ) +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/permissions/management-and-governance#microsoftauthorization +func (c authorizationRoleDefinitionWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Authorization/roleDefinitions/read", + } +} + +func (c authorizationRoleDefinitionWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/authorization-role-definition_test.go b/sources/azure/manual/authorization-role-definition_test.go new file mode 100644 index 00000000..93cfa318 --- /dev/null +++ b/sources/azure/manual/authorization-role-definition_test.go @@ -0,0 +1,478 @@ +package manual_test + +import ( + "context" + "errors" + "reflect" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" +) + +func TestAuthorizationRoleDefinition(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + scope := subscriptionID + azureScope := "/subscriptions/" + subscriptionID + + t.Run("Get", func(t *testing.T) { + roleDefinitionID := "b24988ac-6180-42a0-ab88-20f7382dd24c" + roleDefinition := createAzureRoleDefinition(roleDefinitionID, "Reader") + + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + mockClient.EXPECT().Get(ctx, azureScope, roleDefinitionID, nil).Return( + armauthorization.RoleDefinitionsClientGetResponse{ + RoleDefinition: *roleDefinition, + }, nil) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, scope, roleDefinitionID, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.AuthorizationRoleDefinition.String() { + t.Errorf("Expected type %s, got %s", azureshared.AuthorizationRoleDefinition.String(), sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if sdpItem.UniqueAttributeValue() != roleDefinitionID { + t.Errorf("Expected unique attribute value %s, got %s", roleDefinitionID, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != scope { + t.Errorf("Expected scope %s, got %s", scope, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + // Verify linked item queries for AssignableScopes + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + // Subscription scope link + ExpectedType: azureshared.ResourcesSubscription.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: subscriptionID, + ExpectedScope: "global", + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("Get_EmptyScope", func(t *testing.T) { + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, "", "test-role-definition", true) + if qErr == nil { + t.Error("Expected error when getting role definition with empty scope, but got nil") + } + }) + + t.Run("Get_EmptyRoleDefinitionID", func(t *testing.T) { + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, scope, "", true) + if qErr == nil { + t.Error("Expected error when getting role definition with empty ID, but got nil") + } + }) + + t.Run("Get_ClientError", func(t *testing.T) { + roleDefinitionID := "test-role-definition" + expectedError := errors.New("client error") + + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + mockClient.EXPECT().Get(ctx, azureScope, roleDefinitionID, nil).Return( + armauthorization.RoleDefinitionsClientGetResponse{}, + expectedError) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, scope, roleDefinitionID, true) + if qErr == nil { + t.Error("Expected error when client returns error, but got nil") + } + }) + + t.Run("Get_NilName", func(t *testing.T) { + roleDefinition := &armauthorization.RoleDefinition{ + Name: nil, + Properties: &armauthorization.RoleDefinitionProperties{ + RoleName: new("Reader"), + }, + } + + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + roleDefinitionID := "test-role-definition" + mockClient.EXPECT().Get(ctx, azureScope, roleDefinitionID, nil).Return( + armauthorization.RoleDefinitionsClientGetResponse{ + RoleDefinition: *roleDefinition, + }, nil) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, scope, roleDefinitionID, true) + if qErr == nil { + t.Error("Expected error when role definition has nil name, but got nil") + } + }) + + t.Run("List", func(t *testing.T) { + roleDefinition1 := createAzureRoleDefinition("guid-1", "Reader") + roleDefinition2 := createAzureRoleDefinition("guid-2", "Contributor") + + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + mockPager := NewMockRoleDefinitionsPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armauthorization.RoleDefinitionsClientListResponse{ + RoleDefinitionListResult: armauthorization.RoleDefinitionListResult{ + Value: []*armauthorization.RoleDefinition{roleDefinition1, roleDefinition2}, + }, + }, nil), + mockPager.EXPECT().More().Return(false), + ) + + mockClient.EXPECT().NewListPager(azureScope, nil).Return(mockPager) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if item.Validate() != nil { + t.Fatalf("Expected no validation error, got: %v", item.Validate()) + } + + if item.GetType() != azureshared.AuthorizationRoleDefinition.String() { + t.Fatalf("Expected type %s, got: %s", azureshared.AuthorizationRoleDefinition.String(), item.GetType()) + } + } + }) + + t.Run("List_EmptyScope", func(t *testing.T) { + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + _, err := listable.List(ctx, "", true) + if err == nil { + t.Error("Expected error when listing role definitions with empty scope, but got nil") + } + }) + + t.Run("List_PagerError", func(t *testing.T) { + expectedError := errors.New("pager error") + + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + mockPager := NewMockRoleDefinitionsPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armauthorization.RoleDefinitionsClientListResponse{}, + expectedError), + ) + + mockClient.EXPECT().NewListPager(azureScope, nil).Return(mockPager) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + _, err := listable.List(ctx, scope, true) + if err == nil { + t.Error("Expected error when pager returns error, but got nil") + } + }) + + t.Run("List_WithNilName", func(t *testing.T) { + roleDefinition1 := createAzureRoleDefinition("guid-1", "Reader") + roleDefinition2 := &armauthorization.RoleDefinition{ + Name: nil, + Properties: &armauthorization.RoleDefinitionProperties{ + RoleName: new("Contributor"), + }, + } + + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + mockPager := NewMockRoleDefinitionsPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armauthorization.RoleDefinitionsClientListResponse{ + RoleDefinitionListResult: armauthorization.RoleDefinitionListResult{ + Value: []*armauthorization.RoleDefinition{roleDefinition1, roleDefinition2}, + }, + }, nil), + mockPager.EXPECT().More().Return(false), + ) + + mockClient.EXPECT().NewListPager(azureScope, nil).Return(mockPager) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Should skip nil name items + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (nil name should be skipped), got: %d", len(sdpItems)) + } + }) + + t.Run("ListStream", func(t *testing.T) { + roleDefinition1 := createAzureRoleDefinition("guid-1", "Reader") + roleDefinition2 := createAzureRoleDefinition("guid-2", "Contributor") + + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + mockPager := NewMockRoleDefinitionsPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armauthorization.RoleDefinitionsClientListResponse{ + RoleDefinitionListResult: armauthorization.RoleDefinitionListResult{ + Value: []*armauthorization.RoleDefinition{roleDefinition1, roleDefinition2}, + }, + }, nil), + mockPager.EXPECT().More().Return(false), + ) + + mockClient.EXPECT().NewListPager(azureScope, nil).Return(mockPager) + + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listStreamable, ok := adapter.(discovery.ListStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support ListStream operation") + } + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + var errs []error + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + listStreamable.ListStream(ctx, scope, true, stream) + wg.Wait() + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %d", len(errs)) + } + }) + + t.Run("GetLookups", func(t *testing.T) { + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + + lookups := wrapper.GetLookups() + if len(lookups) != 1 { + t.Errorf("Expected 1 lookup, got: %d", len(lookups)) + } + + foundLookup := false + for _, lookup := range lookups { + if lookup.ItemType == azureshared.AuthorizationRoleDefinition { + foundLookup = true + break + } + } + if !foundLookup { + t.Error("Expected GetLookups to include AuthorizationRoleDefinition") + } + }) + + t.Run("PotentialLinks", func(t *testing.T) { + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + + potentialLinks := wrapper.PotentialLinks() + if len(potentialLinks) != 2 { + t.Errorf("Expected 2 potential links, got: %d", len(potentialLinks)) + } + if !potentialLinks[azureshared.ResourcesSubscription] { + t.Error("Expected PotentialLinks to include ResourcesSubscription") + } + if !potentialLinks[azureshared.ResourcesResourceGroup] { + t.Error("Expected PotentialLinks to include ResourcesResourceGroup") + } + }) + + t.Run("IAMPermissions", func(t *testing.T) { + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + + permissions := wrapper.IAMPermissions() + if len(permissions) != 1 { + t.Errorf("Expected 1 permission, got: %d", len(permissions)) + } + + expectedPermission := "Microsoft.Authorization/roleDefinitions/read" + if permissions[0] != expectedPermission { + t.Errorf("Expected permission %s, got: %s", expectedPermission, permissions[0]) + } + }) + + t.Run("PredefinedRole", func(t *testing.T) { + mockClient := mocks.NewMockRoleDefinitionsClient(ctrl) + wrapper := manual.NewAuthorizationRoleDefinition(mockClient, subscriptionID) + + if roleInterface, ok := any(wrapper).(interface{ PredefinedRole() string }); ok { + role := roleInterface.PredefinedRole() + if role != "Reader" { + t.Errorf("Expected PredefinedRole to be 'Reader', got %s", role) + } + } else { + t.Error("Wrapper does not implement PredefinedRole method") + } + }) +} + +// MockRoleDefinitionsPager is a mock for RoleDefinitionsPager +type MockRoleDefinitionsPager struct { + ctrl *gomock.Controller + recorder *MockRoleDefinitionsPagerMockRecorder +} + +type MockRoleDefinitionsPagerMockRecorder struct { + mock *MockRoleDefinitionsPager +} + +func NewMockRoleDefinitionsPager(ctrl *gomock.Controller) *MockRoleDefinitionsPager { + mock := &MockRoleDefinitionsPager{ctrl: ctrl} + mock.recorder = &MockRoleDefinitionsPagerMockRecorder{mock} + return mock +} + +func (m *MockRoleDefinitionsPager) EXPECT() *MockRoleDefinitionsPagerMockRecorder { + return m.recorder +} + +func (m *MockRoleDefinitionsPager) More() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "More") + ret0, _ := ret[0].(bool) + return ret0 +} + +func (mr *MockRoleDefinitionsPagerMockRecorder) More() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "More", reflect.TypeFor[func() bool]()) +} + +func (m *MockRoleDefinitionsPager) NextPage(ctx context.Context) (armauthorization.RoleDefinitionsClientListResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextPage", ctx) + ret0, _ := ret[0].(armauthorization.RoleDefinitionsClientListResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (mr *MockRoleDefinitionsPagerMockRecorder) NextPage(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextPage", reflect.TypeFor[func(ctx context.Context) (armauthorization.RoleDefinitionsClientListResponse, error)](), ctx) +} + +// createAzureRoleDefinition creates a mock Azure role definition for testing +func createAzureRoleDefinition(roleDefinitionID, roleName string) *armauthorization.RoleDefinition { + return &armauthorization.RoleDefinition{ + Name: new(roleDefinitionID), + Type: new("Microsoft.Authorization/roleDefinitions"), + ID: new("/subscriptions/test-subscription/providers/Microsoft.Authorization/roleDefinitions/" + roleDefinitionID), + Properties: &armauthorization.RoleDefinitionProperties{ + RoleName: new(roleName), + RoleType: new("BuiltInRole"), + Description: new("Test role definition for " + roleName), + AssignableScopes: []*string{ + new("/subscriptions/test-subscription"), + }, + Permissions: []*armauthorization.Permission{ + { + Actions: []*string{ + new("*/read"), + }, + }, + }, + }, + } +} diff --git a/sources/azure/manual/batch-private-endpoint-connection.go b/sources/azure/manual/batch-private-endpoint-connection.go new file mode 100644 index 00000000..f37d2721 --- /dev/null +++ b/sources/azure/manual/batch-private-endpoint-connection.go @@ -0,0 +1,256 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/batch/armbatch/v4" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +var BatchPrivateEndpointConnectionLookupByName = shared.NewItemTypeLookup("name", azureshared.BatchBatchPrivateEndpointConnection) + +type batchPrivateEndpointConnectionWrapper struct { + client clients.BatchPrivateEndpointConnectionClient + + *azureshared.MultiResourceGroupBase +} + +// NewBatchPrivateEndpointConnection returns a SearchableWrapper for Azure Batch private endpoint connections. +func NewBatchPrivateEndpointConnection(client clients.BatchPrivateEndpointConnectionClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { + return &batchPrivateEndpointConnectionWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_COMPUTE_APPLICATION, + azureshared.BatchBatchPrivateEndpointConnection, + ), + } +} + +func (b batchPrivateEndpointConnectionWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 2 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Get requires 2 query parts: accountName and privateEndpointConnectionName", + Scope: scope, + ItemType: b.Type(), + } + } + accountName := queryParts[0] + connectionName := queryParts[1] + + if accountName == "" { + return nil, azureshared.QueryError(errors.New("accountName cannot be empty"), scope, b.Type()) + } + if connectionName == "" { + return nil, azureshared.QueryError(errors.New("privateEndpointConnectionName cannot be empty"), scope, b.Type()) + } + + rgScope, err := b.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, b.Type()) + } + resp, err := b.client.Get(ctx, rgScope.ResourceGroup, accountName, connectionName) + if err != nil { + return nil, azureshared.QueryError(err, scope, b.Type()) + } + + item, sdpErr := b.azurePrivateEndpointConnectionToSDPItem(&resp.PrivateEndpointConnection, accountName, connectionName, scope) + if sdpErr != nil { + return nil, sdpErr + } + return item, nil +} + +func (b batchPrivateEndpointConnectionWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + BatchAccountLookupByName, + BatchPrivateEndpointConnectionLookupByName, + } +} + +func (b batchPrivateEndpointConnectionWrapper) Search(ctx context.Context, scope string, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Search requires 1 query part: accountName", + Scope: scope, + ItemType: b.Type(), + } + } + accountName := queryParts[0] + + if accountName == "" { + return nil, azureshared.QueryError(errors.New("accountName cannot be empty"), scope, b.Type()) + } + + rgScope, err := b.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, b.Type()) + } + pager := b.client.ListByBatchAccount(ctx, rgScope.ResourceGroup, accountName) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, b.Type()) + } + + for _, conn := range page.Value { + if conn == nil || conn.Name == nil { + continue + } + + item, sdpErr := b.azurePrivateEndpointConnectionToSDPItem(conn, accountName, *conn.Name, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +func (b batchPrivateEndpointConnectionWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { + if len(queryParts) < 1 { + stream.SendError(azureshared.QueryError(errors.New("Search requires 1 query part: accountName"), scope, b.Type())) + return + } + accountName := queryParts[0] + + if accountName == "" { + stream.SendError(azureshared.QueryError(errors.New("accountName cannot be empty"), scope, b.Type())) + return + } + + rgScope, err := b.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, b.Type())) + return + } + pager := b.client.ListByBatchAccount(ctx, rgScope.ResourceGroup, accountName) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, b.Type())) + return + } + for _, conn := range page.Value { + if conn == nil || conn.Name == nil { + continue + } + item, sdpErr := b.azurePrivateEndpointConnectionToSDPItem(conn, accountName, *conn.Name, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (b batchPrivateEndpointConnectionWrapper) SearchLookups() []sources.ItemTypeLookups { + return []sources.ItemTypeLookups{ + { + BatchAccountLookupByName, + }, + } +} + +func (b batchPrivateEndpointConnectionWrapper) PotentialLinks() map[shared.ItemType]bool { + return map[shared.ItemType]bool{ + azureshared.BatchBatchAccount: true, + azureshared.NetworkPrivateEndpoint: true, + } +} + +func (b batchPrivateEndpointConnectionWrapper) azurePrivateEndpointConnectionToSDPItem(conn *armbatch.PrivateEndpointConnection, accountName, connectionName, scope string) (*sdp.Item, *sdp.QueryError) { + attributes, err := shared.ToAttributesWithExclude(conn, "tags") + if err != nil { + return nil, azureshared.QueryError(err, scope, b.Type()) + } + + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(accountName, connectionName)) + if err != nil { + return nil, azureshared.QueryError(err, scope, b.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.BatchBatchPrivateEndpointConnection.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attributes, + Scope: scope, + Tags: azureshared.ConvertAzureTags(conn.Tags), + } + + // Health from provisioning state + if conn.Properties != nil && conn.Properties.ProvisioningState != nil { + switch *conn.Properties.ProvisioningState { + case armbatch.PrivateEndpointConnectionProvisioningStateSucceeded: + sdpItem.Health = sdp.Health_HEALTH_OK.Enum() + case armbatch.PrivateEndpointConnectionProvisioningStateCreating, + armbatch.PrivateEndpointConnectionProvisioningStateUpdating, + armbatch.PrivateEndpointConnectionProvisioningStateDeleting: + sdpItem.Health = sdp.Health_HEALTH_PENDING.Enum() + case armbatch.PrivateEndpointConnectionProvisioningStateFailed, + armbatch.PrivateEndpointConnectionProvisioningStateCancelled: + sdpItem.Health = sdp.Health_HEALTH_ERROR.Enum() + default: + sdpItem.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + // Link to parent Batch Account + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.BatchBatchAccount.String(), + Method: sdp.QueryMethod_GET, + Query: accountName, + Scope: scope, + }, + }) + + // Link to Network Private Endpoint when present (may be in different resource group) + if conn.Properties != nil && conn.Properties.PrivateEndpoint != nil && conn.Properties.PrivateEndpoint.ID != nil { + peID := *conn.Properties.PrivateEndpoint.ID + peName := azureshared.ExtractResourceName(peID) + if peName != "" { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(peID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkPrivateEndpoint.String(), + Method: sdp.QueryMethod_GET, + Query: peName, + Scope: linkedScope, + }, + }) + } + } + + return sdpItem, nil +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftbatch +func (b batchPrivateEndpointConnectionWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Batch/batchAccounts/privateEndpointConnections/read", + } +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/built-in-roles/compute#azure-batch-account-reader +func (b batchPrivateEndpointConnectionWrapper) PredefinedRole() string { + return "Azure Batch Account Reader" +} diff --git a/sources/azure/manual/batch-private-endpoint-connection_test.go b/sources/azure/manual/batch-private-endpoint-connection_test.go new file mode 100644 index 00000000..aaba0ec8 --- /dev/null +++ b/sources/azure/manual/batch-private-endpoint-connection_test.go @@ -0,0 +1,471 @@ +package manual_test + +import ( + "context" + "errors" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/batch/armbatch/v4" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" +) + +type mockBatchPrivateEndpointConnectionPager struct { + pages []armbatch.PrivateEndpointConnectionClientListByBatchAccountResponse + index int +} + +func (m *mockBatchPrivateEndpointConnectionPager) More() bool { + return m.index < len(m.pages) +} + +func (m *mockBatchPrivateEndpointConnectionPager) NextPage(ctx context.Context) (armbatch.PrivateEndpointConnectionClientListByBatchAccountResponse, error) { + if m.index >= len(m.pages) { + return armbatch.PrivateEndpointConnectionClientListByBatchAccountResponse{}, errors.New("no more pages") + } + page := m.pages[m.index] + m.index++ + return page, nil +} + +type testBatchPrivateEndpointConnectionClient struct { + *mocks.MockBatchPrivateEndpointConnectionClient + pager clients.BatchPrivateEndpointConnectionPager +} + +func (t *testBatchPrivateEndpointConnectionClient) ListByBatchAccount(ctx context.Context, resourceGroupName, accountName string) clients.BatchPrivateEndpointConnectionPager { + return t.pager +} + +func TestBatchPrivateEndpointConnection(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + accountName := "test-batch-account" + connectionName := "test-pec" + + t.Run("Get", func(t *testing.T) { + conn := createAzureBatchPrivateEndpointConnection(connectionName, "") + + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, accountName, connectionName).Return( + armbatch.PrivateEndpointConnectionClientGetResponse{ + PrivateEndpointConnection: *conn, + }, nil) + + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(accountName, connectionName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.BatchBatchPrivateEndpointConnection.String() { + t.Errorf("Expected type %s, got %s", azureshared.BatchBatchPrivateEndpointConnection, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + if sdpItem.UniqueAttributeValue() != shared.CompositeLookupKey(accountName, connectionName) { + t.Errorf("Expected unique attribute value %s, got %s", shared.CompositeLookupKey(accountName, connectionName), sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != subscriptionID+"."+resourceGroup { + t.Errorf("Expected scope %s, got %s", subscriptionID+"."+resourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + t.Run("StaticTests", func(t *testing.T) { + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) < 1 { + t.Fatalf("Expected at least 1 linked query, got: %d", len(linkedQueries)) + } + + foundBatchAccount := false + for _, lq := range linkedQueries { + if lq.GetQuery().GetType() == azureshared.BatchBatchAccount.String() { + foundBatchAccount = true + if lq.GetQuery().GetMethod() != sdp.QueryMethod_GET { + t.Errorf("Expected BatchAccount link method GET, got %v", lq.GetQuery().GetMethod()) + } + if lq.GetQuery().GetQuery() != accountName { + t.Errorf("Expected BatchAccount query %s, got %s", accountName, lq.GetQuery().GetQuery()) + } + } + } + if !foundBatchAccount { + t.Error("Expected linked query to BatchAccount") + } + }) + }) + + t.Run("Get_WithPrivateEndpointLink", func(t *testing.T) { + peID := "/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.Network/privateEndpoints/test-pe" + conn := createAzureBatchPrivateEndpointConnection(connectionName, peID) + + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, accountName, connectionName).Return( + armbatch.PrivateEndpointConnectionClientGetResponse{ + PrivateEndpointConnection: *conn, + }, nil) + + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(accountName, connectionName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + foundPrivateEndpoint := false + for _, lq := range sdpItem.GetLinkedItemQueries() { + if lq.GetQuery().GetType() == azureshared.NetworkPrivateEndpoint.String() { + foundPrivateEndpoint = true + if lq.GetQuery().GetQuery() != "test-pe" { + t.Errorf("Expected NetworkPrivateEndpoint query 'test-pe', got %s", lq.GetQuery().GetQuery()) + } + break + } + } + if !foundPrivateEndpoint { + t.Error("Expected linked query to NetworkPrivateEndpoint when PrivateEndpoint ID is set") + } + }) + + t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], accountName, true) + if qErr == nil { + t.Error("Expected error when providing insufficient query parts, but got nil") + } + }) + + t.Run("GetWithEmptyAccountName", func(t *testing.T) { + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey("", connectionName) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when accountName is empty, but got nil") + } + }) + + t.Run("GetWithEmptyConnectionName", func(t *testing.T) { + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(accountName, "") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when connectionName is empty, but got nil") + } + }) + + t.Run("Search", func(t *testing.T) { + conn1 := createAzureBatchPrivateEndpointConnection("pec-1", "") + conn2 := createAzureBatchPrivateEndpointConnection("pec-2", "") + + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + mockPager := &mockBatchPrivateEndpointConnectionPager{ + pages: []armbatch.PrivateEndpointConnectionClientListByBatchAccountResponse{ + { + ListPrivateEndpointConnectionsResult: armbatch.ListPrivateEndpointConnectionsResult{ + Value: []*armbatch.PrivateEndpointConnection{conn1, conn2}, + }, + }, + }, + } + + testClient := &testBatchPrivateEndpointConnectionClient{ + MockBatchPrivateEndpointConnectionClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], accountName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if err := item.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + if item.GetType() != azureshared.BatchBatchPrivateEndpointConnection.String() { + t.Errorf("Expected type %s, got %s", azureshared.BatchBatchPrivateEndpointConnection, item.GetType()) + } + } + }) + + t.Run("SearchStream", func(t *testing.T) { + conn1 := createAzureBatchPrivateEndpointConnection("pec-1", "") + conn2 := createAzureBatchPrivateEndpointConnection("pec-2", "") + + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + mockPager := &mockBatchPrivateEndpointConnectionPager{ + pages: []armbatch.PrivateEndpointConnectionClientListByBatchAccountResponse{ + { + ListPrivateEndpointConnectionsResult: armbatch.ListPrivateEndpointConnectionsResult{ + Value: []*armbatch.PrivateEndpointConnection{conn1, conn2}, + }, + }, + }, + } + + testClient := &testBatchPrivateEndpointConnectionClient{ + MockBatchPrivateEndpointConnectionClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchStreamable, ok := adapter.(discovery.SearchStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support SearchStream operation") + } + + var items []*sdp.Item + var errs []error + + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + } + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + searchStreamable.SearchStream(ctx, wrapper.Scopes()[0], accountName, true, stream) + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + }) + + t.Run("Search_NilNameSkipped", func(t *testing.T) { + validConn := createAzureBatchPrivateEndpointConnection("valid-pec", "") + + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + mockPager := &mockBatchPrivateEndpointConnectionPager{ + pages: []armbatch.PrivateEndpointConnectionClientListByBatchAccountResponse{ + { + ListPrivateEndpointConnectionsResult: armbatch.ListPrivateEndpointConnectionsResult{ + Value: []*armbatch.PrivateEndpointConnection{ + {Name: nil}, + validConn, + }, + }, + }, + }, + } + + testClient := &testBatchPrivateEndpointConnectionClient{ + MockBatchPrivateEndpointConnectionClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], accountName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (nil name skipped), got: %d", len(sdpItems)) + } + if sdpItems[0].UniqueAttributeValue() != shared.CompositeLookupKey(accountName, "valid-pec") { + t.Errorf("Expected unique value %s, got %s", shared.CompositeLookupKey(accountName, "valid-pec"), sdpItems[0].UniqueAttributeValue()) + } + }) + + t.Run("Search_InvalidQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) + if qErr == nil { + t.Error("Expected error when providing no query parts, but got nil") + } + }) + + t.Run("SearchWithEmptyAccountName", func(t *testing.T) { + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0], "") + if qErr == nil { + t.Error("Expected error when accountName is empty, but got nil") + } + }) + + t.Run("ErrorHandling_Get", func(t *testing.T) { + expectedErr := errors.New("private endpoint connection not found") + + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, accountName, "nonexistent-pec").Return( + armbatch.PrivateEndpointConnectionClientGetResponse{}, expectedErr) + + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(accountName, "nonexistent-pec") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting non-existent private endpoint connection, but got nil") + } + }) + + t.Run("PotentialLinks", func(t *testing.T) { + wrapper := manual.NewBatchPrivateEndpointConnection(nil, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + links := wrapper.PotentialLinks() + if !links[azureshared.BatchBatchAccount] { + t.Error("Expected BatchAccount in PotentialLinks") + } + if !links[azureshared.NetworkPrivateEndpoint] { + t.Error("Expected NetworkPrivateEndpoint in PotentialLinks") + } + }) + + t.Run("HealthMapping", func(t *testing.T) { + tests := []struct { + name string + state armbatch.PrivateEndpointConnectionProvisioningState + expectedHeath sdp.Health + }{ + {"Succeeded", armbatch.PrivateEndpointConnectionProvisioningStateSucceeded, sdp.Health_HEALTH_OK}, + {"Creating", armbatch.PrivateEndpointConnectionProvisioningStateCreating, sdp.Health_HEALTH_PENDING}, + {"Updating", armbatch.PrivateEndpointConnectionProvisioningStateUpdating, sdp.Health_HEALTH_PENDING}, + {"Deleting", armbatch.PrivateEndpointConnectionProvisioningStateDeleting, sdp.Health_HEALTH_PENDING}, + {"Failed", armbatch.PrivateEndpointConnectionProvisioningStateFailed, sdp.Health_HEALTH_ERROR}, + {"Cancelled", armbatch.PrivateEndpointConnectionProvisioningStateCancelled, sdp.Health_HEALTH_ERROR}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conn := createAzureBatchPrivateEndpointConnectionWithState(connectionName, tt.state) + + mockClient := mocks.NewMockBatchPrivateEndpointConnectionClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, accountName, connectionName).Return( + armbatch.PrivateEndpointConnectionClientGetResponse{ + PrivateEndpointConnection: *conn, + }, nil) + + testClient := &testBatchPrivateEndpointConnectionClient{MockBatchPrivateEndpointConnectionClient: mockClient} + wrapper := manual.NewBatchPrivateEndpointConnection(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(accountName, connectionName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetHealth() != tt.expectedHeath { + t.Errorf("Expected health %v, got %v", tt.expectedHeath, sdpItem.GetHealth()) + } + }) + } + }) +} + +func createAzureBatchPrivateEndpointConnection(connectionName, privateEndpointID string) *armbatch.PrivateEndpointConnection { + succeeded := armbatch.PrivateEndpointConnectionProvisioningStateSucceeded + conn := &armbatch.PrivateEndpointConnection{ + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Batch/batchAccounts/test-batch-account/privateEndpointConnections/" + connectionName), + Name: new(connectionName), + Type: new("Microsoft.Batch/batchAccounts/privateEndpointConnections"), + Properties: &armbatch.PrivateEndpointConnectionProperties{ + ProvisioningState: &succeeded, + PrivateLinkServiceConnectionState: &armbatch.PrivateLinkServiceConnectionState{ + Status: new(armbatch.PrivateLinkServiceConnectionStatusApproved), + }, + }, + Tags: map[string]*string{ + "env": new("test"), + }, + } + if privateEndpointID != "" { + conn.Properties.PrivateEndpoint = &armbatch.PrivateEndpoint{ + ID: new(privateEndpointID), + } + } + return conn +} + +func createAzureBatchPrivateEndpointConnectionWithState(connectionName string, state armbatch.PrivateEndpointConnectionProvisioningState) *armbatch.PrivateEndpointConnection { + conn := &armbatch.PrivateEndpointConnection{ + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Batch/batchAccounts/test-batch-account/privateEndpointConnections/" + connectionName), + Name: new(connectionName), + Type: new("Microsoft.Batch/batchAccounts/privateEndpointConnections"), + Properties: &armbatch.PrivateEndpointConnectionProperties{ + ProvisioningState: &state, + }, + Tags: map[string]*string{ + "env": new("test"), + }, + } + return conn +} diff --git a/sources/azure/manual/dbforpostgresql-flexible-server-administrator.go b/sources/azure/manual/dbforpostgresql-flexible-server-administrator.go new file mode 100644 index 00000000..f7968937 --- /dev/null +++ b/sources/azure/manual/dbforpostgresql-flexible-server-administrator.go @@ -0,0 +1,229 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +var DBforPostgreSQLFlexibleServerAdministratorLookupByName = shared.NewItemTypeLookup("name", azureshared.DBforPostgreSQLFlexibleServerAdministrator) + +type dbforPostgreSQLFlexibleServerAdministratorWrapper struct { + client clients.DBforPostgreSQLFlexibleServerAdministratorClient + + *azureshared.MultiResourceGroupBase +} + +func NewDBforPostgreSQLFlexibleServerAdministrator(client clients.DBforPostgreSQLFlexibleServerAdministratorClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { + return &dbforPostgreSQLFlexibleServerAdministratorWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, + azureshared.DBforPostgreSQLFlexibleServerAdministrator, + ), + } +} + +// Get retrieves a single administrator by server name and object ID +// ref: https://learn.microsoft.com/en-us/rest/api/postgresql/administrators-microsoft-entra/get +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 2 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Get requires 2 query parts: serverName and objectId", + Scope: scope, + ItemType: s.Type(), + } + } + serverName := queryParts[0] + objectID := queryParts[1] + + if serverName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "serverName cannot be empty", + Scope: scope, + ItemType: s.Type(), + } + } + if objectID == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "objectId cannot be empty", + Scope: scope, + ItemType: s.Type(), + } + } + + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, serverName, objectID) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + + return s.azureAdministratorToSDPItem(&resp.AdministratorMicrosoftEntra, serverName, scope) +} + +// Search retrieves all administrators for a given server +// ref: https://learn.microsoft.com/en-us/rest/api/postgresql/administrators-microsoft-entra/list-by-server +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) Search(ctx context.Context, scope string, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Search requires 1 query part: serverName", + Scope: scope, + ItemType: s.Type(), + } + } + serverName := queryParts[0] + if serverName == "" { + return nil, azureshared.QueryError(errors.New("serverName cannot be empty"), scope, s.Type()) + } + + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + pager := s.client.ListByServer(ctx, rgScope.ResourceGroup, serverName) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + + for _, admin := range page.Value { + if admin.Name == nil { + continue + } + + item, sdpErr := s.azureAdministratorToSDPItem(admin, serverName, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { + if len(queryParts) < 1 { + stream.SendError(azureshared.QueryError(errors.New("Search requires 1 query part: serverName"), scope, s.Type())) + return + } + serverName := queryParts[0] + if serverName == "" { + stream.SendError(azureshared.QueryError(errors.New("serverName cannot be empty"), scope, s.Type())) + return + } + + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, s.Type())) + return + } + pager := s.client.ListByServer(ctx, rgScope.ResourceGroup, serverName) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, s.Type())) + return + } + for _, admin := range page.Value { + if admin.Name == nil { + continue + } + item, sdpErr := s.azureAdministratorToSDPItem(admin, serverName, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + DBforPostgreSQLFlexibleServerLookupByName, + DBforPostgreSQLFlexibleServerAdministratorLookupByName, + } +} + +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) SearchLookups() []sources.ItemTypeLookups { + return []sources.ItemTypeLookups{ + { + DBforPostgreSQLFlexibleServerLookupByName, + }, + } +} + +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) azureAdministratorToSDPItem(admin *armpostgresqlflexibleservers.AdministratorMicrosoftEntra, serverName, scope string) (*sdp.Item, *sdp.QueryError) { + if admin.Name == nil { + return nil, azureshared.QueryError(errors.New("administrator name (objectId) is nil"), scope, s.Type()) + } + + attributes, err := shared.ToAttributesWithExclude(admin) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + + objectID := *admin.Name + + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(serverName, objectID)) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + + sdpItem := &sdp.Item{ + Type: s.Type(), + UniqueAttribute: "uniqueAttr", + Attributes: attributes, + Scope: scope, + } + + // Link to the parent PostgreSQL Flexible Server + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.DBforPostgreSQLFlexibleServer.String(), + Method: sdp.QueryMethod_GET, + Query: serverName, + Scope: scope, + }, + }) + + return sdpItem, nil +} + +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) PotentialLinks() map[shared.ItemType]bool { + return shared.NewItemTypesSet( + azureshared.DBforPostgreSQLFlexibleServer, + ) +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/permissions/databases#microsoftdbforpostgresql +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.DBforPostgreSQL/flexibleServers/administrators/read", + } +} + +func (s dbforPostgreSQLFlexibleServerAdministratorWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/dbforpostgresql-flexible-server-administrator_test.go b/sources/azure/manual/dbforpostgresql-flexible-server-administrator_test.go new file mode 100644 index 00000000..2ffbf003 --- /dev/null +++ b/sources/azure/manual/dbforpostgresql-flexible-server-administrator_test.go @@ -0,0 +1,402 @@ +package manual_test + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" +) + +// mockAdministratorPager is a simple mock implementation of DBforPostgreSQLFlexibleServerAdministratorPager +type mockAdministratorPager struct { + pages []armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse + index int +} + +func (m *mockAdministratorPager) More() bool { + return m.index < len(m.pages) +} + +func (m *mockAdministratorPager) NextPage(ctx context.Context) (armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse, error) { + if m.index >= len(m.pages) { + return armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse{}, errors.New("no more pages") + } + page := m.pages[m.index] + m.index++ + return page, nil +} + +// errorAdministratorPager is a mock pager that always returns an error +type errorAdministratorPager struct{} + +func (e *errorAdministratorPager) More() bool { + return true +} + +func (e *errorAdministratorPager) NextPage(ctx context.Context) (armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse, error) { + return armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse{}, errors.New("pager error") +} + +// testAdministratorClient wraps the mock to implement the correct interface +type testAdministratorClient struct { + *mocks.MockDBforPostgreSQLFlexibleServerAdministratorClient + pager clients.DBforPostgreSQLFlexibleServerAdministratorPager +} + +func (t *testAdministratorClient) ListByServer(ctx context.Context, resourceGroupName, serverName string) clients.DBforPostgreSQLFlexibleServerAdministratorPager { + return t.pager +} + +func TestDBforPostgreSQLFlexibleServerAdministrator(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + serverName := "test-server" + objectID := "00000000-0000-0000-0000-000000000001" + + t.Run("Get", func(t *testing.T) { + admin := createAzureAdministrator(objectID) + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, objectID).Return( + armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientGetResponse{ + AdministratorMicrosoftEntra: *admin, + }, nil) + + testClient := &testAdministratorClient{MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient} + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, objectID) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.DBforPostgreSQLFlexibleServerAdministrator.String() { + t.Errorf("Expected type %s, got %s", azureshared.DBforPostgreSQLFlexibleServerAdministrator, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + expectedUniqueValue := shared.CompositeLookupKey(serverName, objectID) + if sdpItem.UniqueAttributeValue() != expectedUniqueValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueValue, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != subscriptionID+"."+resourceGroup { + t.Errorf("Expected scope %s, got %s", subscriptionID+"."+resourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + t.Run("StaticTests", func(t *testing.T) { + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) != 1 { + t.Fatalf("Expected 1 linked query, got: %d", len(linkedQueries)) + } + + queryTests := shared.QueryTests{ + { + ExpectedType: azureshared.DBforPostgreSQLFlexibleServer.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: serverName, + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + testClient := &testAdministratorClient{MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient} + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) + if qErr == nil { + t.Error("Expected error when providing insufficient query parts, but got nil") + } + }) + + t.Run("GetWithEmptyServerName", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + testClient := &testAdministratorClient{MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient} + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey("", objectID) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when providing empty server name, but got nil") + } + }) + + t.Run("GetWithEmptyObjectId", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + testClient := &testAdministratorClient{MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient} + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, "") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when providing empty objectId, but got nil") + } + }) + + t.Run("Search", func(t *testing.T) { + admin1 := createAzureAdministrator("00000000-0000-0000-0000-000000000001") + admin2 := createAzureAdministrator("00000000-0000-0000-0000-000000000002") + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + mockPager := &mockAdministratorPager{ + pages: []armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse{ + { + AdministratorMicrosoftEntraList: armpostgresqlflexibleservers.AdministratorMicrosoftEntraList{ + Value: []*armpostgresqlflexibleservers.AdministratorMicrosoftEntra{admin1, admin2}, + }, + }, + }, + } + + testClient := &testAdministratorClient{ + MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if err := item.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + if item.GetType() != azureshared.DBforPostgreSQLFlexibleServerAdministrator.String() { + t.Errorf("Expected type %s, got %s", azureshared.DBforPostgreSQLFlexibleServerAdministrator, item.GetType()) + } + } + }) + + t.Run("SearchWithEmptyServerName", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + testClient := &testAdministratorClient{MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient} + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0], "") + if qErr == nil { + t.Error("Expected error when providing empty server name, but got nil") + } + }) + + t.Run("SearchWithNoQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + testClient := &testAdministratorClient{MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient} + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) + if qErr == nil { + t.Error("Expected error when providing no query parts, but got nil") + } + }) + + t.Run("SearchStream", func(t *testing.T) { + admin1 := createAzureAdministrator("00000000-0000-0000-0000-000000000001") + admin2 := createAzureAdministrator("00000000-0000-0000-0000-000000000002") + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + mockPager := &mockAdministratorPager{ + pages: []armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse{ + { + AdministratorMicrosoftEntraList: armpostgresqlflexibleservers.AdministratorMicrosoftEntraList{ + Value: []*armpostgresqlflexibleservers.AdministratorMicrosoftEntra{admin1, admin2}, + }, + }, + }, + } + + testClient := &testAdministratorClient{ + MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + + var errs []error + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + searchStreamable, ok := adapter.(discovery.SearchStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support SearchStream operation") + } + + searchStreamable.SearchStream(ctx, wrapper.Scopes()[0], serverName, true, stream) + wg.Wait() + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + }) + + t.Run("ErrorHandling_Get", func(t *testing.T) { + expectedErr := errors.New("administrator not found") + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, "nonexistent").Return( + armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientGetResponse{}, expectedErr) + + testClient := &testAdministratorClient{MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient} + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, "nonexistent") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting non-existent administrator, but got nil") + } + }) + + t.Run("ErrorHandling_Search", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + errorPager := &errorAdministratorPager{} + + testClient := &testAdministratorClient{ + MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient, + pager: errorPager, + } + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + _, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if err == nil { + t.Error("Expected error from pager when NextPage returns an error, but got nil") + } + }) + + t.Run("Search_AdminWithNilName", func(t *testing.T) { + validAdmin := createAzureAdministrator("00000000-0000-0000-0000-000000000001") + nilNameAdmin := &armpostgresqlflexibleservers.AdministratorMicrosoftEntra{ + Name: nil, + } + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl) + mockPager := &mockAdministratorPager{ + pages: []armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientListByServerResponse{ + { + AdministratorMicrosoftEntraList: armpostgresqlflexibleservers.AdministratorMicrosoftEntraList{ + Value: []*armpostgresqlflexibleservers.AdministratorMicrosoftEntra{nilNameAdmin, validAdmin}, + }, + }, + }, + } + + testClient := &testAdministratorClient{ + MockDBforPostgreSQLFlexibleServerAdministratorClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewDBforPostgreSQLFlexibleServerAdministrator(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (nil name should be skipped), got: %d", len(sdpItems)) + } + + expectedUniqueValue := shared.CompositeLookupKey(serverName, "00000000-0000-0000-0000-000000000001") + if sdpItems[0].UniqueAttributeValue() != expectedUniqueValue { + t.Errorf("Expected unique value %s, got %s", expectedUniqueValue, sdpItems[0].UniqueAttributeValue()) + } + }) +} + +// createAzureAdministrator creates a mock Azure administrator for testing +func createAzureAdministrator(objectID string) *armpostgresqlflexibleservers.AdministratorMicrosoftEntra { + principalType := armpostgresqlflexibleservers.PrincipalTypeUser + return &armpostgresqlflexibleservers.AdministratorMicrosoftEntra{ + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.DBforPostgreSQL/flexibleServers/test-server/administrators/" + objectID), + Name: new(objectID), + Type: new("Microsoft.DBforPostgreSQL/flexibleServers/administrators"), + Properties: &armpostgresqlflexibleservers.AdministratorMicrosoftEntraProperties{ + ObjectID: new(objectID), + PrincipalName: new("admin@example.com"), + PrincipalType: &principalType, + TenantID: new("tenant-id"), + }, + } +} diff --git a/sources/azure/manual/dbforpostgresql-flexible-server-virtual-endpoint.go b/sources/azure/manual/dbforpostgresql-flexible-server-virtual-endpoint.go new file mode 100644 index 00000000..b1a114d6 --- /dev/null +++ b/sources/azure/manual/dbforpostgresql-flexible-server-virtual-endpoint.go @@ -0,0 +1,276 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +var DBforPostgreSQLFlexibleServerVirtualEndpointLookupByName = shared.NewItemTypeLookup("name", azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint) + +type dbforPostgreSQLFlexibleServerVirtualEndpointWrapper struct { + client clients.DBforPostgreSQLFlexibleServerVirtualEndpointClient + + *azureshared.MultiResourceGroupBase +} + +func NewDBforPostgreSQLFlexibleServerVirtualEndpoint(client clients.DBforPostgreSQLFlexibleServerVirtualEndpointClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { + return &dbforPostgreSQLFlexibleServerVirtualEndpointWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, + azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint, + ), + } +} + +// ref: https://learn.microsoft.com/en-us/rest/api/postgresql/flexibleserver/virtual-endpoints/get?view=rest-postgresql-2025-08-01 +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 2 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Get requires 2 query parts: serverName and virtualEndpointName", + Scope: scope, + ItemType: s.Type(), + } + } + serverName := queryParts[0] + virtualEndpointName := queryParts[1] + if serverName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "serverName cannot be empty", + Scope: scope, + ItemType: s.Type(), + } + } + if virtualEndpointName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "virtualEndpointName cannot be empty", + Scope: scope, + ItemType: s.Type(), + } + } + + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + resp, err := s.client.Get(ctx, rgScope.ResourceGroup, serverName, virtualEndpointName) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + + return s.azureVirtualEndpointToSDPItem(&resp.VirtualEndpoint, serverName, virtualEndpointName, scope) +} + +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) azureVirtualEndpointToSDPItem(virtualEndpoint *armpostgresqlflexibleservers.VirtualEndpoint, serverName, virtualEndpointName, scope string) (*sdp.Item, *sdp.QueryError) { + if virtualEndpoint.Name == nil { + return nil, azureshared.QueryError(errors.New("virtual endpoint name is nil"), scope, s.Type()) + } + + attributes, err := shared.ToAttributesWithExclude(virtualEndpoint) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(serverName, virtualEndpointName)) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attributes, + Scope: scope, + Tags: nil, + } + + // Link to parent PostgreSQL Flexible Server + if virtualEndpoint.ID != nil { + params := azureshared.ExtractPathParamsFromResourceID(*virtualEndpoint.ID, []string{"flexibleServers"}) + if len(params) > 0 { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.DBforPostgreSQLFlexibleServer.String(), + Method: sdp.QueryMethod_GET, + Query: params[0], + Scope: scope, + }, + }) + } + } else { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.DBforPostgreSQLFlexibleServer.String(), + Method: sdp.QueryMethod_GET, + Query: serverName, + Scope: scope, + }, + }) + } + + // Link to member servers (Members field contains server names that this virtual endpoint can refer to) + if virtualEndpoint.Properties != nil && virtualEndpoint.Properties.Members != nil { + for _, memberServerName := range virtualEndpoint.Properties.Members { + if memberServerName != nil && *memberServerName != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.DBforPostgreSQLFlexibleServer.String(), + Method: sdp.QueryMethod_GET, + Query: *memberServerName, + Scope: scope, + }, + }) + } + } + } + + // Link to virtual endpoint DNS names (VirtualEndpoints field contains DNS names) + if virtualEndpoint.Properties != nil && virtualEndpoint.Properties.VirtualEndpoints != nil { + for _, dnsName := range virtualEndpoint.Properties.VirtualEndpoints { + if dnsName != nil && *dnsName != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkDNS.String(), + Method: sdp.QueryMethod_SEARCH, + Query: *dnsName, + Scope: "global", + }, + }) + } + } + } + + return sdpItem, nil +} + +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + DBforPostgreSQLFlexibleServerLookupByName, + DBforPostgreSQLFlexibleServerVirtualEndpointLookupByName, + } +} + +// ref: https://learn.microsoft.com/en-us/rest/api/postgresql/flexibleserver/virtual-endpoints/list-by-server?view=rest-postgresql-2025-08-01 +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) Search(ctx context.Context, scope string, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Search requires 1 query part: serverName", + Scope: scope, + ItemType: s.Type(), + } + } + serverName := queryParts[0] + if serverName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "serverName cannot be empty", + Scope: scope, + ItemType: s.Type(), + } + } + + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + pager := s.client.ListByServer(ctx, rgScope.ResourceGroup, serverName) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, s.Type()) + } + for _, virtualEndpoint := range page.Value { + if virtualEndpoint.Name == nil { + continue + } + item, sdpErr := s.azureVirtualEndpointToSDPItem(virtualEndpoint, serverName, *virtualEndpoint.Name, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { + if len(queryParts) < 1 { + stream.SendError(azureshared.QueryError(errors.New("Search requires 1 query part: serverName"), scope, s.Type())) + return + } + serverName := queryParts[0] + if serverName == "" { + stream.SendError(azureshared.QueryError(errors.New("serverName cannot be empty"), scope, s.Type())) + return + } + + rgScope, err := s.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, s.Type())) + return + } + pager := s.client.ListByServer(ctx, rgScope.ResourceGroup, serverName) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, s.Type())) + return + } + for _, virtualEndpoint := range page.Value { + if virtualEndpoint.Name == nil { + continue + } + item, sdpErr := s.azureVirtualEndpointToSDPItem(virtualEndpoint, serverName, *virtualEndpoint.Name, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) SearchLookups() []sources.ItemTypeLookups { + return []sources.ItemTypeLookups{ + { + DBforPostgreSQLFlexibleServerLookupByName, + }, + } +} + +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) PotentialLinks() map[shared.ItemType]bool { + return map[shared.ItemType]bool{ + azureshared.DBforPostgreSQLFlexibleServer: true, + stdlib.NetworkDNS: true, + } +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftdbforpostgresql +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.DBforPostgreSQL/flexibleServers/virtualEndpoints/read", + } +} + +func (s dbforPostgreSQLFlexibleServerVirtualEndpointWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/dbforpostgresql-flexible-server-virtual-endpoint_test.go b/sources/azure/manual/dbforpostgresql-flexible-server-virtual-endpoint_test.go new file mode 100644 index 00000000..538bc4a9 --- /dev/null +++ b/sources/azure/manual/dbforpostgresql-flexible-server-virtual-endpoint_test.go @@ -0,0 +1,328 @@ +package manual_test + +import ( + "context" + "errors" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +type mockDBforPostgreSQLFlexibleServerVirtualEndpointPager struct { + pages []armpostgresqlflexibleservers.VirtualEndpointsClientListByServerResponse + index int +} + +func (m *mockDBforPostgreSQLFlexibleServerVirtualEndpointPager) More() bool { + return m.index < len(m.pages) +} + +func (m *mockDBforPostgreSQLFlexibleServerVirtualEndpointPager) NextPage(ctx context.Context) (armpostgresqlflexibleservers.VirtualEndpointsClientListByServerResponse, error) { + if m.index >= len(m.pages) { + return armpostgresqlflexibleservers.VirtualEndpointsClientListByServerResponse{}, errors.New("no more pages") + } + page := m.pages[m.index] + m.index++ + return page, nil +} + +type errorDBforPostgreSQLFlexibleServerVirtualEndpointPager struct{} + +func (e *errorDBforPostgreSQLFlexibleServerVirtualEndpointPager) More() bool { + return true +} + +func (e *errorDBforPostgreSQLFlexibleServerVirtualEndpointPager) NextPage(ctx context.Context) (armpostgresqlflexibleservers.VirtualEndpointsClientListByServerResponse, error) { + return armpostgresqlflexibleservers.VirtualEndpointsClientListByServerResponse{}, errors.New("pager error") +} + +type testDBforPostgreSQLFlexibleServerVirtualEndpointClient struct { + *mocks.MockDBforPostgreSQLFlexibleServerVirtualEndpointClient + pager clients.DBforPostgreSQLFlexibleServerVirtualEndpointPager +} + +func (t *testDBforPostgreSQLFlexibleServerVirtualEndpointClient) ListByServer(ctx context.Context, resourceGroupName, serverName string) clients.DBforPostgreSQLFlexibleServerVirtualEndpointPager { + return t.pager +} + +func TestDBforPostgreSQLFlexibleServerVirtualEndpoint(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + serverName := "test-server" + virtualEndpointName := "test-virtual-endpoint" + + t.Run("Get", func(t *testing.T) { + virtualEndpoint := createAzurePostgreSQLFlexibleServerVirtualEndpoint(serverName, virtualEndpointName) + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, virtualEndpointName).Return( + armpostgresqlflexibleservers.VirtualEndpointsClientGetResponse{ + VirtualEndpoint: *virtualEndpoint, + }, nil) + + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(&testDBforPostgreSQLFlexibleServerVirtualEndpointClient{MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient}, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, virtualEndpointName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint.String() { + t.Errorf("Expected type %s, got %s", azureshared.DBforPostgreSQLFlexibleServerVirtualEndpoint, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + expectedUniqueAttrValue := shared.CompositeLookupKey(serverName, virtualEndpointName) + if sdpItem.UniqueAttributeValue() != expectedUniqueAttrValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueAttrValue, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != subscriptionID+"."+resourceGroup { + t.Errorf("Expected scope %s, got %s", subscriptionID+"."+resourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + ExpectedType: azureshared.DBforPostgreSQLFlexibleServer.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: serverName, + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + ExpectedType: azureshared.DBforPostgreSQLFlexibleServer.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "member-server-1", + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + ExpectedType: stdlib.NetworkDNS.String(), + ExpectedMethod: sdp.QueryMethod_SEARCH, + ExpectedQuery: "test-endpoint.postgres.database.azure.com", + ExpectedScope: "global", + }, + } + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(&testDBforPostgreSQLFlexibleServerVirtualEndpointClient{MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient}, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) + if qErr == nil { + t.Error("Expected error when providing only serverName (1 query part), but got nil") + } + }) + + t.Run("GetWithEmptyServerName", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(&testDBforPostgreSQLFlexibleServerVirtualEndpointClient{MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient}, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey("", virtualEndpointName) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when serverName is empty, but got nil") + } + }) + + t.Run("GetWithEmptyVirtualEndpointName", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(&testDBforPostgreSQLFlexibleServerVirtualEndpointClient{MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient}, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, "") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when virtualEndpointName is empty, but got nil") + } + }) + + t.Run("Search", func(t *testing.T) { + virtualEndpoint1 := createAzurePostgreSQLFlexibleServerVirtualEndpoint(serverName, "vep1") + virtualEndpoint2 := createAzurePostgreSQLFlexibleServerVirtualEndpoint(serverName, "vep2") + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + pager := &mockDBforPostgreSQLFlexibleServerVirtualEndpointPager{ + pages: []armpostgresqlflexibleservers.VirtualEndpointsClientListByServerResponse{ + { + VirtualEndpointsList: armpostgresqlflexibleservers.VirtualEndpointsList{ + Value: []*armpostgresqlflexibleservers.VirtualEndpoint{virtualEndpoint1, virtualEndpoint2}, + }, + }, + }, + } + + testClient := &testDBforPostgreSQLFlexibleServerVirtualEndpointClient{ + MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient, + pager: pager, + } + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + items, qErr := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if qErr != nil { + t.Fatalf("Expected no error from Search, got: %v", qErr) + } + if len(items) != 2 { + t.Errorf("Expected 2 items from Search, got %d", len(items)) + } + }) + + t.Run("SearchStream", func(t *testing.T) { + virtualEndpoint1 := createAzurePostgreSQLFlexibleServerVirtualEndpoint(serverName, "vep1") + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + pager := &mockDBforPostgreSQLFlexibleServerVirtualEndpointPager{ + pages: []armpostgresqlflexibleservers.VirtualEndpointsClientListByServerResponse{ + { + VirtualEndpointsList: armpostgresqlflexibleservers.VirtualEndpointsList{ + Value: []*armpostgresqlflexibleservers.VirtualEndpoint{virtualEndpoint1}, + }, + }, + }, + } + + testClient := &testDBforPostgreSQLFlexibleServerVirtualEndpointClient{ + MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient, + pager: pager, + } + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchStreamable, ok := adapter.(discovery.SearchStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support SearchStream operation") + } + + stream := discovery.NewRecordingQueryResultStream() + searchStreamable.SearchStream(ctx, wrapper.Scopes()[0], serverName, true, stream) + items := stream.GetItems() + errs := stream.GetErrors() + if len(errs) > 0 { + t.Fatalf("Expected no errors from SearchStream, got: %v", errs) + } + if len(items) != 1 { + t.Errorf("Expected 1 item from SearchStream, got %d", len(items)) + } + }) + + t.Run("SearchWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(&testDBforPostgreSQLFlexibleServerVirtualEndpointClient{MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient}, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) + if qErr == nil { + t.Error("Expected error when providing no query parts, but got nil") + } + }) + + t.Run("SearchWithEmptyServerName", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(&testDBforPostgreSQLFlexibleServerVirtualEndpointClient{MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient}, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0], "") + if qErr == nil { + t.Error("Expected error when serverName is empty, but got nil") + } + }) + + t.Run("ErrorHandling_Get", func(t *testing.T) { + expectedErr := errors.New("virtual endpoint not found") + + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, "nonexistent-vep").Return( + armpostgresqlflexibleservers.VirtualEndpointsClientGetResponse{}, expectedErr) + + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(&testDBforPostgreSQLFlexibleServerVirtualEndpointClient{MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient}, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, "nonexistent-vep") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting non-existent virtual endpoint, but got nil") + } + }) + + t.Run("ErrorHandling_Search", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + errorPager := &errorDBforPostgreSQLFlexibleServerVirtualEndpointPager{} + testClient := &testDBforPostgreSQLFlexibleServerVirtualEndpointClient{ + MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient, + pager: errorPager, + } + + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0], serverName) + if qErr == nil { + t.Error("Expected error from Search when pager returns error, but got nil") + } + }) + + t.Run("PotentialLinks", func(t *testing.T) { + mockClient := mocks.NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl) + wrapper := manual.NewDBforPostgreSQLFlexibleServerVirtualEndpoint(&testDBforPostgreSQLFlexibleServerVirtualEndpointClient{MockDBforPostgreSQLFlexibleServerVirtualEndpointClient: mockClient}, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + potentialLinks := wrapper.PotentialLinks() + + expectedLinks := map[shared.ItemType]bool{ + azureshared.DBforPostgreSQLFlexibleServer: true, + stdlib.NetworkDNS: true, + } + + for linkType := range expectedLinks { + if !potentialLinks[linkType] { + t.Errorf("Expected PotentialLinks to include %s", linkType) + } + } + }) +} + +func createAzurePostgreSQLFlexibleServerVirtualEndpoint(serverName, virtualEndpointName string) *armpostgresqlflexibleservers.VirtualEndpoint { + virtualEndpointID := "/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.DBforPostgreSQL/flexibleServers/" + serverName + "/virtualEndpoints/" + virtualEndpointName + endpointType := armpostgresqlflexibleservers.VirtualEndpointTypeReadWrite + return &armpostgresqlflexibleservers.VirtualEndpoint{ + Name: new(virtualEndpointName), + ID: new(virtualEndpointID), + Type: new("Microsoft.DBforPostgreSQL/flexibleServers/virtualEndpoints"), + Properties: &armpostgresqlflexibleservers.VirtualEndpointResourceProperties{ + EndpointType: &endpointType, + Members: []*string{new("member-server-1")}, + VirtualEndpoints: []*string{ + new("test-endpoint.postgres.database.azure.com"), + }, + }, + } +} diff --git a/sources/azure/manual/elastic-san-volume.go b/sources/azure/manual/elastic-san-volume.go new file mode 100644 index 00000000..ce5485d0 --- /dev/null +++ b/sources/azure/manual/elastic-san-volume.go @@ -0,0 +1,327 @@ +package manual + +import ( + "context" + "errors" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/elasticsan/armelasticsan" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +var ElasticSanVolumeLookupByName = shared.NewItemTypeLookup("name", azureshared.ElasticSanVolume) + +type elasticSanVolumeWrapper struct { + client clients.ElasticSanVolumeClient + *azureshared.MultiResourceGroupBase +} + +func NewElasticSanVolume(client clients.ElasticSanVolumeClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { + return &elasticSanVolumeWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_STORAGE, + azureshared.ElasticSanVolume, + ), + } +} + +func (e elasticSanVolumeWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 3 { + return nil, azureshared.QueryError(errors.New("Get requires 3 query parts: elasticSanName, volumeGroupName, and volumeName"), scope, e.Type()) + } + elasticSanName := queryParts[0] + if elasticSanName == "" { + return nil, azureshared.QueryError(errors.New("elasticSanName cannot be empty"), scope, e.Type()) + } + volumeGroupName := queryParts[1] + if volumeGroupName == "" { + return nil, azureshared.QueryError(errors.New("volumeGroupName cannot be empty"), scope, e.Type()) + } + volumeName := queryParts[2] + if volumeName == "" { + return nil, azureshared.QueryError(errors.New("volumeName cannot be empty"), scope, e.Type()) + } + + rgScope, err := e.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, e.Type()) + } + resp, err := e.client.Get(ctx, rgScope.ResourceGroup, elasticSanName, volumeGroupName, volumeName, nil) + if err != nil { + return nil, azureshared.QueryError(err, scope, e.Type()) + } + return e.azureVolumeToSDPItem(&resp.Volume, elasticSanName, volumeGroupName, volumeName, scope) +} + +func (e elasticSanVolumeWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + ElasticSanLookupByName, + ElasticSanVolumeGroupLookupByName, + ElasticSanVolumeLookupByName, + } +} + +func (e elasticSanVolumeWrapper) Search(ctx context.Context, scope string, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 2 { + return nil, azureshared.QueryError(errors.New("Search requires 2 query parts: elasticSanName and volumeGroupName"), scope, e.Type()) + } + elasticSanName := queryParts[0] + if elasticSanName == "" { + return nil, azureshared.QueryError(errors.New("elasticSanName cannot be empty"), scope, e.Type()) + } + volumeGroupName := queryParts[1] + if volumeGroupName == "" { + return nil, azureshared.QueryError(errors.New("volumeGroupName cannot be empty"), scope, e.Type()) + } + + rgScope, err := e.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, e.Type()) + } + pager := e.client.NewListByVolumeGroupPager(rgScope.ResourceGroup, elasticSanName, volumeGroupName, nil) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, e.Type()) + } + for _, vol := range page.Value { + if vol.Name == nil { + continue + } + item, sdpErr := e.azureVolumeToSDPItem(vol, elasticSanName, volumeGroupName, *vol.Name, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + return items, nil +} + +func (e elasticSanVolumeWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { + if len(queryParts) < 2 { + stream.SendError(azureshared.QueryError(errors.New("Search requires 2 query parts: elasticSanName and volumeGroupName"), scope, e.Type())) + return + } + elasticSanName := queryParts[0] + if elasticSanName == "" { + stream.SendError(azureshared.QueryError(errors.New("elasticSanName cannot be empty"), scope, e.Type())) + return + } + volumeGroupName := queryParts[1] + if volumeGroupName == "" { + stream.SendError(azureshared.QueryError(errors.New("volumeGroupName cannot be empty"), scope, e.Type())) + return + } + + rgScope, err := e.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, e.Type())) + return + } + pager := e.client.NewListByVolumeGroupPager(rgScope.ResourceGroup, elasticSanName, volumeGroupName, nil) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, e.Type())) + return + } + for _, vol := range page.Value { + if vol.Name == nil { + continue + } + item, sdpErr := e.azureVolumeToSDPItem(vol, elasticSanName, volumeGroupName, *vol.Name, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (e elasticSanVolumeWrapper) SearchLookups() []sources.ItemTypeLookups { + return []sources.ItemTypeLookups{ + {ElasticSanLookupByName, ElasticSanVolumeGroupLookupByName}, + } +} + +func (e elasticSanVolumeWrapper) azureVolumeToSDPItem(vol *armelasticsan.Volume, elasticSanName, volumeGroupName, volumeName, scope string) (*sdp.Item, *sdp.QueryError) { + if vol.Name == nil { + return nil, azureshared.QueryError(errors.New("volume name is nil"), scope, e.Type()) + } + attributes, err := shared.ToAttributesWithExclude(vol, "tags") + if err != nil { + return nil, azureshared.QueryError(err, scope, e.Type()) + } + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(elasticSanName, volumeGroupName, volumeName)) + if err != nil { + return nil, azureshared.QueryError(err, scope, e.Type()) + } + + item := &sdp.Item{ + Type: azureshared.ElasticSanVolume.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attributes, + Scope: scope, + LinkedItemQueries: []*sdp.LinkedItemQuery{}, + } + + // Link to parent Elastic SAN + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ElasticSan.String(), + Method: sdp.QueryMethod_GET, + Query: elasticSanName, + Scope: scope, + }, + }) + + // Link to parent Volume Group + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ElasticSanVolumeGroup.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(elasticSanName, volumeGroupName), + Scope: scope, + }, + }) + + if vol.Properties != nil { + // Link to source resource (snapshot or volume) via CreationData.SourceID + if vol.Properties.CreationData != nil && vol.Properties.CreationData.SourceID != nil && *vol.Properties.CreationData.SourceID != "" { + sourceID := *vol.Properties.CreationData.SourceID + // Determine the type based on the resource ID path + // Azure REST API uses /snapshots/ for Elastic SAN volume snapshots + if strings.Contains(sourceID, "/snapshots/") { + // It's a snapshot - extract elasticSanName, volumeGroupName, snapshotName + params := azureshared.ExtractPathParamsFromResourceID(sourceID, []string{"elasticSans", "volumegroups", "snapshots"}) + if len(params) >= 3 && params[0] != "" && params[1] != "" && params[2] != "" { + linkedScope := azureshared.ExtractScopeFromResourceID(sourceID) + if linkedScope == "" { + linkedScope = scope + } + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ElasticSanVolumeSnapshot.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(params[0], params[1], params[2]), + Scope: linkedScope, + }, + }) + } + } else if strings.Contains(sourceID, "/volumes/") { + // It's a volume - extract elasticSanName, volumeGroupName, volumeName + params := azureshared.ExtractPathParamsFromResourceID(sourceID, []string{"elasticSans", "volumegroups", "volumes"}) + if len(params) >= 3 && params[0] != "" && params[1] != "" && params[2] != "" { + linkedScope := azureshared.ExtractScopeFromResourceID(sourceID) + if linkedScope == "" { + linkedScope = scope + } + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ElasticSanVolume.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(params[0], params[1], params[2]), + Scope: linkedScope, + }, + }) + } + } + } + + // Link to managed-by resource via ManagedBy.ResourceID + if vol.Properties.ManagedBy != nil && vol.Properties.ManagedBy.ResourceID != nil && *vol.Properties.ManagedBy.ResourceID != "" { + managedByID := *vol.Properties.ManagedBy.ResourceID + // ManagedBy can reference different resource types (e.g., AKS clusters, VMs) + // We'll use the generic resource name extraction and link appropriately + linkedScope := azureshared.ExtractScopeFromResourceID(managedByID) + if linkedScope == "" { + linkedScope = scope + } + + // Detect the resource type based on the path + if strings.Contains(managedByID, "/virtualMachines/") { + vmName := azureshared.ExtractResourceName(managedByID) + if vmName != "" { + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ComputeVirtualMachine.String(), + Method: sdp.QueryMethod_GET, + Query: vmName, + Scope: linkedScope, + }, + }) + } + } + // Add other resource types as needed + } + + // Link to storage target DNS/hostname if available + if vol.Properties.StorageTarget != nil { + if vol.Properties.StorageTarget.TargetPortalHostname != nil && *vol.Properties.StorageTarget.TargetPortalHostname != "" { + item.LinkedItemQueries = append(item.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkDNS.String(), + Method: sdp.QueryMethod_SEARCH, + Query: *vol.Properties.StorageTarget.TargetPortalHostname, + Scope: "global", + }, + }) + } + } + } + + // Health from provisioning state + if vol.Properties != nil && vol.Properties.ProvisioningState != nil { + switch *vol.Properties.ProvisioningState { + case armelasticsan.ProvisioningStatesSucceeded: + item.Health = sdp.Health_HEALTH_OK.Enum() + case armelasticsan.ProvisioningStatesCreating, armelasticsan.ProvisioningStatesUpdating, armelasticsan.ProvisioningStatesDeleting, + armelasticsan.ProvisioningStatesPending, armelasticsan.ProvisioningStatesRestoring: + item.Health = sdp.Health_HEALTH_PENDING.Enum() + case armelasticsan.ProvisioningStatesFailed, armelasticsan.ProvisioningStatesCanceled, + armelasticsan.ProvisioningStatesDeleted, armelasticsan.ProvisioningStatesInvalid: + item.Health = sdp.Health_HEALTH_ERROR.Enum() + default: + item.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + return item, nil +} + +func (e elasticSanVolumeWrapper) PotentialLinks() map[shared.ItemType]bool { + return map[shared.ItemType]bool{ + azureshared.ElasticSan: true, + azureshared.ElasticSanVolumeGroup: true, + azureshared.ElasticSanVolumeSnapshot: true, + azureshared.ElasticSanVolume: true, + azureshared.ComputeVirtualMachine: true, + stdlib.NetworkDNS: true, + } +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftelasticsan +func (e elasticSanVolumeWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.ElasticSan/elasticSans/volumegroups/volumes/read", + } +} + +func (e elasticSanVolumeWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/elastic-san-volume_test.go b/sources/azure/manual/elastic-san-volume_test.go new file mode 100644 index 00000000..d2a26bc7 --- /dev/null +++ b/sources/azure/manual/elastic-san-volume_test.go @@ -0,0 +1,329 @@ +package manual_test + +import ( + "context" + "errors" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/elasticsan/armelasticsan" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +// mockElasticSanVolumePager is a simple mock implementation of ElasticSanVolumePager +type mockElasticSanVolumePager struct { + pages []armelasticsan.VolumesClientListByVolumeGroupResponse + index int +} + +func (m *mockElasticSanVolumePager) More() bool { + return m.index < len(m.pages) +} + +func (m *mockElasticSanVolumePager) NextPage(ctx context.Context) (armelasticsan.VolumesClientListByVolumeGroupResponse, error) { + if m.index >= len(m.pages) { + return armelasticsan.VolumesClientListByVolumeGroupResponse{}, errors.New("no more pages") + } + page := m.pages[m.index] + m.index++ + return page, nil +} + +func createAzureElasticSanVolume(name string) *armelasticsan.Volume { + provisioningState := armelasticsan.ProvisioningStatesSucceeded + sizeGiB := int64(100) + return &armelasticsan.Volume{ + ID: new("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.ElasticSan/elasticSans/es/volumegroups/vg/volumes/" + name), + Name: new(name), + Type: new("Microsoft.ElasticSan/elasticSans/volumegroups/volumes"), + Properties: &armelasticsan.VolumeProperties{ + SizeGiB: &sizeGiB, + ProvisioningState: &provisioningState, + }, + } +} + +func createAzureElasticSanVolumeWithLinks(name string) *armelasticsan.Volume { + vol := createAzureElasticSanVolume(name) + vol.Properties.StorageTarget = &armelasticsan.IscsiTargetInfo{ + TargetPortalHostname: new("test-san.region.elasticsan.azure.net"), + TargetIqn: new("iqn.2022-05.net.azure.elasticsan:test"), + TargetPortalPort: new(int32(3260)), + } + vol.Properties.CreationData = &armelasticsan.SourceCreationData{ + CreateSource: new(armelasticsan.VolumeCreateOptionVolumeSnapshot), + SourceID: new("/subscriptions/sub/resourceGroups/rg/providers/Microsoft.ElasticSan/elasticSans/es/volumegroups/vg/snapshots/snap1"), + } + return vol +} + +func TestElasticSanVolume(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + elasticSanName := "test-elastic-san" + volumeGroupName := "test-volume-group" + volumeName := "test-volume" + + t.Run("Get", func(t *testing.T) { + vol := createAzureElasticSanVolume(volumeName) + + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, elasticSanName, volumeGroupName, volumeName, nil).Return( + armelasticsan.VolumesClientGetResponse{ + Volume: *vol, + }, nil) + + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(elasticSanName, volumeGroupName, volumeName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.ElasticSanVolume.String() { + t.Errorf("Expected type %s, got %s", azureshared.ElasticSanVolume.String(), sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + expectedUnique := shared.CompositeLookupKey(elasticSanName, volumeGroupName, volumeName) + if sdpItem.UniqueAttributeValue() != expectedUnique { + t.Errorf("Expected unique attribute value %s, got %s", expectedUnique, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != subscriptionID+"."+resourceGroup { + t.Errorf("Expected scope %s, got %s", subscriptionID+"."+resourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + t.Run("StaticTests", func(t *testing.T) { + scope := subscriptionID + "." + resourceGroup + queryTests := shared.QueryTests{ + {ExpectedType: azureshared.ElasticSan.String(), ExpectedMethod: sdp.QueryMethod_GET, ExpectedQuery: elasticSanName, ExpectedScope: scope}, + {ExpectedType: azureshared.ElasticSanVolumeGroup.String(), ExpectedMethod: sdp.QueryMethod_GET, ExpectedQuery: shared.CompositeLookupKey(elasticSanName, volumeGroupName), ExpectedScope: scope}, + } + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithLinks", func(t *testing.T) { + vol := createAzureElasticSanVolumeWithLinks(volumeName) + + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, elasticSanName, volumeGroupName, volumeName, nil).Return( + armelasticsan.VolumesClientGetResponse{ + Volume: *vol, + }, nil) + + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(elasticSanName, volumeGroupName, volumeName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + t.Run("StaticTests", func(t *testing.T) { + scope := subscriptionID + "." + resourceGroup + queryTests := shared.QueryTests{ + {ExpectedType: azureshared.ElasticSan.String(), ExpectedMethod: sdp.QueryMethod_GET, ExpectedQuery: elasticSanName, ExpectedScope: scope}, + {ExpectedType: azureshared.ElasticSanVolumeGroup.String(), ExpectedMethod: sdp.QueryMethod_GET, ExpectedQuery: shared.CompositeLookupKey(elasticSanName, volumeGroupName), ExpectedScope: scope}, + {ExpectedType: azureshared.ElasticSanVolumeSnapshot.String(), ExpectedMethod: sdp.QueryMethod_GET, ExpectedQuery: shared.CompositeLookupKey("es", "vg", "snap1"), ExpectedScope: "sub.rg"}, + {ExpectedType: stdlib.NetworkDNS.String(), ExpectedMethod: sdp.QueryMethod_SEARCH, ExpectedQuery: "test-san.region.elasticsan.azure.net", ExpectedScope: "global"}, + } + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Only 2 query parts - missing volumeName + query := shared.CompositeLookupKey(elasticSanName, volumeGroupName) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when providing insufficient query parts, but got nil") + } + }) + + t.Run("GetWithEmptyElasticSanName", func(t *testing.T) { + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey("", volumeGroupName, volumeName) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when elasticSanName is empty, but got nil") + } + }) + + t.Run("GetWithEmptyVolumeGroupName", func(t *testing.T) { + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(elasticSanName, "", volumeName) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when volumeGroupName is empty, but got nil") + } + }) + + t.Run("GetWithEmptyVolumeName", func(t *testing.T) { + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(elasticSanName, volumeGroupName, "") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when volumeName is empty, but got nil") + } + }) + + t.Run("ErrorHandling", func(t *testing.T) { + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, elasticSanName, volumeGroupName, "nonexistent", nil).Return( + armelasticsan.VolumesClientGetResponse{}, errors.New("volume not found")) + + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(elasticSanName, volumeGroupName, "nonexistent") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when resource not found, but got nil") + } + }) + + t.Run("Search", func(t *testing.T) { + vol1 := createAzureElasticSanVolume("vol-1") + vol2 := createAzureElasticSanVolume("vol-2") + + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + mockPager := &mockElasticSanVolumePager{ + pages: []armelasticsan.VolumesClientListByVolumeGroupResponse{ + { + VolumeList: armelasticsan.VolumeList{ + Value: []*armelasticsan.Volume{vol1, vol2}, + }, + }, + }, + } + mockClient.EXPECT().NewListByVolumeGroupPager(resourceGroup, elasticSanName, volumeGroupName, nil).Return(mockPager) + + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + query := shared.CompositeLookupKey(elasticSanName, volumeGroupName) + items, err := searchable.Search(ctx, wrapper.Scopes()[0], query, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + if len(items) != 2 { + t.Fatalf("Expected 2 items, got %d", len(items)) + } + for _, item := range items { + if err := item.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + } + }) + + t.Run("SearchWithEmptyElasticSanName", func(t *testing.T) { + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + query := shared.CompositeLookupKey("", volumeGroupName) + _, err := searchable.Search(ctx, wrapper.Scopes()[0], query, true) + if err == nil { + t.Error("Expected error when elasticSanName is empty, but got nil") + } + }) + + t.Run("SearchWithEmptyVolumeGroupName", func(t *testing.T) { + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + query := shared.CompositeLookupKey(elasticSanName, "") + _, err := searchable.Search(ctx, wrapper.Scopes()[0], query, true) + if err == nil { + t.Error("Expected error when volumeGroupName is empty, but got nil") + } + }) + + t.Run("SearchStream", func(t *testing.T) { + vol := createAzureElasticSanVolume("stream-vol") + mockClient := mocks.NewMockElasticSanVolumeClient(ctrl) + mockPager := &mockElasticSanVolumePager{ + pages: []armelasticsan.VolumesClientListByVolumeGroupResponse{ + { + VolumeList: armelasticsan.VolumeList{ + Value: []*armelasticsan.Volume{vol}, + }, + }, + }, + } + mockClient.EXPECT().NewListByVolumeGroupPager(resourceGroup, elasticSanName, volumeGroupName, nil).Return(mockPager) + + wrapper := manual.NewElasticSanVolume(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + streamable, ok := adapter.(discovery.SearchStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support SearchStream operation") + } + + query := shared.CompositeLookupKey(elasticSanName, volumeGroupName) + stream := discovery.NewRecordingQueryResultStream() + streamable.SearchStream(ctx, wrapper.Scopes()[0], query, true, stream) + items := stream.GetItems() + if len(items) != 1 { + t.Fatalf("Expected 1 item from stream, got %d", len(items)) + } + if items[0].GetType() != azureshared.ElasticSanVolume.String() { + t.Errorf("Expected type %s, got %s", azureshared.ElasticSanVolume.String(), items[0].GetType()) + } + }) +} diff --git a/sources/azure/manual/managedidentity-federated-identity-credential.go b/sources/azure/manual/managedidentity-federated-identity-credential.go new file mode 100644 index 00000000..cbfea73c --- /dev/null +++ b/sources/azure/manual/managedidentity-federated-identity-credential.go @@ -0,0 +1,243 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +var ManagedIdentityFederatedIdentityCredentialLookupByName = shared.NewItemTypeLookup("name", azureshared.ManagedIdentityFederatedIdentityCredential) + +type managedIdentityFederatedIdentityCredentialWrapper struct { + client clients.FederatedIdentityCredentialsClient + + *azureshared.MultiResourceGroupBase +} + +func NewManagedIdentityFederatedIdentityCredential(client clients.FederatedIdentityCredentialsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { + return &managedIdentityFederatedIdentityCredentialWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_SECURITY, + azureshared.ManagedIdentityFederatedIdentityCredential, + ), + } +} + +func (m managedIdentityFederatedIdentityCredentialWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 2 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Get requires 2 query parts: identityName and federatedCredentialName", + Scope: scope, + ItemType: m.Type(), + } + } + identityName := queryParts[0] + if identityName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "identityName cannot be empty", + Scope: scope, + ItemType: m.Type(), + } + } + federatedCredentialName := queryParts[1] + if federatedCredentialName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "federatedCredentialName cannot be empty", + Scope: scope, + ItemType: m.Type(), + } + } + + rgScope, err := m.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, m.Type()) + } + + resp, err := m.client.Get(ctx, rgScope.ResourceGroup, identityName, federatedCredentialName, nil) + if err != nil { + return nil, azureshared.QueryError(err, scope, m.Type()) + } + + return m.azureFederatedIdentityCredentialToSDPItem(&resp.FederatedIdentityCredential, identityName, federatedCredentialName, scope) +} + +func (m managedIdentityFederatedIdentityCredentialWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + ManagedIdentityUserAssignedIdentityLookupByName, + ManagedIdentityFederatedIdentityCredentialLookupByName, + } +} + +func (m managedIdentityFederatedIdentityCredentialWrapper) Search(ctx context.Context, scope string, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Search requires 1 query part: identityName", + Scope: scope, + ItemType: m.Type(), + } + } + identityName := queryParts[0] + if identityName == "" { + return nil, azureshared.QueryError(errors.New("identityName cannot be empty"), scope, m.Type()) + } + + rgScope, err := m.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, m.Type()) + } + + pager := m.client.NewListPager(rgScope.ResourceGroup, identityName, nil) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, m.Type()) + } + + for _, credential := range page.Value { + if credential.Name == nil { + continue + } + + item, sdpErr := m.azureFederatedIdentityCredentialToSDPItem(credential, identityName, *credential.Name, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +func (m managedIdentityFederatedIdentityCredentialWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { + if len(queryParts) < 1 { + stream.SendError(azureshared.QueryError(errors.New("Search requires 1 query part: identityName"), scope, m.Type())) + return + } + identityName := queryParts[0] + if identityName == "" { + stream.SendError(azureshared.QueryError(errors.New("identityName cannot be empty"), scope, m.Type())) + return + } + + rgScope, err := m.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, m.Type())) + return + } + + pager := m.client.NewListPager(rgScope.ResourceGroup, identityName, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, m.Type())) + return + } + for _, credential := range page.Value { + if credential.Name == nil { + continue + } + item, sdpErr := m.azureFederatedIdentityCredentialToSDPItem(credential, identityName, *credential.Name, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (m managedIdentityFederatedIdentityCredentialWrapper) SearchLookups() []sources.ItemTypeLookups { + return []sources.ItemTypeLookups{ + { + ManagedIdentityUserAssignedIdentityLookupByName, + }, + } +} + +func (m managedIdentityFederatedIdentityCredentialWrapper) azureFederatedIdentityCredentialToSDPItem(credential *armmsi.FederatedIdentityCredential, identityName, credentialName, scope string) (*sdp.Item, *sdp.QueryError) { + if credential.Name == nil { + return nil, azureshared.QueryError(errors.New("credential name is nil"), scope, m.Type()) + } + + attributes, err := shared.ToAttributesWithExclude(credential) + if err != nil { + return nil, azureshared.QueryError(err, scope, m.Type()) + } + + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(identityName, credentialName)) + if err != nil { + return nil, azureshared.QueryError(err, scope, m.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.ManagedIdentityFederatedIdentityCredential.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attributes, + Scope: scope, + } + + // Link back to the parent user assigned identity + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ManagedIdentityUserAssignedIdentity.String(), + Method: sdp.QueryMethod_GET, + Query: identityName, + Scope: scope, + }, + }) + + // Link to DNS hostname from Issuer URL (e.g., https://token.actions.githubusercontent.com) + // The Issuer is the URL of the external identity provider + if credential.Properties != nil && credential.Properties.Issuer != nil && *credential.Properties.Issuer != "" { + dnsName := azureshared.ExtractDNSFromURL(*credential.Properties.Issuer) + if dnsName != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkDNS.String(), + Method: sdp.QueryMethod_SEARCH, + Query: dnsName, + Scope: "global", + }, + }) + } + } + + return sdpItem, nil +} + +func (m managedIdentityFederatedIdentityCredentialWrapper) PotentialLinks() map[shared.ItemType]bool { + return map[shared.ItemType]bool{ + azureshared.ManagedIdentityUserAssignedIdentity: true, + stdlib.NetworkDNS: true, + } +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/permissions/identity#microsoftmanagedidentity +func (m managedIdentityFederatedIdentityCredentialWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.ManagedIdentity/userAssignedIdentities/federatedIdentityCredentials/read", + } +} + +func (m managedIdentityFederatedIdentityCredentialWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/managedidentity-federated-identity-credential_test.go b/sources/azure/manual/managedidentity-federated-identity-credential_test.go new file mode 100644 index 00000000..ba894d04 --- /dev/null +++ b/sources/azure/manual/managedidentity-federated-identity-credential_test.go @@ -0,0 +1,402 @@ +package manual_test + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +// mockFederatedIdentityCredentialsPager is a simple mock implementation of FederatedIdentityCredentialsPager +type mockFederatedIdentityCredentialsPager struct { + pages []armmsi.FederatedIdentityCredentialsClientListResponse + index int +} + +func (m *mockFederatedIdentityCredentialsPager) More() bool { + return m.index < len(m.pages) +} + +func (m *mockFederatedIdentityCredentialsPager) NextPage(ctx context.Context) (armmsi.FederatedIdentityCredentialsClientListResponse, error) { + if m.index >= len(m.pages) { + return armmsi.FederatedIdentityCredentialsClientListResponse{}, errors.New("no more pages") + } + page := m.pages[m.index] + m.index++ + return page, nil +} + +// errorFederatedIdentityCredentialsPager is a mock pager that always returns an error +type errorFederatedIdentityCredentialsPager struct{} + +func (e *errorFederatedIdentityCredentialsPager) More() bool { + return true +} + +func (e *errorFederatedIdentityCredentialsPager) NextPage(ctx context.Context) (armmsi.FederatedIdentityCredentialsClientListResponse, error) { + return armmsi.FederatedIdentityCredentialsClientListResponse{}, errors.New("pager error") +} + +// testFederatedIdentityCredentialsClient wraps the mock to implement the correct interface +type testFederatedIdentityCredentialsClient struct { + *mocks.MockFederatedIdentityCredentialsClient + pager clients.FederatedIdentityCredentialsPager +} + +func (t *testFederatedIdentityCredentialsClient) NewListPager(resourceGroupName string, resourceName string, options *armmsi.FederatedIdentityCredentialsClientListOptions) clients.FederatedIdentityCredentialsPager { + return t.pager +} + +func TestManagedIdentityFederatedIdentityCredential(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + identityName := "test-identity" + credentialName := "test-credential" + + t.Run("Get", func(t *testing.T) { + credential := createAzureFederatedIdentityCredential(credentialName) + + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, identityName, credentialName, nil).Return( + armmsi.FederatedIdentityCredentialsClientGetResponse{ + FederatedIdentityCredential: *credential, + }, nil) + + testClient := &testFederatedIdentityCredentialsClient{MockFederatedIdentityCredentialsClient: mockClient} + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(identityName, credentialName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.ManagedIdentityFederatedIdentityCredential.String() { + t.Errorf("Expected type %s, got %s", azureshared.ManagedIdentityFederatedIdentityCredential, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + if sdpItem.UniqueAttributeValue() != shared.CompositeLookupKey(identityName, credentialName) { + t.Errorf("Expected unique attribute value %s, got %s", shared.CompositeLookupKey(identityName, credentialName), sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != subscriptionID+"."+resourceGroup { + t.Errorf("Expected scope %s, got %s", subscriptionID+"."+resourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + t.Run("StaticTests", func(t *testing.T) { + linkedQueries := sdpItem.GetLinkedItemQueries() + if len(linkedQueries) != 2 { + t.Fatalf("Expected 2 linked queries, got: %d", len(linkedQueries)) + } + + queryTests := shared.QueryTests{ + { + ExpectedType: azureshared.ManagedIdentityUserAssignedIdentity.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: identityName, + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + ExpectedType: stdlib.NetworkDNS.String(), + ExpectedMethod: sdp.QueryMethod_SEARCH, + ExpectedQuery: "token.actions.githubusercontent.com", + ExpectedScope: "global", + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + testClient := &testFederatedIdentityCredentialsClient{MockFederatedIdentityCredentialsClient: mockClient} + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], identityName, true) + if qErr == nil { + t.Error("Expected error when providing insufficient query parts, but got nil") + } + }) + + t.Run("GetWithEmptyIdentityName", func(t *testing.T) { + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + testClient := &testFederatedIdentityCredentialsClient{MockFederatedIdentityCredentialsClient: mockClient} + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey("", credentialName) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting with empty identity name, but got nil") + } + }) + + t.Run("GetWithEmptyCredentialName", func(t *testing.T) { + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + testClient := &testFederatedIdentityCredentialsClient{MockFederatedIdentityCredentialsClient: mockClient} + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(identityName, "") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting with empty credential name, but got nil") + } + }) + + t.Run("Search", func(t *testing.T) { + credential1 := createAzureFederatedIdentityCredential("credential-1") + credential2 := createAzureFederatedIdentityCredential("credential-2") + + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + mockPager := &mockFederatedIdentityCredentialsPager{ + pages: []armmsi.FederatedIdentityCredentialsClientListResponse{ + { + FederatedIdentityCredentialsListResult: armmsi.FederatedIdentityCredentialsListResult{ + Value: []*armmsi.FederatedIdentityCredential{credential1, credential2}, + }, + }, + }, + } + + testClient := &testFederatedIdentityCredentialsClient{ + MockFederatedIdentityCredentialsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], identityName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if err := item.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + if item.GetType() != azureshared.ManagedIdentityFederatedIdentityCredential.String() { + t.Errorf("Expected type %s, got %s", azureshared.ManagedIdentityFederatedIdentityCredential, item.GetType()) + } + } + }) + + t.Run("SearchStream", func(t *testing.T) { + credential1 := createAzureFederatedIdentityCredential("credential-1") + credential2 := createAzureFederatedIdentityCredential("credential-2") + + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + mockPager := &mockFederatedIdentityCredentialsPager{ + pages: []armmsi.FederatedIdentityCredentialsClientListResponse{ + { + FederatedIdentityCredentialsListResult: armmsi.FederatedIdentityCredentialsListResult{ + Value: []*armmsi.FederatedIdentityCredential{credential1, credential2}, + }, + }, + }, + } + + testClient := &testFederatedIdentityCredentialsClient{ + MockFederatedIdentityCredentialsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + + var errs []error + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + searchStreamable, ok := adapter.(discovery.SearchStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support SearchStream operation") + } + + searchStreamable.SearchStream(ctx, wrapper.Scopes()[0], identityName, true, stream) + wg.Wait() + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + }) + + t.Run("SearchWithEmptyIdentityName", func(t *testing.T) { + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + testClient := &testFederatedIdentityCredentialsClient{MockFederatedIdentityCredentialsClient: mockClient} + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0], "") + if qErr == nil { + t.Error("Expected error when providing empty identity name, but got nil") + } + }) + + t.Run("SearchWithNoQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + testClient := &testFederatedIdentityCredentialsClient{MockFederatedIdentityCredentialsClient: mockClient} + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) + if qErr == nil { + t.Error("Expected error when providing no query parts, but got nil") + } + }) + + t.Run("Search_CredentialWithNilName", func(t *testing.T) { + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + mockPager := &mockFederatedIdentityCredentialsPager{ + pages: []armmsi.FederatedIdentityCredentialsClientListResponse{ + { + FederatedIdentityCredentialsListResult: armmsi.FederatedIdentityCredentialsListResult{ + Value: []*armmsi.FederatedIdentityCredential{ + {Name: nil}, + createAzureFederatedIdentityCredential("valid-credential"), + }, + }, + }, + }, + } + + testClient := &testFederatedIdentityCredentialsClient{ + MockFederatedIdentityCredentialsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], identityName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item, got: %d", len(sdpItems)) + } + + if sdpItems[0].UniqueAttributeValue() != shared.CompositeLookupKey(identityName, "valid-credential") { + t.Errorf("Expected credential unique value '%s', got %s", shared.CompositeLookupKey(identityName, "valid-credential"), sdpItems[0].UniqueAttributeValue()) + } + }) + + t.Run("ErrorHandling_Get", func(t *testing.T) { + expectedErr := errors.New("credential not found") + + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, identityName, "nonexistent", nil).Return( + armmsi.FederatedIdentityCredentialsClientGetResponse{}, expectedErr) + + testClient := &testFederatedIdentityCredentialsClient{MockFederatedIdentityCredentialsClient: mockClient} + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(identityName, "nonexistent") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting non-existent credential, but got nil") + } + }) + + t.Run("ErrorHandling_Search", func(t *testing.T) { + mockClient := mocks.NewMockFederatedIdentityCredentialsClient(ctrl) + errorPager := &errorFederatedIdentityCredentialsPager{} + + testClient := &testFederatedIdentityCredentialsClient{ + MockFederatedIdentityCredentialsClient: mockClient, + pager: errorPager, + } + + wrapper := manual.NewManagedIdentityFederatedIdentityCredential(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + _, err := searchable.Search(ctx, wrapper.Scopes()[0], identityName, true) + if err == nil { + t.Error("Expected error from pager when NextPage returns an error, but got nil") + } + }) +} + +func createAzureFederatedIdentityCredential(name string) *armmsi.FederatedIdentityCredential { + return &armmsi.FederatedIdentityCredential{ + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.ManagedIdentity/userAssignedIdentities/test-identity/federatedIdentityCredentials/" + name), + Name: new(name), + Type: new("Microsoft.ManagedIdentity/userAssignedIdentities/federatedIdentityCredentials"), + Properties: &armmsi.FederatedIdentityCredentialProperties{ + Issuer: new("https://token.actions.githubusercontent.com"), + Subject: new("repo:example/repo:ref:refs/heads/main"), + Audiences: []*string{new("api://AzureADTokenExchange")}, + }, + } +} diff --git a/sources/azure/manual/network-ip-group.go b/sources/azure/manual/network-ip-group.go new file mode 100644 index 00000000..0bcb1fca --- /dev/null +++ b/sources/azure/manual/network-ip-group.go @@ -0,0 +1,236 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +var NetworkIPGroupLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkIPGroup) + +type networkIPGroupWrapper struct { + client clients.IPGroupsClient + + *azureshared.MultiResourceGroupBase +} + +// NewNetworkIPGroup creates a new networkIPGroupWrapper instance. +func NewNetworkIPGroup(client clients.IPGroupsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { + return &networkIPGroupWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, + azureshared.NetworkIPGroup, + ), + } +} + +// List retrieves all IP groups in a resource group. +// ref: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/ip-groups/list-by-resource-group +func (c networkIPGroupWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + for _, ipGroup := range page.Value { + if ipGroup.Name == nil { + continue + } + item, sdpErr := c.azureIPGroupToSDPItem(ipGroup, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + return items, nil +} + +// ListStream streams all IP groups in a resource group. +func (c networkIPGroupWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + for _, ipGroup := range page.Value { + if ipGroup.Name == nil { + continue + } + item, sdpErr := c.azureIPGroupToSDPItem(ipGroup, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +// Get retrieves a single IP group by name. +// ref: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/ip-groups/get +func (c networkIPGroupWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, azureshared.QueryError(errors.New("query must be exactly one part (IP group name)"), scope, c.Type()) + } + ipGroupName := queryParts[0] + if ipGroupName == "" { + return nil, azureshared.QueryError(errors.New("IP group name cannot be empty"), scope, c.Type()) + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + resp, err := c.client.Get(ctx, rgScope.ResourceGroup, ipGroupName, nil) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + return c.azureIPGroupToSDPItem(&resp.IPGroup, scope) +} + +func (c networkIPGroupWrapper) azureIPGroupToSDPItem(ipGroup *armnetwork.IPGroup, scope string) (*sdp.Item, *sdp.QueryError) { + if ipGroup.Name == nil { + return nil, azureshared.QueryError(errors.New("IP group name is nil"), scope, c.Type()) + } + attributes, err := shared.ToAttributesWithExclude(ipGroup, "tags") + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.NetworkIPGroup.String(), + UniqueAttribute: "name", + Attributes: attributes, + Scope: scope, + Tags: azureshared.ConvertAzureTags(ipGroup.Tags), + LinkedItemQueries: []*sdp.LinkedItemQuery{}, + } + + // Health from provisioning state + if ipGroup.Properties != nil && ipGroup.Properties.ProvisioningState != nil { + switch *ipGroup.Properties.ProvisioningState { + case armnetwork.ProvisioningStateSucceeded: + sdpItem.Health = sdp.Health_HEALTH_OK.Enum() + case armnetwork.ProvisioningStateCreating, armnetwork.ProvisioningStateUpdating, armnetwork.ProvisioningStateDeleting: + sdpItem.Health = sdp.Health_HEALTH_PENDING.Enum() + case armnetwork.ProvisioningStateFailed, armnetwork.ProvisioningStateCanceled: + sdpItem.Health = sdp.Health_HEALTH_ERROR.Enum() + default: + sdpItem.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + if ipGroup.Properties != nil { + // Link to IP addresses + // IP Groups contain a list of IP addresses or prefixes + for _, ipAddr := range ipGroup.Properties.IPAddresses { + if ipAddr != nil && *ipAddr != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *ipAddr, + Scope: "global", + }, + }) + } + } + + // Link to Firewalls (read-only, references back to Azure Firewalls using this IP Group) + // Note: These are SubResource references containing just IDs + for _, firewall := range ipGroup.Properties.Firewalls { + if firewall != nil && firewall.ID != nil { + firewallName := azureshared.ExtractResourceName(*firewall.ID) + if firewallName != "" { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*firewall.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkFirewall.String(), + Method: sdp.QueryMethod_GET, + Query: firewallName, + Scope: linkedScope, + }, + }) + } + } + } + + // Link to Firewall Policies (read-only, references back to Firewall Policies using this IP Group) + for _, policy := range ipGroup.Properties.FirewallPolicies { + if policy != nil && policy.ID != nil { + policyName := azureshared.ExtractResourceName(*policy.ID) + if policyName != "" { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*policy.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkFirewallPolicy.String(), + Method: sdp.QueryMethod_GET, + Query: policyName, + Scope: linkedScope, + }, + }) + } + } + } + } + + return sdpItem, nil +} + +func (c networkIPGroupWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + NetworkIPGroupLookupByName, + } +} + +func (c networkIPGroupWrapper) PotentialLinks() map[shared.ItemType]bool { + return map[shared.ItemType]bool{ + stdlib.NetworkIP: true, + azureshared.NetworkFirewall: true, + azureshared.NetworkFirewallPolicy: true, + } +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftnetwork +func (c networkIPGroupWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Network/ipGroups/read", + } +} + +func (c networkIPGroupWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/network-ip-group_test.go b/sources/azure/manual/network-ip-group_test.go new file mode 100644 index 00000000..c3cf14b4 --- /dev/null +++ b/sources/azure/manual/network-ip-group_test.go @@ -0,0 +1,423 @@ +package manual_test + +import ( + "context" + "errors" + "slices" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +func TestNetworkIPGroup(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + + t.Run("Get", func(t *testing.T) { + ipGroupName := "test-ip-group" + ipGroup := createAzureIPGroup(ipGroupName) + + mockClient := mocks.NewMockIPGroupsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, ipGroupName, nil).Return( + armnetwork.IPGroupsClientGetResponse{ + IPGroup: *ipGroup, + }, nil) + + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], ipGroupName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkIPGroup.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkIPGroup, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if sdpItem.UniqueAttributeValue() != ipGroupName { + t.Errorf("Expected unique attribute value %s, got %s", ipGroupName, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetTags()["env"] != "test" { + t.Errorf("Expected tag 'env=test', got: %v", sdpItem.GetTags()["env"]) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + ExpectedType: stdlib.NetworkIP.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "10.0.0.0/24", + ExpectedScope: "global", + }, + { + ExpectedType: stdlib.NetworkIP.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "192.168.1.1", + ExpectedScope: "global", + }, + { + ExpectedType: azureshared.NetworkFirewall.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-firewall", + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + ExpectedType: azureshared.NetworkFirewallPolicy.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-firewall-policy", + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + } + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithEmptyName", func(t *testing.T) { + mockClient := mocks.NewMockIPGroupsClient(ctrl) + + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "", true) + if qErr == nil { + t.Error("Expected error when IP group name is empty, but got nil") + } + }) + + t.Run("Get_IPGroupWithNilName", func(t *testing.T) { + provisioningState := armnetwork.ProvisioningStateSucceeded + ipGroupWithNilName := &armnetwork.IPGroup{ + Name: nil, + Location: new("eastus"), + Properties: &armnetwork.IPGroupPropertiesFormat{ + ProvisioningState: &provisioningState, + }, + } + + mockClient := mocks.NewMockIPGroupsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, "test-ip-group", nil).Return( + armnetwork.IPGroupsClientGetResponse{ + IPGroup: *ipGroupWithNilName, + }, nil) + + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-ip-group", true) + if qErr == nil { + t.Error("Expected error when IP group has nil name, but got nil") + } + }) + + t.Run("List", func(t *testing.T) { + ipGroup1 := createAzureIPGroup("ip-group-1") + ipGroup2 := createAzureIPGroup("ip-group-2") + + mockClient := mocks.NewMockIPGroupsClient(ctrl) + mockPager := newMockIPGroupsPager(ctrl, []*armnetwork.IPGroup{ipGroup1, ipGroup2}) + + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, wrapper.Scopes()[0], true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if item.Validate() != nil { + t.Fatalf("Expected no validation error, got: %v", item.Validate()) + } + if item.GetType() != azureshared.NetworkIPGroup.String() { + t.Fatalf("Expected type %s, got: %s", azureshared.NetworkIPGroup, item.GetType()) + } + } + }) + + t.Run("List_WithNilName", func(t *testing.T) { + ipGroup1 := createAzureIPGroup("ip-group-1") + provisioningState := armnetwork.ProvisioningStateSucceeded + ipGroup2NilName := &armnetwork.IPGroup{ + Name: nil, + Location: new("eastus"), + Tags: map[string]*string{"env": new("test")}, + Properties: &armnetwork.IPGroupPropertiesFormat{ + ProvisioningState: &provisioningState, + }, + } + + mockClient := mocks.NewMockIPGroupsClient(ctrl) + mockPager := newMockIPGroupsPager(ctrl, []*armnetwork.IPGroup{ipGroup1, ipGroup2NilName}) + + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, wrapper.Scopes()[0], true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (nil name skipped), got: %d", len(sdpItems)) + } + if sdpItems[0].UniqueAttributeValue() != "ip-group-1" { + t.Errorf("Expected item name 'ip-group-1', got: %s", sdpItems[0].UniqueAttributeValue()) + } + }) + + t.Run("ListStream", func(t *testing.T) { + ipGroup1 := createAzureIPGroup("stream-ip-group-1") + ipGroup2 := createAzureIPGroup("stream-ip-group-2") + + mockClient := mocks.NewMockIPGroupsClient(ctrl) + mockPager := newMockIPGroupsPager(ctrl, []*armnetwork.IPGroup{ipGroup1, ipGroup2}) + + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + var errs []error + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + listStreamable, ok := adapter.(discovery.ListStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support ListStream operation") + } + + listStreamable.ListStream(ctx, wrapper.Scopes()[0], true, stream) + wg.Wait() + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + }) + + t.Run("ErrorHandling", func(t *testing.T) { + expectedErr := errors.New("IP group not found") + + mockClient := mocks.NewMockIPGroupsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-ip-group", nil).Return( + armnetwork.IPGroupsClientGetResponse{}, expectedErr) + + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-ip-group", true) + if qErr == nil { + t.Error("Expected error when getting non-existent IP group, but got nil") + } + }) + + t.Run("InterfaceCompliance", func(t *testing.T) { + mockClient := mocks.NewMockIPGroupsClient(ctrl) + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + w := wrapper.(sources.Wrapper) + + permissions := w.IAMPermissions() + if len(permissions) == 0 { + t.Error("Expected IAMPermissions to return at least one permission") + } + expectedPermission := "Microsoft.Network/ipGroups/read" + if !slices.Contains(permissions, expectedPermission) { + t.Errorf("Expected IAMPermissions to include %s", expectedPermission) + } + + lookups := w.GetLookups() + foundLookup := false + for _, lookup := range lookups { + if lookup.ItemType == azureshared.NetworkIPGroup { + foundLookup = true + break + } + } + if !foundLookup { + t.Error("Expected GetLookups to include NetworkIPGroup") + } + + potentialLinks := w.PotentialLinks() + if !potentialLinks[stdlib.NetworkIP] { + t.Error("Expected PotentialLinks to include stdlib.NetworkIP") + } + if !potentialLinks[azureshared.NetworkFirewall] { + t.Error("Expected PotentialLinks to include NetworkFirewall") + } + if !potentialLinks[azureshared.NetworkFirewallPolicy] { + t.Error("Expected PotentialLinks to include NetworkFirewallPolicy") + } + }) + + t.Run("HealthStatus", func(t *testing.T) { + tests := []struct { + name string + provisioningState armnetwork.ProvisioningState + expectedHealth sdp.Health + }{ + {"Succeeded", armnetwork.ProvisioningStateSucceeded, sdp.Health_HEALTH_OK}, + {"Updating", armnetwork.ProvisioningStateUpdating, sdp.Health_HEALTH_PENDING}, + {"Deleting", armnetwork.ProvisioningStateDeleting, sdp.Health_HEALTH_PENDING}, + {"Failed", armnetwork.ProvisioningStateFailed, sdp.Health_HEALTH_ERROR}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + ipGroup := &armnetwork.IPGroup{ + ID: new("/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/ipGroups/test-ip-group"), + Name: new("test-ip-group"), + Type: new("Microsoft.Network/ipGroups"), + Location: new("eastus"), + Tags: map[string]*string{}, + Properties: &armnetwork.IPGroupPropertiesFormat{ + ProvisioningState: &tc.provisioningState, + }, + } + + mockClient := mocks.NewMockIPGroupsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, "test-ip-group", nil).Return( + armnetwork.IPGroupsClientGetResponse{ + IPGroup: *ipGroup, + }, nil) + + wrapper := manual.NewNetworkIPGroup(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "test-ip-group", true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetHealth() != tc.expectedHealth { + t.Errorf("Expected health %v, got %v", tc.expectedHealth, sdpItem.GetHealth()) + } + }) + } + }) +} + +type mockIPGroupsPager struct { + ctrl *gomock.Controller + items []*armnetwork.IPGroup + index int + more bool +} + +func newMockIPGroupsPager(ctrl *gomock.Controller, items []*armnetwork.IPGroup) clients.IPGroupsPager { + return &mockIPGroupsPager{ + ctrl: ctrl, + items: items, + index: 0, + more: len(items) > 0, + } +} + +func (m *mockIPGroupsPager) More() bool { + return m.more +} + +func (m *mockIPGroupsPager) NextPage(ctx context.Context) (armnetwork.IPGroupsClientListByResourceGroupResponse, error) { + if m.index >= len(m.items) { + m.more = false + return armnetwork.IPGroupsClientListByResourceGroupResponse{ + IPGroupListResult: armnetwork.IPGroupListResult{ + Value: []*armnetwork.IPGroup{}, + }, + }, nil + } + item := m.items[m.index] + m.index++ + m.more = m.index < len(m.items) + return armnetwork.IPGroupsClientListByResourceGroupResponse{ + IPGroupListResult: armnetwork.IPGroupListResult{ + Value: []*armnetwork.IPGroup{item}, + }, + }, nil +} + +func createAzureIPGroup(name string) *armnetwork.IPGroup { + provisioningState := armnetwork.ProvisioningStateSucceeded + return &armnetwork.IPGroup{ + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/ipGroups/" + name), + Name: new(name), + Type: new("Microsoft.Network/ipGroups"), + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + "project": new("testing"), + }, + Properties: &armnetwork.IPGroupPropertiesFormat{ + ProvisioningState: &provisioningState, + IPAddresses: []*string{ + new("10.0.0.0/24"), + new("192.168.1.1"), + }, + Firewalls: []*armnetwork.SubResource{ + { + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/azureFirewalls/test-firewall"), + }, + }, + FirewallPolicies: []*armnetwork.SubResource{ + { + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Network/firewallPolicies/test-firewall-policy"), + }, + }, + }, + } +} + +var _ clients.IPGroupsPager = (*mockIPGroupsPager)(nil) diff --git a/sources/azure/manual/network-local-network-gateway.go b/sources/azure/manual/network-local-network-gateway.go new file mode 100644 index 00000000..8185b1fd --- /dev/null +++ b/sources/azure/manual/network-local-network-gateway.go @@ -0,0 +1,284 @@ +package manual + +import ( + "context" + "errors" + "net" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +var NetworkLocalNetworkGatewayLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkLocalNetworkGateway) + +type networkLocalNetworkGatewayWrapper struct { + client clients.LocalNetworkGatewaysClient + + *azureshared.MultiResourceGroupBase +} + +// NewNetworkLocalNetworkGateway creates a new networkLocalNetworkGatewayWrapper instance. +func NewNetworkLocalNetworkGateway(client clients.LocalNetworkGatewaysClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { + return &networkLocalNetworkGatewayWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, + azureshared.NetworkLocalNetworkGateway, + ), + } +} + +// List retrieves all local network gateways in a scope. +// ref: https://learn.microsoft.com/en-us/rest/api/network-gateway/local-network-gateways/list +func (c networkLocalNetworkGatewayWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + for _, gw := range page.Value { + if gw.Name == nil { + continue + } + item, sdpErr := c.azureLocalNetworkGatewayToSDPItem(gw, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +// ListStream streams all local network gateways in a scope. +func (c networkLocalNetworkGatewayWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + + for _, gw := range page.Value { + if gw.Name == nil { + continue + } + item, sdpErr := c.azureLocalNetworkGatewayToSDPItem(gw, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +// Get retrieves a single local network gateway by name. +// ref: https://learn.microsoft.com/en-us/rest/api/network-gateway/local-network-gateways/get +func (c networkLocalNetworkGatewayWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, azureshared.QueryError(errors.New("queryParts must be at least 1 and be the local network gateway name"), scope, c.Type()) + } + gatewayName := queryParts[0] + if gatewayName == "" { + return nil, azureshared.QueryError(errors.New("localNetworkGatewayName cannot be empty"), scope, c.Type()) + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + result, err := c.client.Get(ctx, rgScope.ResourceGroup, gatewayName, nil) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + return c.azureLocalNetworkGatewayToSDPItem(&result.LocalNetworkGateway, scope) +} + +func (c networkLocalNetworkGatewayWrapper) azureLocalNetworkGatewayToSDPItem(gw *armnetwork.LocalNetworkGateway, scope string) (*sdp.Item, *sdp.QueryError) { + if gw.Name == nil { + return nil, azureshared.QueryError(errors.New("local network gateway name is nil"), scope, c.Type()) + } + + attributes, err := shared.ToAttributesWithExclude(gw, "tags") + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.NetworkLocalNetworkGateway.String(), + UniqueAttribute: "name", + Attributes: attributes, + Scope: scope, + Tags: azureshared.ConvertAzureTags(gw.Tags), + LinkedItemQueries: []*sdp.LinkedItemQuery{}, + } + + // Health from provisioning state + if gw.Properties != nil && gw.Properties.ProvisioningState != nil { + switch *gw.Properties.ProvisioningState { + case armnetwork.ProvisioningStateSucceeded: + sdpItem.Health = sdp.Health_HEALTH_OK.Enum() + case armnetwork.ProvisioningStateCreating, armnetwork.ProvisioningStateUpdating, armnetwork.ProvisioningStateDeleting: + sdpItem.Health = sdp.Health_HEALTH_PENDING.Enum() + case armnetwork.ProvisioningStateFailed, armnetwork.ProvisioningStateCanceled: + sdpItem.Health = sdp.Health_HEALTH_ERROR.Enum() + default: + sdpItem.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + // Gateway IP address (on-premises VPN device IP) + if gw.Properties != nil && gw.Properties.GatewayIPAddress != nil && *gw.Properties.GatewayIPAddress != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *gw.Properties.GatewayIPAddress, + Scope: "global", + }, + }) + } + + // FQDN (if used instead of IP address for the on-premises device) + if gw.Properties != nil && gw.Properties.Fqdn != nil && *gw.Properties.Fqdn != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkDNS.String(), + Method: sdp.QueryMethod_SEARCH, + Query: *gw.Properties.Fqdn, + Scope: "global", + }, + }) + } + + // BGP settings + if gw.Properties != nil && gw.Properties.BgpSettings != nil { + bgp := gw.Properties.BgpSettings + + // BgpPeeringAddress - can be IP or hostname + if bgp.BgpPeeringAddress != nil && *bgp.BgpPeeringAddress != "" { + if net.ParseIP(*bgp.BgpPeeringAddress) != nil { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *bgp.BgpPeeringAddress, + Scope: "global", + }, + }) + } else { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkDNS.String(), + Method: sdp.QueryMethod_SEARCH, + Query: *bgp.BgpPeeringAddress, + Scope: "global", + }, + }) + } + } + + // BgpPeeringAddresses array + if bgp.BgpPeeringAddresses != nil { + for _, peeringAddr := range bgp.BgpPeeringAddresses { + if peeringAddr == nil { + continue + } + // DefaultBgpIPAddresses + for _, ipStr := range peeringAddr.DefaultBgpIPAddresses { + if ipStr != nil && *ipStr != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *ipStr, + Scope: "global", + }, + }) + } + } + // CustomBgpIPAddresses + for _, ipStr := range peeringAddr.CustomBgpIPAddresses { + if ipStr != nil && *ipStr != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *ipStr, + Scope: "global", + }, + }) + } + } + // TunnelIPAddresses + for _, ipStr := range peeringAddr.TunnelIPAddresses { + if ipStr != nil && *ipStr != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *ipStr, + Scope: "global", + }, + }) + } + } + } + } + } + + return sdpItem, nil +} + +func (c networkLocalNetworkGatewayWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + NetworkLocalNetworkGatewayLookupByName, + } +} + +func (c networkLocalNetworkGatewayWrapper) PotentialLinks() map[shared.ItemType]bool { + return map[shared.ItemType]bool{ + stdlib.NetworkIP: true, + stdlib.NetworkDNS: true, + } +} + +// IAMPermissions returns the Azure RBAC permissions required to read this resource. +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftnetwork +func (c networkLocalNetworkGatewayWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Network/localNetworkGateways/read", + } +} + +// PredefinedRole returns the Azure built-in role that grants the required permissions. +func (c networkLocalNetworkGatewayWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/network-local-network-gateway_test.go b/sources/azure/manual/network-local-network-gateway_test.go new file mode 100644 index 00000000..376a290f --- /dev/null +++ b/sources/azure/manual/network-local-network-gateway_test.go @@ -0,0 +1,424 @@ +package manual_test + +import ( + "context" + "errors" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +func TestNetworkLocalNetworkGateway(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + scope := subscriptionID + "." + resourceGroup + + t.Run("Get", func(t *testing.T) { + gatewayName := "test-local-gateway" + gw := createAzureLocalNetworkGateway(gatewayName) + + mockClient := mocks.NewMockLocalNetworkGatewaysClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, gatewayName, nil).Return( + armnetwork.LocalNetworkGatewaysClientGetResponse{ + LocalNetworkGateway: *gw, + }, nil) + + wrapper := manual.NewNetworkLocalNetworkGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, scope, gatewayName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkLocalNetworkGateway.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkLocalNetworkGateway.String(), sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if sdpItem.UniqueAttributeValue() != gatewayName { + t.Errorf("Expected unique attribute value %s, got %s", gatewayName, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetTags()["env"] != "test" { + t.Errorf("Expected tag 'env=test', got: %v", sdpItem.GetTags()["env"]) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + ExpectedType: stdlib.NetworkIP.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "203.0.113.1", + ExpectedScope: "global", + }, + } + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("Get_WithFqdn", func(t *testing.T) { + gatewayName := "test-local-gateway-fqdn" + gw := createAzureLocalNetworkGatewayWithFqdn(gatewayName) + + mockClient := mocks.NewMockLocalNetworkGatewaysClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, gatewayName, nil).Return( + armnetwork.LocalNetworkGatewaysClientGetResponse{ + LocalNetworkGateway: *gw, + }, nil) + + wrapper := manual.NewNetworkLocalNetworkGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, scope, gatewayName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + ExpectedType: stdlib.NetworkDNS.String(), + ExpectedMethod: sdp.QueryMethod_SEARCH, + ExpectedQuery: "vpn.example.com", + ExpectedScope: "global", + }, + } + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("Get_WithBgpSettings", func(t *testing.T) { + gatewayName := "test-local-gateway-bgp" + gw := createAzureLocalNetworkGatewayWithBgp(gatewayName) + + mockClient := mocks.NewMockLocalNetworkGatewaysClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, gatewayName, nil).Return( + armnetwork.LocalNetworkGatewaysClientGetResponse{ + LocalNetworkGateway: *gw, + }, nil) + + wrapper := manual.NewNetworkLocalNetworkGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, scope, gatewayName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + ExpectedType: stdlib.NetworkIP.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "203.0.113.1", + ExpectedScope: "global", + }, + { + ExpectedType: stdlib.NetworkIP.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "10.0.0.1", + ExpectedScope: "global", + }, + } + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithEmptyName", func(t *testing.T) { + mockClient := mocks.NewMockLocalNetworkGatewaysClient(ctrl) + + wrapper := manual.NewNetworkLocalNetworkGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, scope, "", true) + if qErr == nil { + t.Error("Expected error when getting gateway with empty name, but got nil") + } + }) + + t.Run("ErrorHandling", func(t *testing.T) { + gatewayName := "nonexistent-gateway" + expectedErr := errors.New("local network gateway not found") + + mockClient := mocks.NewMockLocalNetworkGatewaysClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, gatewayName, nil).Return( + armnetwork.LocalNetworkGatewaysClientGetResponse{}, expectedErr) + + wrapper := manual.NewNetworkLocalNetworkGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, scope, gatewayName, true) + if qErr == nil { + t.Fatal("Expected error when gateway not found, got nil") + } + }) + + t.Run("List", func(t *testing.T) { + gw1 := createAzureLocalNetworkGateway("local-gateway-1") + gw2 := createAzureLocalNetworkGateway("local-gateway-2") + + mockClient := mocks.NewMockLocalNetworkGatewaysClient(ctrl) + mockPager := newMockLocalNetworkGatewaysPager(ctrl, []*armnetwork.LocalNetworkGateway{gw1, gw2}) + + mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkLocalNetworkGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + items, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got %d", len(items)) + } + + for i, item := range items { + if item.GetType() != azureshared.NetworkLocalNetworkGateway.String() { + t.Errorf("Item %d: expected type %s, got %s", i, azureshared.NetworkLocalNetworkGateway.String(), item.GetType()) + } + if item.Validate() != nil { + t.Errorf("Item %d: validation error: %v", i, item.Validate()) + } + } + }) + + t.Run("ListStream", func(t *testing.T) { + gw1 := createAzureLocalNetworkGateway("local-gateway-1") + gw2 := createAzureLocalNetworkGateway("local-gateway-2") + + mockClient := mocks.NewMockLocalNetworkGatewaysClient(ctrl) + mockPager := newMockLocalNetworkGatewaysPager(ctrl, []*armnetwork.LocalNetworkGateway{gw1, gw2}) + + mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkLocalNetworkGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listStream, ok := adapter.(discovery.ListStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support ListStream operation") + } + + var received []*sdp.Item + stream := &localNetworkGatewayCollectingStream{items: &received} + listStream.ListStream(ctx, scope, true, stream) + + if len(received) != 2 { + t.Fatalf("Expected 2 items from stream, got %d", len(received)) + } + }) + + t.Run("List_NilNameSkipped", func(t *testing.T) { + gw1 := createAzureLocalNetworkGateway("local-gateway-1") + gw2NilName := createAzureLocalNetworkGateway("local-gateway-2") + gw2NilName.Name = nil + + mockClient := mocks.NewMockLocalNetworkGatewaysClient(ctrl) + mockPager := newMockLocalNetworkGatewaysPager(ctrl, []*armnetwork.LocalNetworkGateway{gw1, gw2NilName}) + + mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkLocalNetworkGateway(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + items, err := listable.List(ctx, scope, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(items) != 1 { + t.Fatalf("Expected 1 item (nil name skipped), got %d", len(items)) + } + if items[0].UniqueAttributeValue() != "local-gateway-1" { + t.Errorf("Expected only local-gateway-1, got %s", items[0].UniqueAttributeValue()) + } + }) + + t.Run("GetLookups", func(t *testing.T) { + wrapper := manual.NewNetworkLocalNetworkGateway(nil, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + lookups := wrapper.GetLookups() + if len(lookups) == 0 { + t.Error("Expected GetLookups to return at least one lookup") + } + found := false + for _, l := range lookups { + if l.ItemType.String() == azureshared.NetworkLocalNetworkGateway.String() { + found = true + break + } + } + if !found { + t.Error("Expected GetLookups to include NetworkLocalNetworkGateway") + } + }) + + t.Run("PotentialLinks", func(t *testing.T) { + wrapper := manual.NewNetworkLocalNetworkGateway(nil, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + potentialLinks := wrapper.PotentialLinks() + for _, linkType := range []shared.ItemType{ + stdlib.NetworkIP, + stdlib.NetworkDNS, + } { + if !potentialLinks[linkType] { + t.Errorf("Expected PotentialLinks to include %s", linkType) + } + } + }) +} + +type localNetworkGatewayCollectingStream struct { + items *[]*sdp.Item +} + +func (c *localNetworkGatewayCollectingStream) SendItem(item *sdp.Item) { + *c.items = append(*c.items, item) +} + +func (c *localNetworkGatewayCollectingStream) SendError(err error) {} + +type mockLocalNetworkGatewaysPager struct { + ctrl *gomock.Controller + items []*armnetwork.LocalNetworkGateway + index int + more bool +} + +func newMockLocalNetworkGatewaysPager(ctrl *gomock.Controller, items []*armnetwork.LocalNetworkGateway) *mockLocalNetworkGatewaysPager { + return &mockLocalNetworkGatewaysPager{ + ctrl: ctrl, + items: items, + index: 0, + more: len(items) > 0, + } +} + +func (m *mockLocalNetworkGatewaysPager) More() bool { + return m.more +} + +func (m *mockLocalNetworkGatewaysPager) NextPage(ctx context.Context) (armnetwork.LocalNetworkGatewaysClientListResponse, error) { + if m.index >= len(m.items) { + m.more = false + return armnetwork.LocalNetworkGatewaysClientListResponse{ + LocalNetworkGatewayListResult: armnetwork.LocalNetworkGatewayListResult{ + Value: []*armnetwork.LocalNetworkGateway{}, + }, + }, nil + } + item := m.items[m.index] + m.index++ + m.more = m.index < len(m.items) + return armnetwork.LocalNetworkGatewaysClientListResponse{ + LocalNetworkGatewayListResult: armnetwork.LocalNetworkGatewayListResult{ + Value: []*armnetwork.LocalNetworkGateway{item}, + }, + }, nil +} + +func createAzureLocalNetworkGateway(name string) *armnetwork.LocalNetworkGateway { + provisioningState := armnetwork.ProvisioningStateSucceeded + gatewayIP := "203.0.113.1" + return &armnetwork.LocalNetworkGateway{ + ID: new("/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/localNetworkGateways/" + name), + Name: new(name), + Type: new("Microsoft.Network/localNetworkGateways"), + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + "project": new("testing"), + }, + Properties: &armnetwork.LocalNetworkGatewayPropertiesFormat{ + ProvisioningState: &provisioningState, + GatewayIPAddress: &gatewayIP, + LocalNetworkAddressSpace: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{ + new("10.1.0.0/16"), + new("10.2.0.0/16"), + }, + }, + }, + } +} + +func createAzureLocalNetworkGatewayWithFqdn(name string) *armnetwork.LocalNetworkGateway { + provisioningState := armnetwork.ProvisioningStateSucceeded + fqdn := "vpn.example.com" + return &armnetwork.LocalNetworkGateway{ + ID: new("/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/localNetworkGateways/" + name), + Name: new(name), + Type: new("Microsoft.Network/localNetworkGateways"), + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + }, + Properties: &armnetwork.LocalNetworkGatewayPropertiesFormat{ + ProvisioningState: &provisioningState, + Fqdn: &fqdn, + LocalNetworkAddressSpace: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{ + new("10.1.0.0/16"), + }, + }, + }, + } +} + +func createAzureLocalNetworkGatewayWithBgp(name string) *armnetwork.LocalNetworkGateway { + provisioningState := armnetwork.ProvisioningStateSucceeded + gatewayIP := "203.0.113.1" + bgpPeeringAddress := "10.0.0.1" + asn := int64(65001) + return &armnetwork.LocalNetworkGateway{ + ID: new("/subscriptions/test-sub/resourceGroups/test-rg/providers/Microsoft.Network/localNetworkGateways/" + name), + Name: new(name), + Type: new("Microsoft.Network/localNetworkGateways"), + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + }, + Properties: &armnetwork.LocalNetworkGatewayPropertiesFormat{ + ProvisioningState: &provisioningState, + GatewayIPAddress: &gatewayIP, + BgpSettings: &armnetwork.BgpSettings{ + Asn: &asn, + BgpPeeringAddress: &bgpPeeringAddress, + }, + LocalNetworkAddressSpace: &armnetwork.AddressSpace{ + AddressPrefixes: []*string{ + new("10.1.0.0/16"), + }, + }, + }, + } +} diff --git a/sources/azure/manual/network-network-interface-ip-configuration.go b/sources/azure/manual/network-network-interface-ip-configuration.go new file mode 100644 index 00000000..52a0424b --- /dev/null +++ b/sources/azure/manual/network-network-interface-ip-configuration.go @@ -0,0 +1,459 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +var NetworkNetworkInterfaceIPConfigurationLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkNetworkInterfaceIPConfiguration) + +type networkNetworkInterfaceIPConfigurationWrapper struct { + client clients.InterfaceIPConfigurationsClient + + *azureshared.MultiResourceGroupBase +} + +func NewNetworkNetworkInterfaceIPConfiguration(client clients.InterfaceIPConfigurationsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { + return &networkNetworkInterfaceIPConfigurationWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, + azureshared.NetworkNetworkInterfaceIPConfiguration, + ), + } +} + +func (n networkNetworkInterfaceIPConfigurationWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 2 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Get requires 2 query parts: networkInterfaceName and ipConfigurationName", + Scope: scope, + ItemType: n.Type(), + } + } + networkInterfaceName := queryParts[0] + ipConfigurationName := queryParts[1] + + if networkInterfaceName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "networkInterfaceName cannot be empty", + Scope: scope, + ItemType: n.Type(), + } + } + if ipConfigurationName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "ipConfigurationName cannot be empty", + Scope: scope, + ItemType: n.Type(), + } + } + + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + + resp, err := n.client.Get(ctx, rgScope.ResourceGroup, networkInterfaceName, ipConfigurationName) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + + return n.azureIPConfigurationToSDPItem(&resp.InterfaceIPConfiguration, networkInterfaceName, scope) +} + +func (n networkNetworkInterfaceIPConfigurationWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + NetworkNetworkInterfaceLookupByName, + NetworkNetworkInterfaceIPConfigurationLookupByName, + } +} + +func (n networkNetworkInterfaceIPConfigurationWrapper) Search(ctx context.Context, scope string, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Search requires 1 query part: networkInterfaceName", + Scope: scope, + ItemType: n.Type(), + } + } + networkInterfaceName := queryParts[0] + + if networkInterfaceName == "" { + return nil, azureshared.QueryError(errors.New("networkInterfaceName cannot be empty"), scope, n.Type()) + } + + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + + pager := n.client.List(ctx, rgScope.ResourceGroup, networkInterfaceName) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + + for _, ipConfig := range page.Value { + if ipConfig.Name == nil { + continue + } + + item, sdpErr := n.azureIPConfigurationToSDPItem(ipConfig, networkInterfaceName, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +func (n networkNetworkInterfaceIPConfigurationWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { + if len(queryParts) < 1 { + stream.SendError(azureshared.QueryError(errors.New("SearchStream requires 1 query part: networkInterfaceName"), scope, n.Type())) + return + } + networkInterfaceName := queryParts[0] + + if networkInterfaceName == "" { + stream.SendError(azureshared.QueryError(errors.New("networkInterfaceName cannot be empty"), scope, n.Type())) + return + } + + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + + pager := n.client.List(ctx, rgScope.ResourceGroup, networkInterfaceName) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + for _, ipConfig := range page.Value { + if ipConfig.Name == nil { + continue + } + item, sdpErr := n.azureIPConfigurationToSDPItem(ipConfig, networkInterfaceName, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (n networkNetworkInterfaceIPConfigurationWrapper) SearchLookups() []sources.ItemTypeLookups { + return []sources.ItemTypeLookups{ + { + NetworkNetworkInterfaceLookupByName, + }, + } +} + +// ref: https://learn.microsoft.com/en-us/rest/api/virtualnetwork/network-interface-ip-configurations/get +func (n networkNetworkInterfaceIPConfigurationWrapper) azureIPConfigurationToSDPItem(ipConfig *armnetwork.InterfaceIPConfiguration, networkInterfaceName, scope string) (*sdp.Item, *sdp.QueryError) { + if ipConfig.Name == nil { + return nil, azureshared.QueryError(errors.New("IP configuration name is nil"), scope, n.Type()) + } + + attributes, err := shared.ToAttributesWithExclude(ipConfig) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(networkInterfaceName, *ipConfig.Name)) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.NetworkNetworkInterfaceIPConfiguration.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attributes, + Scope: scope, + } + + // Health status based on provisioning state + if ipConfig.Properties != nil && ipConfig.Properties.ProvisioningState != nil { + switch *ipConfig.Properties.ProvisioningState { + case armnetwork.ProvisioningStateSucceeded: + sdpItem.Health = sdp.Health_HEALTH_OK.Enum() + case armnetwork.ProvisioningStateUpdating, armnetwork.ProvisioningStateDeleting, armnetwork.ProvisioningStateCreating: + sdpItem.Health = sdp.Health_HEALTH_PENDING.Enum() + case armnetwork.ProvisioningStateFailed, armnetwork.ProvisioningStateCanceled: + sdpItem.Health = sdp.Health_HEALTH_ERROR.Enum() + default: + sdpItem.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + // Link back to parent NetworkInterface + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkNetworkInterface.String(), + Method: sdp.QueryMethod_GET, + Query: networkInterfaceName, + Scope: scope, + }, + }) + + if ipConfig.Properties != nil { + props := ipConfig.Properties + + // Subnet link + if props.Subnet != nil && props.Subnet.ID != nil { + subnetParams := azureshared.ExtractPathParamsFromResourceID(*props.Subnet.ID, []string{"virtualNetworks", "subnets"}) + if len(subnetParams) >= 2 { + vnetName, subnetName := subnetParams[0], subnetParams[1] + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*props.Subnet.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkSubnet.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(vnetName, subnetName), + Scope: linkedScope, + }, + }) + } + } + + // Public IP address link + if props.PublicIPAddress != nil && props.PublicIPAddress.ID != nil { + pipName := azureshared.ExtractResourceName(*props.PublicIPAddress.ID) + if pipName != "" { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*props.PublicIPAddress.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkPublicIPAddress.String(), + Method: sdp.QueryMethod_GET, + Query: pipName, + Scope: linkedScope, + }, + }) + } + } + + // Private IP address -> stdlib ip + if props.PrivateIPAddress != nil && *props.PrivateIPAddress != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *props.PrivateIPAddress, + Scope: "global", + }, + }) + } + + // Application security groups + if props.ApplicationSecurityGroups != nil { + for _, asg := range props.ApplicationSecurityGroups { + if asg != nil && asg.ID != nil { + asgName := azureshared.ExtractResourceName(*asg.ID) + if asgName != "" { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*asg.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkApplicationSecurityGroup.String(), + Method: sdp.QueryMethod_GET, + Query: asgName, + Scope: linkedScope, + }, + }) + } + } + } + } + + // Load balancer backend address pools + if props.LoadBalancerBackendAddressPools != nil { + for _, pool := range props.LoadBalancerBackendAddressPools { + if pool != nil && pool.ID != nil { + params := azureshared.ExtractPathParamsFromResourceID(*pool.ID, []string{"loadBalancers", "backendAddressPools"}) + if len(params) >= 2 { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*pool.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkLoadBalancerBackendAddressPool.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(params[0], params[1]), + Scope: linkedScope, + }, + }) + } + } + } + } + + // Load balancer inbound NAT rules + if props.LoadBalancerInboundNatRules != nil { + for _, rule := range props.LoadBalancerInboundNatRules { + if rule != nil && rule.ID != nil { + params := azureshared.ExtractPathParamsFromResourceID(*rule.ID, []string{"loadBalancers", "inboundNatRules"}) + if len(params) >= 2 { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*rule.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkLoadBalancerInboundNatRule.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(params[0], params[1]), + Scope: linkedScope, + }, + }) + } + } + } + } + + // Application gateway backend address pools + if props.ApplicationGatewayBackendAddressPools != nil { + for _, pool := range props.ApplicationGatewayBackendAddressPools { + if pool != nil && pool.ID != nil { + params := azureshared.ExtractPathParamsFromResourceID(*pool.ID, []string{"applicationGateways", "backendAddressPools"}) + if len(params) >= 2 { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*pool.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkApplicationGatewayBackendAddressPool.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(params[0], params[1]), + Scope: linkedScope, + }, + }) + } + } + } + } + + // Gateway load balancer (frontend IP config reference) + if props.GatewayLoadBalancer != nil && props.GatewayLoadBalancer.ID != nil { + params := azureshared.ExtractPathParamsFromResourceID(*props.GatewayLoadBalancer.ID, []string{"loadBalancers", "frontendIPConfigurations"}) + if len(params) >= 2 { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*props.GatewayLoadBalancer.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkLoadBalancerFrontendIPConfiguration.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(params[0], params[1]), + Scope: linkedScope, + }, + }) + } + } + + // Virtual network taps + if props.VirtualNetworkTaps != nil { + for _, tap := range props.VirtualNetworkTaps { + if tap != nil && tap.ID != nil { + tapName := azureshared.ExtractResourceName(*tap.ID) + if tapName != "" { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*tap.ID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkVirtualNetworkTap.String(), + Method: sdp.QueryMethod_GET, + Query: tapName, + Scope: linkedScope, + }, + }) + } + } + } + } + + // PrivateLinkConnectionProperties - FQDNs + if props.PrivateLinkConnectionProperties != nil && props.PrivateLinkConnectionProperties.Fqdns != nil { + for _, fqdn := range props.PrivateLinkConnectionProperties.Fqdns { + if fqdn != nil && *fqdn != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkDNS.String(), + Method: sdp.QueryMethod_SEARCH, + Query: *fqdn, + Scope: "global", + }, + }) + } + } + } + } + + return sdpItem, nil +} + +func (n networkNetworkInterfaceIPConfigurationWrapper) PotentialLinks() map[shared.ItemType]bool { + return map[shared.ItemType]bool{ + azureshared.NetworkNetworkInterface: true, + azureshared.NetworkSubnet: true, + azureshared.NetworkPublicIPAddress: true, + azureshared.NetworkApplicationSecurityGroup: true, + azureshared.NetworkLoadBalancerBackendAddressPool: true, + azureshared.NetworkLoadBalancerInboundNatRule: true, + azureshared.NetworkApplicationGatewayBackendAddressPool: true, + azureshared.NetworkLoadBalancerFrontendIPConfiguration: true, + azureshared.NetworkVirtualNetworkTap: true, + stdlib.NetworkIP: true, + stdlib.NetworkDNS: true, + } +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftnetwork +func (n networkNetworkInterfaceIPConfigurationWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Network/networkInterfaces/ipConfigurations/read", + } +} + +func (n networkNetworkInterfaceIPConfigurationWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/network-network-interface-ip-configuration_test.go b/sources/azure/manual/network-network-interface-ip-configuration_test.go new file mode 100644 index 00000000..a516371b --- /dev/null +++ b/sources/azure/manual/network-network-interface-ip-configuration_test.go @@ -0,0 +1,605 @@ +package manual_test + +import ( + "context" + "errors" + "reflect" + "slices" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +// MockInterfaceIPConfigurationsPager is a simple mock for InterfaceIPConfigurationsPager +type MockInterfaceIPConfigurationsPager struct { + ctrl *gomock.Controller + recorder *MockInterfaceIPConfigurationsPagerMockRecorder +} + +type MockInterfaceIPConfigurationsPagerMockRecorder struct { + mock *MockInterfaceIPConfigurationsPager +} + +func NewMockInterfaceIPConfigurationsPager(ctrl *gomock.Controller) *MockInterfaceIPConfigurationsPager { + mock := &MockInterfaceIPConfigurationsPager{ctrl: ctrl} + mock.recorder = &MockInterfaceIPConfigurationsPagerMockRecorder{mock} + return mock +} + +func (m *MockInterfaceIPConfigurationsPager) EXPECT() *MockInterfaceIPConfigurationsPagerMockRecorder { + return m.recorder +} + +func (m *MockInterfaceIPConfigurationsPager) More() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "More") + ret0, _ := ret[0].(bool) + return ret0 +} + +func (mr *MockInterfaceIPConfigurationsPagerMockRecorder) More() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "More", reflect.TypeFor[func() bool]()) +} + +func (m *MockInterfaceIPConfigurationsPager) NextPage(ctx context.Context) (armnetwork.InterfaceIPConfigurationsClientListResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextPage", ctx) + ret0, _ := ret[0].(armnetwork.InterfaceIPConfigurationsClientListResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (mr *MockInterfaceIPConfigurationsPagerMockRecorder) NextPage(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextPage", reflect.TypeFor[func(ctx context.Context) (armnetwork.InterfaceIPConfigurationsClientListResponse, error)](), ctx) +} + +// testInterfaceIPConfigurationsClient wraps the mock to implement the correct interface +type testInterfaceIPConfigurationsClient struct { + *mocks.MockInterfaceIPConfigurationsClient + pager clients.InterfaceIPConfigurationsPager +} + +func (t *testInterfaceIPConfigurationsClient) List(ctx context.Context, resourceGroupName, networkInterfaceName string) clients.InterfaceIPConfigurationsPager { + return t.pager +} + +func TestNetworkNetworkInterfaceIPConfiguration(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + networkInterfaceName := "test-nic" + ipConfigName := "ipconfig1" + + t.Run("Get", func(t *testing.T) { + ipConfig := createAzureIPConfiguration(subscriptionID, resourceGroup, networkInterfaceName, ipConfigName) + + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, networkInterfaceName, ipConfigName).Return( + armnetwork.InterfaceIPConfigurationsClientGetResponse{ + InterfaceIPConfiguration: *ipConfig, + }, nil) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(networkInterfaceName, ipConfigName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkNetworkInterfaceIPConfiguration.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkNetworkInterfaceIPConfiguration, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + expectedUniqueValue := shared.CompositeLookupKey(networkInterfaceName, ipConfigName) + if sdpItem.UniqueAttributeValue() != expectedUniqueValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueValue, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != subscriptionID+"."+resourceGroup { + t.Errorf("Expected scope %s, got %s", subscriptionID+"."+resourceGroup, sdpItem.GetScope()) + } + + if sdpItem.Validate() != nil { + t.Fatalf("Expected no validation error, got: %v", sdpItem.Validate()) + } + + // Verify health status is OK for Succeeded provisioning state + if sdpItem.GetHealth() != sdp.Health_HEALTH_OK { + t.Errorf("Expected health OK, got %v", sdpItem.GetHealth()) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + // Parent NetworkInterface link + ExpectedType: azureshared.NetworkNetworkInterface.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: networkInterfaceName, + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + // Subnet link + ExpectedType: azureshared.NetworkSubnet.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: shared.CompositeLookupKey("test-vnet", "test-subnet"), + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + // Public IP address link + ExpectedType: azureshared.NetworkPublicIPAddress.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-pip", + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + // Private IP address link (stdlib) + ExpectedType: stdlib.NetworkIP.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "10.0.0.4", + ExpectedScope: "global", + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Test with only network interface name (missing ipConfigName) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], networkInterfaceName, true) + if qErr == nil { + t.Error("Expected error when providing insufficient query parts, but got nil") + } + }) + + t.Run("GetWithEmptyNetworkInterfaceName", func(t *testing.T) { + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Test directly on wrapper to get the QueryError + _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0], "", ipConfigName) + if qErr == nil { + t.Fatal("Expected error when providing empty network interface name, but got nil") + } + if qErr.GetErrorString() != "networkInterfaceName cannot be empty" { + t.Errorf("Expected specific error message, got: %s", qErr.GetErrorString()) + } + }) + + t.Run("GetWithEmptyIPConfigName", func(t *testing.T) { + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Test directly on wrapper to get the QueryError + _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0], networkInterfaceName, "") + if qErr == nil { + t.Fatal("Expected error when providing empty IP config name, but got nil") + } + if qErr.GetErrorString() != "ipConfigurationName cannot be empty" { + t.Errorf("Expected specific error message, got: %s", qErr.GetErrorString()) + } + }) + + t.Run("Search", func(t *testing.T) { + ipConfig1 := createAzureIPConfiguration(subscriptionID, resourceGroup, networkInterfaceName, "ipconfig1") + ipConfig2 := createAzureIPConfiguration(subscriptionID, resourceGroup, networkInterfaceName, "ipconfig2") + + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockPager := NewMockInterfaceIPConfigurationsPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armnetwork.InterfaceIPConfigurationsClientListResponse{ + InterfaceIPConfigurationListResult: armnetwork.InterfaceIPConfigurationListResult{ + Value: []*armnetwork.InterfaceIPConfiguration{ipConfig1, ipConfig2}, + }, + }, nil), + mockPager.EXPECT().More().Return(false), + ) + + testClient := &testInterfaceIPConfigurationsClient{ + MockInterfaceIPConfigurationsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], networkInterfaceName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if item.Validate() != nil { + t.Fatalf("Expected no validation error, got: %v", item.Validate()) + } + + if item.GetType() != azureshared.NetworkNetworkInterfaceIPConfiguration.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkNetworkInterfaceIPConfiguration, item.GetType()) + } + } + }) + + t.Run("SearchWithEmptyNetworkInterfaceName", func(t *testing.T) { + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + testClient := &testInterfaceIPConfigurationsClient{MockInterfaceIPConfigurationsClient: mockClient} + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Test Search directly with empty network interface name + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0], "") + if qErr == nil { + t.Error("Expected error when providing empty network interface name, but got nil") + } + }) + + t.Run("SearchWithNoQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + testClient := &testInterfaceIPConfigurationsClient{MockInterfaceIPConfigurationsClient: mockClient} + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Test Search directly with no query parts + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) + if qErr == nil { + t.Error("Expected error when providing no query parts, but got nil") + } + }) + + t.Run("Search_IPConfigWithNilName", func(t *testing.T) { + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockPager := NewMockInterfaceIPConfigurationsPager(ctrl) + + ipConfigValid := createAzureIPConfiguration(subscriptionID, resourceGroup, networkInterfaceName, "ipconfig-valid") + ipConfigNilName := &armnetwork.InterfaceIPConfiguration{ + Name: nil, + } + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armnetwork.InterfaceIPConfigurationsClientListResponse{ + InterfaceIPConfigurationListResult: armnetwork.InterfaceIPConfigurationListResult{ + Value: []*armnetwork.InterfaceIPConfiguration{ipConfigNilName, ipConfigValid}, + }, + }, nil), + mockPager.EXPECT().More().Return(false), + ) + + testClient := &testInterfaceIPConfigurationsClient{ + MockInterfaceIPConfigurationsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], networkInterfaceName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Should only return 1 item (the one with a valid name) + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item, got: %d", len(sdpItems)) + } + + expectedUniqueValue := shared.CompositeLookupKey(networkInterfaceName, "ipconfig-valid") + if sdpItems[0].UniqueAttributeValue() != expectedUniqueValue { + t.Errorf("Expected unique value %s, got %s", expectedUniqueValue, sdpItems[0].UniqueAttributeValue()) + } + }) + + t.Run("ErrorHandling_Get", func(t *testing.T) { + expectedErr := errors.New("IP configuration not found") + + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, networkInterfaceName, "nonexistent-ipconfig").Return( + armnetwork.InterfaceIPConfigurationsClientGetResponse{}, expectedErr) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(networkInterfaceName, "nonexistent-ipconfig") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting non-existent IP configuration, but got nil") + } + }) + + t.Run("ErrorHandling_Search", func(t *testing.T) { + expectedErr := errors.New("failed to list IP configurations") + + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockPager := NewMockInterfaceIPConfigurationsPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armnetwork.InterfaceIPConfigurationsClientListResponse{}, expectedErr), + ) + + testClient := &testInterfaceIPConfigurationsClient{ + MockInterfaceIPConfigurationsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + _, err := searchable.Search(ctx, wrapper.Scopes()[0], networkInterfaceName, true) + if err == nil { + t.Error("Expected error when listing IP configurations fails, but got nil") + } + }) + + t.Run("InterfaceCompliance", func(t *testing.T) { + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Cast to sources.Wrapper to access interface methods + w := wrapper.(sources.Wrapper) + + // Verify IAMPermissions + permissions := w.IAMPermissions() + if len(permissions) == 0 { + t.Error("Expected IAMPermissions to return at least one permission") + } + expectedPermission := "Microsoft.Network/networkInterfaces/ipConfigurations/read" + if !slices.Contains(permissions, expectedPermission) { + t.Errorf("Expected IAMPermissions to include %s, got %v", expectedPermission, permissions) + } + + // Verify PotentialLinks + potentialLinks := w.PotentialLinks() + if len(potentialLinks) == 0 { + t.Error("Expected PotentialLinks to return at least one link") + } + if !potentialLinks[azureshared.NetworkNetworkInterface] { + t.Error("Expected PotentialLinks to include NetworkNetworkInterface") + } + if !potentialLinks[azureshared.NetworkSubnet] { + t.Error("Expected PotentialLinks to include NetworkSubnet") + } + if !potentialLinks[stdlib.NetworkIP] { + t.Error("Expected PotentialLinks to include NetworkIP") + } + + // Verify SearchLookups + searchLookups := wrapper.SearchLookups() + if len(searchLookups) == 0 { + t.Error("Expected SearchLookups to return at least one lookup") + } + + // Verify GetLookups + getLookups := wrapper.GetLookups() + if len(getLookups) != 2 { + t.Errorf("Expected GetLookups to return 2 lookups (parent + child), got %d", len(getLookups)) + } + }) + + t.Run("HealthStatus_Pending", func(t *testing.T) { + ipConfig := createAzureIPConfigurationWithProvisioningState(subscriptionID, resourceGroup, networkInterfaceName, ipConfigName, armnetwork.ProvisioningStateUpdating) + + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, networkInterfaceName, ipConfigName).Return( + armnetwork.InterfaceIPConfigurationsClientGetResponse{ + InterfaceIPConfiguration: *ipConfig, + }, nil) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(networkInterfaceName, ipConfigName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetHealth() != sdp.Health_HEALTH_PENDING { + t.Errorf("Expected health PENDING, got %v", sdpItem.GetHealth()) + } + }) + + t.Run("HealthStatus_Error", func(t *testing.T) { + ipConfig := createAzureIPConfigurationWithProvisioningState(subscriptionID, resourceGroup, networkInterfaceName, ipConfigName, armnetwork.ProvisioningStateFailed) + + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, networkInterfaceName, ipConfigName).Return( + armnetwork.InterfaceIPConfigurationsClientGetResponse{ + InterfaceIPConfiguration: *ipConfig, + }, nil) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(networkInterfaceName, ipConfigName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetHealth() != sdp.Health_HEALTH_ERROR { + t.Errorf("Expected health ERROR, got %v", sdpItem.GetHealth()) + } + }) + + t.Run("GetWithApplicationSecurityGroups", func(t *testing.T) { + ipConfig := createAzureIPConfigurationWithASG(subscriptionID, resourceGroup, networkInterfaceName, ipConfigName, "test-asg") + + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, networkInterfaceName, ipConfigName).Return( + armnetwork.InterfaceIPConfigurationsClientGetResponse{ + InterfaceIPConfiguration: *ipConfig, + }, nil) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(networkInterfaceName, ipConfigName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify ASG link exists among the linked queries + foundASG := false + for _, lq := range sdpItem.GetLinkedItemQueries() { + if lq.GetQuery().GetType() == azureshared.NetworkApplicationSecurityGroup.String() && + lq.GetQuery().GetMethod() == sdp.QueryMethod_GET && + lq.GetQuery().GetQuery() == "test-asg" && + lq.GetQuery().GetScope() == subscriptionID+"."+resourceGroup { + foundASG = true + break + } + } + if !foundASG { + t.Error("Expected to find ASG link in linked item queries") + } + }) + + t.Run("GetWithFQDNs", func(t *testing.T) { + ipConfig := createAzureIPConfigurationWithFQDNs(subscriptionID, resourceGroup, networkInterfaceName, ipConfigName, []string{"test.privatelink.blob.core.windows.net", "example.internal"}) + + mockClient := mocks.NewMockInterfaceIPConfigurationsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, networkInterfaceName, ipConfigName).Return( + armnetwork.InterfaceIPConfigurationsClientGetResponse{ + InterfaceIPConfiguration: *ipConfig, + }, nil) + + wrapper := manual.NewNetworkNetworkInterfaceIPConfiguration(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(networkInterfaceName, ipConfigName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify DNS links exist among the linked queries + expectedFQDNs := []string{"test.privatelink.blob.core.windows.net", "example.internal"} + for _, fqdn := range expectedFQDNs { + found := false + for _, lq := range sdpItem.GetLinkedItemQueries() { + if lq.GetQuery().GetType() == stdlib.NetworkDNS.String() && + lq.GetQuery().GetMethod() == sdp.QueryMethod_SEARCH && + lq.GetQuery().GetQuery() == fqdn && + lq.GetQuery().GetScope() == "global" { + found = true + break + } + } + if !found { + t.Errorf("Expected to find DNS link for FQDN %s in linked item queries", fqdn) + } + } + }) +} + +// createAzureIPConfiguration creates a mock Azure IP configuration for testing +func createAzureIPConfiguration(subscriptionID, resourceGroup, nicName, ipConfigName string) *armnetwork.InterfaceIPConfiguration { + subnetID := "/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.Network/virtualNetworks/test-vnet/subnets/test-subnet" + pipID := "/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.Network/publicIPAddresses/test-pip" + provisioningState := armnetwork.ProvisioningStateSucceeded + + return &armnetwork.InterfaceIPConfiguration{ + ID: new("/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.Network/networkInterfaces/" + nicName + "/ipConfigurations/" + ipConfigName), + Name: new(ipConfigName), + Type: new("Microsoft.Network/networkInterfaces/ipConfigurations"), + Properties: &armnetwork.InterfaceIPConfigurationPropertiesFormat{ + ProvisioningState: &provisioningState, + PrivateIPAddress: new("10.0.0.4"), + PrivateIPAllocationMethod: new(armnetwork.IPAllocationMethodDynamic), + Primary: new(true), + Subnet: &armnetwork.Subnet{ + ID: new(subnetID), + }, + PublicIPAddress: &armnetwork.PublicIPAddress{ + ID: new(pipID), + }, + }, + } +} + +// createAzureIPConfigurationWithProvisioningState creates a mock IP config with a specific provisioning state +func createAzureIPConfigurationWithProvisioningState(subscriptionID, resourceGroup, nicName, ipConfigName string, state armnetwork.ProvisioningState) *armnetwork.InterfaceIPConfiguration { + ipConfig := createAzureIPConfiguration(subscriptionID, resourceGroup, nicName, ipConfigName) + ipConfig.Properties.ProvisioningState = &state + return ipConfig +} + +// createAzureIPConfigurationWithASG creates a mock IP config with application security groups +func createAzureIPConfigurationWithASG(subscriptionID, resourceGroup, nicName, ipConfigName, asgName string) *armnetwork.InterfaceIPConfiguration { + ipConfig := createAzureIPConfiguration(subscriptionID, resourceGroup, nicName, ipConfigName) + asgID := "/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.Network/applicationSecurityGroups/" + asgName + ipConfig.Properties.ApplicationSecurityGroups = []*armnetwork.ApplicationSecurityGroup{ + { + ID: new(asgID), + }, + } + return ipConfig +} + +// createAzureIPConfigurationWithFQDNs creates a mock IP config with PrivateLinkConnectionProperties FQDNs +func createAzureIPConfigurationWithFQDNs(subscriptionID, resourceGroup, nicName, ipConfigName string, fqdns []string) *armnetwork.InterfaceIPConfiguration { + ipConfig := createAzureIPConfiguration(subscriptionID, resourceGroup, nicName, ipConfigName) + fqdnPtrs := make([]*string, len(fqdns)) + for i := range fqdns { + fqdnPtrs[i] = new(fqdns[i]) + } + ipConfig.Properties.PrivateLinkConnectionProperties = &armnetwork.InterfaceIPConfigurationPrivateLinkConnectionProperties{ + Fqdns: fqdnPtrs, + } + return ipConfig +} diff --git a/sources/azure/manual/network-network-watcher.go b/sources/azure/manual/network-network-watcher.go new file mode 100644 index 00000000..95bd6a7d --- /dev/null +++ b/sources/azure/manual/network-network-watcher.go @@ -0,0 +1,185 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +var NetworkNetworkWatcherLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkNetworkWatcher) + +type networkNetworkWatcherWrapper struct { + client clients.NetworkWatchersClient + + *azureshared.MultiResourceGroupBase +} + +// NewNetworkNetworkWatcher creates a new NetworkNetworkWatcher adapter (ListableWrapper: top-level resource). +func NewNetworkNetworkWatcher(client clients.NetworkWatchersClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { + return &networkNetworkWatcherWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, + azureshared.NetworkNetworkWatcher, + ), + } +} + +// ref: https://learn.microsoft.com/en-us/rest/api/network-watcher/network-watchers/list +func (c networkNetworkWatcherWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + for _, watcher := range page.Value { + if watcher.Name == nil { + continue + } + item, sdpErr := c.azureNetworkWatcherToSDPItem(watcher, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +func (c networkNetworkWatcherWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + pager := c.client.NewListPager(rgScope.ResourceGroup, nil) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + + for _, watcher := range page.Value { + if watcher.Name == nil { + continue + } + var sdpErr *sdp.QueryError + var item *sdp.Item + item, sdpErr = c.azureNetworkWatcherToSDPItem(watcher, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +// ref: https://learn.microsoft.com/en-us/rest/api/network-watcher/network-watchers/get +func (c networkNetworkWatcherWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, azureshared.QueryError(errors.New("queryParts must be at least 1 and be the network watcher name"), scope, c.Type()) + } + networkWatcherName := queryParts[0] + if networkWatcherName == "" { + return nil, azureshared.QueryError(errors.New("networkWatcherName cannot be empty"), scope, c.Type()) + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + result, err := c.client.Get(ctx, rgScope.ResourceGroup, networkWatcherName, nil) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + return c.azureNetworkWatcherToSDPItem(&result.Watcher, scope) +} + +func (c networkNetworkWatcherWrapper) azureNetworkWatcherToSDPItem(watcher *armnetwork.Watcher, scope string) (*sdp.Item, *sdp.QueryError) { + if watcher.Name == nil { + return nil, azureshared.QueryError(errors.New("network watcher name is nil"), scope, c.Type()) + } + attributes, err := shared.ToAttributesWithExclude(watcher, "tags") + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.NetworkNetworkWatcher.String(), + UniqueAttribute: "name", + Attributes: attributes, + Scope: scope, + Tags: azureshared.ConvertAzureTags(watcher.Tags), + } + + // Map provisioning state to health + if watcher.Properties != nil && watcher.Properties.ProvisioningState != nil { + switch *watcher.Properties.ProvisioningState { + case armnetwork.ProvisioningStateSucceeded: + sdpItem.Health = sdp.Health_HEALTH_OK.Enum() + case armnetwork.ProvisioningStateUpdating, armnetwork.ProvisioningStateDeleting, armnetwork.ProvisioningStateCreating: + sdpItem.Health = sdp.Health_HEALTH_PENDING.Enum() + case armnetwork.ProvisioningStateFailed, armnetwork.ProvisioningStateCanceled: + sdpItem.Health = sdp.Health_HEALTH_ERROR.Enum() + default: + sdpItem.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + // Link to child FlowLogs via SEARCH + // FlowLogs are child resources of NetworkWatcher, so we link via SEARCH with the network watcher name + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkFlowLog.String(), + Method: sdp.QueryMethod_SEARCH, + Query: *watcher.Name, + Scope: scope, + }, + }) + + return sdpItem, nil +} + +func (c networkNetworkWatcherWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + NetworkNetworkWatcherLookupByName, + } +} + +func (c networkNetworkWatcherWrapper) PotentialLinks() map[shared.ItemType]bool { + return shared.NewItemTypesSet( + azureshared.NetworkFlowLog, + ) +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftnetwork +func (c networkNetworkWatcherWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Network/networkWatchers/read", + } +} + +func (c networkNetworkWatcherWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/network-network-watcher_test.go b/sources/azure/manual/network-network-watcher_test.go new file mode 100644 index 00000000..5cc235fb --- /dev/null +++ b/sources/azure/manual/network-network-watcher_test.go @@ -0,0 +1,327 @@ +package manual_test + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" +) + +func TestNetworkNetworkWatcher(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + + t.Run("Get", func(t *testing.T) { + resourceName := "test-network-watcher" + resource := createNetworkWatcher(resourceName) + + mockClient := mocks.NewMockNetworkWatchersClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, resourceName, nil).Return( + armnetwork.WatchersClientGetResponse{ + Watcher: *resource, + }, nil) + + wrapper := manual.NewNetworkNetworkWatcher(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], resourceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkNetworkWatcher.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkNetworkWatcher, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if sdpItem.UniqueAttributeValue() != resourceName { + t.Errorf("Expected unique attribute value %s, got %s", resourceName, sdpItem.UniqueAttributeValue()) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + ExpectedType: azureshared.NetworkFlowLog.String(), + ExpectedMethod: sdp.QueryMethod_SEARCH, + ExpectedQuery: resourceName, + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("Get_ProvisioningStateSucceeded", func(t *testing.T) { + resourceName := "test-network-watcher-succeeded" + resource := createNetworkWatcherWithProvisioningState(resourceName, armnetwork.ProvisioningStateSucceeded) + + mockClient := mocks.NewMockNetworkWatchersClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, resourceName, nil).Return( + armnetwork.WatchersClientGetResponse{ + Watcher: *resource, + }, nil) + + wrapper := manual.NewNetworkNetworkWatcher(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], resourceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetHealth() != sdp.Health_HEALTH_OK { + t.Errorf("Expected health HEALTH_OK, got %s", sdpItem.GetHealth()) + } + }) + + t.Run("Get_ProvisioningStateFailed", func(t *testing.T) { + resourceName := "test-network-watcher-failed" + resource := createNetworkWatcherWithProvisioningState(resourceName, armnetwork.ProvisioningStateFailed) + + mockClient := mocks.NewMockNetworkWatchersClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, resourceName, nil).Return( + armnetwork.WatchersClientGetResponse{ + Watcher: *resource, + }, nil) + + wrapper := manual.NewNetworkNetworkWatcher(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], resourceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetHealth() != sdp.Health_HEALTH_ERROR { + t.Errorf("Expected health HEALTH_ERROR, got %s", sdpItem.GetHealth()) + } + }) + + t.Run("List", func(t *testing.T) { + resource1 := createNetworkWatcher("test-network-watcher-1") + resource2 := createNetworkWatcher("test-network-watcher-2") + + mockClient := mocks.NewMockNetworkWatchersClient(ctrl) + mockPager := newMockNetworkWatchersPager(ctrl, []*armnetwork.Watcher{resource1, resource2}) + + mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkNetworkWatcher(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, wrapper.Scopes()[0], true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if item.Validate() != nil { + t.Fatalf("Expected no validation error, got: %v", item.Validate()) + } + } + }) + + t.Run("List_SkipNilName", func(t *testing.T) { + resource1 := createNetworkWatcher("test-network-watcher-1") + resource2 := &armnetwork.Watcher{ + Name: nil, // nil name should be skipped + } + + mockClient := mocks.NewMockNetworkWatchersClient(ctrl) + mockPager := newMockNetworkWatchersPager(ctrl, []*armnetwork.Watcher{resource1, resource2}) + + mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkNetworkWatcher(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, wrapper.Scopes()[0], true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (skipping nil name), got: %d", len(sdpItems)) + } + }) + + t.Run("ListStream", func(t *testing.T) { + resource1 := createNetworkWatcher("test-network-watcher-1") + resource2 := createNetworkWatcher("test-network-watcher-2") + + mockClient := mocks.NewMockNetworkWatchersClient(ctrl) + mockPager := newMockNetworkWatchersPager(ctrl, []*armnetwork.Watcher{resource1, resource2}) + + mockClient.EXPECT().NewListPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewNetworkNetworkWatcher(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + + var errs []error + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + listStreamable, ok := adapter.(discovery.ListStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support ListStream operation") + } + + listStreamable.ListStream(ctx, wrapper.Scopes()[0], true, stream) + wg.Wait() + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + }) + + t.Run("ErrorHandling", func(t *testing.T) { + expectedErr := errors.New("resource not found") + + mockClient := mocks.NewMockNetworkWatchersClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent", nil).Return( + armnetwork.WatchersClientGetResponse{}, expectedErr) + + wrapper := manual.NewNetworkNetworkWatcher(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent", true) + if qErr == nil { + t.Error("Expected error when getting non-existent resource, but got nil") + } + }) + + t.Run("GetWithEmptyName", func(t *testing.T) { + mockClient := mocks.NewMockNetworkWatchersClient(ctrl) + + wrapper := manual.NewNetworkNetworkWatcher(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "", true) + if qErr == nil { + t.Error("Expected error when getting resource with empty name, but got nil") + } + }) +} + +func createNetworkWatcher(name string) *armnetwork.Watcher { + provisioningState := armnetwork.ProvisioningStateSucceeded + return &armnetwork.Watcher{ + ID: new(string), + Name: &name, + Type: new(string), + Location: new(string), + Tags: map[string]*string{ + "env": new(string), + }, + Properties: &armnetwork.WatcherPropertiesFormat{ + ProvisioningState: &provisioningState, + }, + } +} + +func createNetworkWatcherWithProvisioningState(name string, state armnetwork.ProvisioningState) *armnetwork.Watcher { + return &armnetwork.Watcher{ + ID: new(string), + Name: &name, + Type: new(string), + Location: new(string), + Tags: map[string]*string{ + "env": new(string), + }, + Properties: &armnetwork.WatcherPropertiesFormat{ + ProvisioningState: &state, + }, + } +} + +type mockNetworkWatchersPager struct { + ctrl *gomock.Controller + items []*armnetwork.Watcher + index int + more bool +} + +func newMockNetworkWatchersPager(ctrl *gomock.Controller, items []*armnetwork.Watcher) clients.NetworkWatchersPager { + return &mockNetworkWatchersPager{ + ctrl: ctrl, + items: items, + index: 0, + more: len(items) > 0, + } +} + +func (m *mockNetworkWatchersPager) More() bool { + return m.more +} + +func (m *mockNetworkWatchersPager) NextPage(ctx context.Context) (armnetwork.WatchersClientListResponse, error) { + if m.index >= len(m.items) { + m.more = false + return armnetwork.WatchersClientListResponse{ + WatcherListResult: armnetwork.WatcherListResult{ + Value: []*armnetwork.Watcher{}, + }, + }, nil + } + + item := m.items[m.index] + m.index++ + m.more = m.index < len(m.items) + + return armnetwork.WatchersClientListResponse{ + WatcherListResult: armnetwork.WatcherListResult{ + Value: []*armnetwork.Watcher{item}, + }, + }, nil +} diff --git a/sources/azure/manual/network-private-link-service.go b/sources/azure/manual/network-private-link-service.go new file mode 100644 index 00000000..0cd92611 --- /dev/null +++ b/sources/azure/manual/network-private-link-service.go @@ -0,0 +1,369 @@ +package manual + +import ( + "context" + "errors" + "strings" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +var NetworkPrivateLinkServiceLookupByName = shared.NewItemTypeLookup("name", azureshared.NetworkPrivateLinkService) + +type networkPrivateLinkServiceWrapper struct { + client clients.PrivateLinkServicesClient + + *azureshared.MultiResourceGroupBase +} + +func NewNetworkPrivateLinkService(client clients.PrivateLinkServicesClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { + return &networkPrivateLinkServiceWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_NETWORK, + azureshared.NetworkPrivateLinkService, + ), + } +} + +func (n networkPrivateLinkServiceWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + pager := n.client.List(rgScope.ResourceGroup) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + for _, pls := range page.Value { + if pls.Name == nil { + continue + } + item, sdpErr := n.azurePrivateLinkServiceToSDPItem(pls, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + return items, nil +} + +func (n networkPrivateLinkServiceWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + pager := n.client.List(rgScope.ResourceGroup) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, n.Type())) + return + } + for _, pls := range page.Value { + if pls.Name == nil { + continue + } + item, sdpErr := n.azurePrivateLinkServiceToSDPItem(pls, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (n networkPrivateLinkServiceWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) != 1 { + return nil, azureshared.QueryError(errors.New("query must be a private link service name"), scope, n.Type()) + } + serviceName := queryParts[0] + if serviceName == "" { + return nil, azureshared.QueryError(errors.New("private link service name cannot be empty"), scope, n.Type()) + } + + rgScope, err := n.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + resp, err := n.client.Get(ctx, rgScope.ResourceGroup, serviceName) + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + return n.azurePrivateLinkServiceToSDPItem(&resp.PrivateLinkService, scope) +} + +func (n networkPrivateLinkServiceWrapper) azurePrivateLinkServiceToSDPItem(pls *armnetwork.PrivateLinkService, scope string) (*sdp.Item, *sdp.QueryError) { + if pls.Name == nil { + return nil, azureshared.QueryError(errors.New("private link service name is nil"), scope, n.Type()) + } + attributes, err := shared.ToAttributesWithExclude(pls, "tags") + if err != nil { + return nil, azureshared.QueryError(err, scope, n.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.NetworkPrivateLinkService.String(), + UniqueAttribute: "name", + Attributes: attributes, + Scope: scope, + Tags: azureshared.ConvertAzureTags(pls.Tags), + } + + // Health status from ProvisioningState + if pls.Properties != nil && pls.Properties.ProvisioningState != nil { + switch *pls.Properties.ProvisioningState { + case armnetwork.ProvisioningStateSucceeded: + sdpItem.Health = sdp.Health_HEALTH_OK.Enum() + case armnetwork.ProvisioningStateCreating, armnetwork.ProvisioningStateUpdating, armnetwork.ProvisioningStateDeleting: + sdpItem.Health = sdp.Health_HEALTH_PENDING.Enum() + case armnetwork.ProvisioningStateFailed, armnetwork.ProvisioningStateCanceled: + sdpItem.Health = sdp.Health_HEALTH_ERROR.Enum() + default: + sdpItem.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + // Link to Custom Location when ExtendedLocation.Name is a custom location resource ID + if pls.ExtendedLocation != nil && pls.ExtendedLocation.Name != nil { + customLocationID := *pls.ExtendedLocation.Name + if strings.Contains(customLocationID, "customLocations") { + customLocationName := azureshared.ExtractResourceName(customLocationID) + if customLocationName != "" { + linkedScope := azureshared.ExtractScopeFromResourceID(customLocationID) + if linkedScope == "" { + linkedScope = scope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.ExtendedLocationCustomLocation.String(), + Method: sdp.QueryMethod_GET, + Query: customLocationName, + Scope: linkedScope, + }, + }) + } + } + } + + if pls.Properties != nil { + // Link to IPConfigurations[].Properties.Subnet and PrivateIPAddress + if pls.Properties.IPConfigurations != nil { + for _, ipConfig := range pls.Properties.IPConfigurations { + if ipConfig == nil || ipConfig.Properties == nil { + continue + } + // Link to Subnet and VirtualNetwork + if ipConfig.Properties.Subnet != nil && ipConfig.Properties.Subnet.ID != nil { + subnetParams := azureshared.ExtractPathParamsFromResourceID(*ipConfig.Properties.Subnet.ID, []string{"virtualNetworks", "subnets"}) + if len(subnetParams) >= 2 { + vnetName, subnetName := subnetParams[0], subnetParams[1] + linkedScope := azureshared.ExtractScopeFromResourceID(*ipConfig.Properties.Subnet.ID) + if linkedScope == "" { + linkedScope = scope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkSubnet.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(vnetName, subnetName), + Scope: linkedScope, + }, + }) + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkVirtualNetwork.String(), + Method: sdp.QueryMethod_GET, + Query: vnetName, + Scope: linkedScope, + }, + }) + } + } + // Link to PrivateIPAddress + if ipConfig.Properties.PrivateIPAddress != nil && *ipConfig.Properties.PrivateIPAddress != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *ipConfig.Properties.PrivateIPAddress, + Scope: "global", + }, + }) + } + } + } + + // Link to LoadBalancerFrontendIPConfigurations + if pls.Properties.LoadBalancerFrontendIPConfigurations != nil { + for _, lbFrontendIPConfig := range pls.Properties.LoadBalancerFrontendIPConfigurations { + if lbFrontendIPConfig == nil || lbFrontendIPConfig.ID == nil { + continue + } + params := azureshared.ExtractPathParamsFromResourceID(*lbFrontendIPConfig.ID, []string{"loadBalancers", "frontendIPConfigurations"}) + if len(params) >= 2 { + lbName, frontendIPConfigName := params[0], params[1] + linkedScope := azureshared.ExtractScopeFromResourceID(*lbFrontendIPConfig.ID) + if linkedScope == "" { + linkedScope = scope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkLoadBalancerFrontendIPConfiguration.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(lbName, frontendIPConfigName), + Scope: linkedScope, + }, + }) + // Also link to the parent LoadBalancer + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkLoadBalancer.String(), + Method: sdp.QueryMethod_GET, + Query: lbName, + Scope: linkedScope, + }, + }) + } + } + } + + // Link to NetworkInterfaces (read-only array) + if pls.Properties.NetworkInterfaces != nil { + for _, iface := range pls.Properties.NetworkInterfaces { + if iface == nil || iface.ID == nil { + continue + } + nicName := azureshared.ExtractResourceName(*iface.ID) + if nicName != "" { + linkedScope := azureshared.ExtractScopeFromResourceID(*iface.ID) + if linkedScope == "" { + linkedScope = scope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkNetworkInterface.String(), + Method: sdp.QueryMethod_GET, + Query: nicName, + Scope: linkedScope, + }, + }) + } + } + } + + // Link to PrivateEndpointConnections[].PrivateEndpoint + if pls.Properties.PrivateEndpointConnections != nil { + for _, peConn := range pls.Properties.PrivateEndpointConnections { + if peConn == nil || peConn.Properties == nil || peConn.Properties.PrivateEndpoint == nil || peConn.Properties.PrivateEndpoint.ID == nil { + continue + } + peName := azureshared.ExtractResourceName(*peConn.Properties.PrivateEndpoint.ID) + if peName != "" { + linkedScope := azureshared.ExtractScopeFromResourceID(*peConn.Properties.PrivateEndpoint.ID) + if linkedScope == "" { + linkedScope = scope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.NetworkPrivateEndpoint.String(), + Method: sdp.QueryMethod_GET, + Query: peName, + Scope: linkedScope, + }, + }) + } + } + } + + // Link to Fqdns as DNS names + if pls.Properties.Fqdns != nil { + for _, fqdn := range pls.Properties.Fqdns { + if fqdn != nil && *fqdn != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkDNS.String(), + Method: sdp.QueryMethod_SEARCH, + Query: *fqdn, + Scope: "global", + }, + }) + } + } + } + + // Link to DestinationIPAddress + if pls.Properties.DestinationIPAddress != nil && *pls.Properties.DestinationIPAddress != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkIP.String(), + Method: sdp.QueryMethod_GET, + Query: *pls.Properties.DestinationIPAddress, + Scope: "global", + }, + }) + } + + // Link to Alias (read-only DNS-resolvable name for the private link service) + if pls.Properties.Alias != nil && *pls.Properties.Alias != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: stdlib.NetworkDNS.String(), + Method: sdp.QueryMethod_SEARCH, + Query: *pls.Properties.Alias, + Scope: "global", + }, + }) + } + } + + return sdpItem, nil +} + +func (n networkPrivateLinkServiceWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + NetworkPrivateLinkServiceLookupByName, + } +} + +func (n networkPrivateLinkServiceWrapper) PotentialLinks() map[shared.ItemType]bool { + return shared.NewItemTypesSet( + azureshared.NetworkSubnet, + azureshared.NetworkVirtualNetwork, + azureshared.NetworkLoadBalancerFrontendIPConfiguration, + azureshared.NetworkLoadBalancer, + azureshared.NetworkNetworkInterface, + azureshared.NetworkPrivateEndpoint, + azureshared.ExtendedLocationCustomLocation, + stdlib.NetworkIP, + stdlib.NetworkDNS, + ) +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftnetwork +func (n networkPrivateLinkServiceWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Network/privateLinkServices/read", + } +} + +func (n networkPrivateLinkServiceWrapper) PredefinedRole() string { + return "Network Contributor" +} diff --git a/sources/azure/manual/network-private-link-service_test.go b/sources/azure/manual/network-private-link-service_test.go new file mode 100644 index 00000000..829654e7 --- /dev/null +++ b/sources/azure/manual/network-private-link-service_test.go @@ -0,0 +1,456 @@ +package manual_test + +import ( + "context" + "errors" + "fmt" + "reflect" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" + "github.com/overmindtech/cli/sources/stdlib" +) + +func TestNetworkPrivateLinkService(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + + t.Run("Get", func(t *testing.T) { + plsName := "test-pls" + pls := createAzurePrivateLinkService(plsName, subscriptionID, resourceGroup) + + mockClient := mocks.NewMockPrivateLinkServicesClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, plsName).Return( + armnetwork.PrivateLinkServicesClientGetResponse{ + PrivateLinkService: *pls, + }, nil) + + wrapper := manual.NewNetworkPrivateLinkService(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], plsName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.NetworkPrivateLinkService.String() { + t.Errorf("Expected type %s, got %s", azureshared.NetworkPrivateLinkService, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if sdpItem.UniqueAttributeValue() != plsName { + t.Errorf("Expected unique attribute value %s, got %s", plsName, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetTags()["env"] != "test" { + t.Errorf("Expected tag 'env=test', got: %v", sdpItem.GetTags()["env"]) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + ExpectedType: azureshared.NetworkSubnet.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: shared.CompositeLookupKey("test-vnet", "test-subnet"), + ExpectedScope: fmt.Sprintf("%s.%s", subscriptionID, resourceGroup), + }, + { + ExpectedType: azureshared.NetworkVirtualNetwork.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-vnet", + ExpectedScope: fmt.Sprintf("%s.%s", subscriptionID, resourceGroup), + }, + { + ExpectedType: stdlib.NetworkIP.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "10.0.0.100", + ExpectedScope: "global", + }, + { + ExpectedType: azureshared.NetworkLoadBalancerFrontendIPConfiguration.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: shared.CompositeLookupKey("test-lb", "test-frontend-ip"), + ExpectedScope: fmt.Sprintf("%s.%s", subscriptionID, resourceGroup), + }, + { + ExpectedType: azureshared.NetworkLoadBalancer.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-lb", + ExpectedScope: fmt.Sprintf("%s.%s", subscriptionID, resourceGroup), + }, + { + ExpectedType: azureshared.NetworkNetworkInterface.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-nic", + ExpectedScope: fmt.Sprintf("%s.%s", subscriptionID, resourceGroup), + }, + { + ExpectedType: azureshared.NetworkPrivateEndpoint.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-pe", + ExpectedScope: fmt.Sprintf("%s.%s", subscriptionID, resourceGroup), + }, + { + ExpectedType: stdlib.NetworkDNS.String(), + ExpectedMethod: sdp.QueryMethod_SEARCH, + ExpectedQuery: "pls.example.com", + ExpectedScope: "global", + }, + { + ExpectedType: stdlib.NetworkIP.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "10.0.0.200", + ExpectedScope: "global", + }, + { + ExpectedType: stdlib.NetworkDNS.String(), + ExpectedMethod: sdp.QueryMethod_SEARCH, + ExpectedQuery: "test-pls.abc123.westus2.azure.privatelinkservice", + ExpectedScope: "global", + }, + { + ExpectedType: azureshared.ExtendedLocationCustomLocation.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-custom-location", + ExpectedScope: fmt.Sprintf("%s.%s", subscriptionID, resourceGroup), + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("Get_EmptyName", func(t *testing.T) { + mockClient := mocks.NewMockPrivateLinkServicesClient(ctrl) + + wrapper := manual.NewNetworkPrivateLinkService(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "", true) + if qErr == nil { + t.Error("Expected error when getting private link service with empty name, but got nil") + } + }) + + t.Run("List", func(t *testing.T) { + pls1 := createAzurePrivateLinkService("test-pls-1", subscriptionID, resourceGroup) + pls2 := createAzurePrivateLinkService("test-pls-2", subscriptionID, resourceGroup) + + mockClient := mocks.NewMockPrivateLinkServicesClient(ctrl) + mockPager := NewMockPrivateLinkServicesPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armnetwork.PrivateLinkServicesClientListResponse{ + PrivateLinkServiceListResult: armnetwork.PrivateLinkServiceListResult{ + Value: []*armnetwork.PrivateLinkService{pls1, pls2}, + }, + }, nil), + mockPager.EXPECT().More().Return(false), + ) + + mockClient.EXPECT().List(resourceGroup).Return(mockPager) + + wrapper := manual.NewNetworkPrivateLinkService(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, wrapper.Scopes()[0], true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if item.Validate() != nil { + t.Fatalf("Expected no validation error, got: %v", item.Validate()) + } + if item.GetType() != azureshared.NetworkPrivateLinkService.String() { + t.Fatalf("Expected type %s, got: %s", azureshared.NetworkPrivateLinkService, item.GetType()) + } + } + }) + + t.Run("ListStream", func(t *testing.T) { + pls1 := createAzurePrivateLinkService("test-pls-1", subscriptionID, resourceGroup) + pls2 := createAzurePrivateLinkService("test-pls-2", subscriptionID, resourceGroup) + + mockClient := mocks.NewMockPrivateLinkServicesClient(ctrl) + mockPager := NewMockPrivateLinkServicesPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armnetwork.PrivateLinkServicesClientListResponse{ + PrivateLinkServiceListResult: armnetwork.PrivateLinkServiceListResult{ + Value: []*armnetwork.PrivateLinkService{pls1, pls2}, + }, + }, nil), + mockPager.EXPECT().More().Return(false), + ) + + mockClient.EXPECT().List(resourceGroup).Return(mockPager) + + wrapper := manual.NewNetworkPrivateLinkService(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + var errs []error + + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + listStreamable, ok := adapter.(discovery.ListStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support ListStream operation") + } + + listStreamable.ListStream(ctx, wrapper.Scopes()[0], true, stream) + wg.Wait() + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + }) + + t.Run("List_WithNilName", func(t *testing.T) { + pls1 := createAzurePrivateLinkService("test-pls-1", subscriptionID, resourceGroup) + pls2 := &armnetwork.PrivateLinkService{ + Name: nil, + Location: new("eastus"), + Tags: map[string]*string{"env": new("test")}, + Properties: &armnetwork.PrivateLinkServiceProperties{ + ProvisioningState: to.Ptr(armnetwork.ProvisioningStateSucceeded), + }, + } + + mockClient := mocks.NewMockPrivateLinkServicesClient(ctrl) + mockPager := NewMockPrivateLinkServicesPager(ctrl) + + gomock.InOrder( + mockPager.EXPECT().More().Return(true), + mockPager.EXPECT().NextPage(ctx).Return( + armnetwork.PrivateLinkServicesClientListResponse{ + PrivateLinkServiceListResult: armnetwork.PrivateLinkServiceListResult{ + Value: []*armnetwork.PrivateLinkService{pls1, pls2}, + }, + }, nil), + mockPager.EXPECT().More().Return(false), + ) + + mockClient.EXPECT().List(resourceGroup).Return(mockPager) + + wrapper := manual.NewNetworkPrivateLinkService(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, wrapper.Scopes()[0], true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (nil name skipped), got: %d", len(sdpItems)) + } + if sdpItems[0].UniqueAttributeValue() != "test-pls-1" { + t.Errorf("Expected item name 'test-pls-1', got: %s", sdpItems[0].UniqueAttributeValue()) + } + }) + + t.Run("ErrorHandling", func(t *testing.T) { + expectedErr := errors.New("private link service not found") + + mockClient := mocks.NewMockPrivateLinkServicesClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-pls").Return( + armnetwork.PrivateLinkServicesClientGetResponse{}, expectedErr) + + wrapper := manual.NewNetworkPrivateLinkService(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-pls", true) + if qErr == nil { + t.Fatal("Expected error when getting nonexistent private link service, got nil") + } + }) + + t.Run("PotentialLinks", func(t *testing.T) { + mockClient := mocks.NewMockPrivateLinkServicesClient(ctrl) + wrapper := manual.NewNetworkPrivateLinkService(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + w := wrapper.(sources.Wrapper) + potentialLinks := w.PotentialLinks() + if len(potentialLinks) == 0 { + t.Error("Expected PotentialLinks to return at least one link type") + } + if !potentialLinks[azureshared.NetworkSubnet] { + t.Error("Expected PotentialLinks to include NetworkSubnet") + } + if !potentialLinks[azureshared.NetworkVirtualNetwork] { + t.Error("Expected PotentialLinks to include NetworkVirtualNetwork") + } + if !potentialLinks[azureshared.NetworkLoadBalancer] { + t.Error("Expected PotentialLinks to include NetworkLoadBalancer") + } + if !potentialLinks[azureshared.NetworkLoadBalancerFrontendIPConfiguration] { + t.Error("Expected PotentialLinks to include NetworkLoadBalancerFrontendIPConfiguration") + } + if !potentialLinks[azureshared.NetworkNetworkInterface] { + t.Error("Expected PotentialLinks to include NetworkNetworkInterface") + } + if !potentialLinks[azureshared.NetworkPrivateEndpoint] { + t.Error("Expected PotentialLinks to include NetworkPrivateEndpoint") + } + if !potentialLinks[stdlib.NetworkIP] { + t.Error("Expected PotentialLinks to include stdlib.NetworkIP") + } + if !potentialLinks[stdlib.NetworkDNS] { + t.Error("Expected PotentialLinks to include stdlib.NetworkDNS") + } + }) +} + +// MockPrivateLinkServicesPager is a mock for PrivateLinkServicesPager +type MockPrivateLinkServicesPager struct { + ctrl *gomock.Controller + recorder *MockPrivateLinkServicesPagerMockRecorder +} + +type MockPrivateLinkServicesPagerMockRecorder struct { + mock *MockPrivateLinkServicesPager +} + +func NewMockPrivateLinkServicesPager(ctrl *gomock.Controller) *MockPrivateLinkServicesPager { + mock := &MockPrivateLinkServicesPager{ctrl: ctrl} + mock.recorder = &MockPrivateLinkServicesPagerMockRecorder{mock} + return mock +} + +func (m *MockPrivateLinkServicesPager) EXPECT() *MockPrivateLinkServicesPagerMockRecorder { + return m.recorder +} + +func (m *MockPrivateLinkServicesPager) More() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "More") + ret0, _ := ret[0].(bool) + return ret0 +} + +func (mr *MockPrivateLinkServicesPagerMockRecorder) More() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "More", reflect.TypeFor[func() bool]()) +} + +func (m *MockPrivateLinkServicesPager) NextPage(ctx context.Context) (armnetwork.PrivateLinkServicesClientListResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NextPage", ctx) + ret0, _ := ret[0].(armnetwork.PrivateLinkServicesClientListResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +func (mr *MockPrivateLinkServicesPagerMockRecorder) NextPage(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NextPage", reflect.TypeFor[func(ctx context.Context) (armnetwork.PrivateLinkServicesClientListResponse, error)](), ctx) +} + +func createAzurePrivateLinkService(plsName, subscriptionID, resourceGroup string) *armnetwork.PrivateLinkService { + subnetID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/virtualNetworks/test-vnet/subnets/test-subnet", subscriptionID, resourceGroup) + lbFrontendIPID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/loadBalancers/test-lb/frontendIPConfigurations/test-frontend-ip", subscriptionID, resourceGroup) + nicID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/networkInterfaces/test-nic", subscriptionID, resourceGroup) + peID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.Network/privateEndpoints/test-pe", subscriptionID, resourceGroup) + customLocationID := fmt.Sprintf("/subscriptions/%s/resourceGroups/%s/providers/Microsoft.ExtendedLocation/customLocations/test-custom-location", subscriptionID, resourceGroup) + + return &armnetwork.PrivateLinkService{ + Name: new(plsName), + Location: new("eastus"), + ExtendedLocation: &armnetwork.ExtendedLocation{ + Name: new(customLocationID), + }, + Tags: map[string]*string{ + "env": new("test"), + "project": new("testing"), + }, + Properties: &armnetwork.PrivateLinkServiceProperties{ + ProvisioningState: to.Ptr(armnetwork.ProvisioningStateSucceeded), + IPConfigurations: []*armnetwork.PrivateLinkServiceIPConfiguration{ + { + Properties: &armnetwork.PrivateLinkServiceIPConfigurationProperties{ + Subnet: &armnetwork.Subnet{ + ID: new(subnetID), + }, + PrivateIPAddress: new("10.0.0.100"), + }, + }, + }, + LoadBalancerFrontendIPConfigurations: []*armnetwork.FrontendIPConfiguration{ + { + ID: new(lbFrontendIPID), + }, + }, + NetworkInterfaces: []*armnetwork.Interface{ + { + ID: new(nicID), + }, + }, + PrivateEndpointConnections: []*armnetwork.PrivateEndpointConnection{ + { + Properties: &armnetwork.PrivateEndpointConnectionProperties{ + PrivateEndpoint: &armnetwork.PrivateEndpoint{ + ID: new(peID), + }, + }, + }, + }, + Fqdns: []*string{ + new("pls.example.com"), + }, + DestinationIPAddress: new("10.0.0.200"), + Alias: new("test-pls.abc123.westus2.azure.privatelinkservice"), + }, + } +} diff --git a/sources/azure/manual/operational-insights-workspace.go b/sources/azure/manual/operational-insights-workspace.go new file mode 100644 index 00000000..0df39092 --- /dev/null +++ b/sources/azure/manual/operational-insights-workspace.go @@ -0,0 +1,224 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +var OperationalInsightsWorkspaceLookupByName = shared.NewItemTypeLookup("name", azureshared.OperationalInsightsWorkspace) + +type operationalInsightsWorkspaceWrapper struct { + client clients.OperationalInsightsWorkspaceClient + + *azureshared.MultiResourceGroupBase +} + +func NewOperationalInsightsWorkspace(client clients.OperationalInsightsWorkspaceClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.ListableWrapper { + return &operationalInsightsWorkspaceWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_OBSERVABILITY, + azureshared.OperationalInsightsWorkspace, + ), + } +} + +// ref: https://learn.microsoft.com/en-us/rest/api/loganalytics/workspaces/list-by-resource-group +func (c operationalInsightsWorkspaceWrapper) List(ctx context.Context, scope string) ([]*sdp.Item, *sdp.QueryError) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + for _, workspace := range page.Value { + if workspace.Name == nil { + continue + } + item, sdpErr := c.azureWorkspaceToSDPItem(workspace, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +func (c operationalInsightsWorkspaceWrapper) ListStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string) { + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + pager := c.client.NewListByResourceGroupPager(rgScope.ResourceGroup, nil) + + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + + for _, workspace := range page.Value { + if workspace.Name == nil { + continue + } + var sdpErr *sdp.QueryError + var item *sdp.Item + item, sdpErr = c.azureWorkspaceToSDPItem(workspace, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +// ref: https://learn.microsoft.com/en-us/rest/api/loganalytics/workspaces/get +func (c operationalInsightsWorkspaceWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, azureshared.QueryError(errors.New("queryParts must be at least 1 and be the workspace name"), scope, c.Type()) + } + workspaceName := queryParts[0] + if workspaceName == "" { + return nil, azureshared.QueryError(errors.New("workspaceName cannot be empty"), scope, c.Type()) + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + result, err := c.client.Get(ctx, rgScope.ResourceGroup, workspaceName, nil) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + return c.azureWorkspaceToSDPItem(&result.Workspace, scope) +} + +func (c operationalInsightsWorkspaceWrapper) azureWorkspaceToSDPItem(workspace *armoperationalinsights.Workspace, scope string) (*sdp.Item, *sdp.QueryError) { + if workspace.Name == nil { + return nil, azureshared.QueryError(errors.New("workspace name is nil"), scope, c.Type()) + } + attributes, err := shared.ToAttributesWithExclude(workspace, "tags") + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.OperationalInsightsWorkspace.String(), + UniqueAttribute: "name", + Attributes: attributes, + Scope: scope, + Tags: azureshared.ConvertAzureTags(workspace.Tags), + } + + // Health status mapping based on provisioning state + if workspace.Properties != nil && workspace.Properties.ProvisioningState != nil { + switch *workspace.Properties.ProvisioningState { + case armoperationalinsights.WorkspaceEntityStatusSucceeded: + sdpItem.Health = sdp.Health_HEALTH_OK.Enum() + case armoperationalinsights.WorkspaceEntityStatusCreating, + armoperationalinsights.WorkspaceEntityStatusUpdating, + armoperationalinsights.WorkspaceEntityStatusDeleting, + armoperationalinsights.WorkspaceEntityStatusProvisioningAccount: + sdpItem.Health = sdp.Health_HEALTH_PENDING.Enum() + case armoperationalinsights.WorkspaceEntityStatusFailed, + armoperationalinsights.WorkspaceEntityStatusCanceled: + sdpItem.Health = sdp.Health_HEALTH_ERROR.Enum() + default: + sdpItem.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + // Link to Private Link Scope Scoped Resources + // PrivateLinkScopedResources[].ResourceID refers to Azure Monitor Private Link Scope + // scoped resources (microsoft.insights/privateLinkScopes/scopedResources) + if workspace.Properties != nil && workspace.Properties.PrivateLinkScopedResources != nil { + for _, plsr := range workspace.Properties.PrivateLinkScopedResources { + if plsr != nil && plsr.ResourceID != nil { + params := azureshared.ExtractPathParamsFromResourceID(*plsr.ResourceID, []string{"privateLinkScopes", "scopedResources"}) + if len(params) >= 2 && params[0] != "" && params[1] != "" { + scopeName, scopedResourceName := params[0], params[1] + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*plsr.ResourceID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.InsightsPrivateLinkScopeScopedResource.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(scopeName, scopedResourceName), + Scope: linkedScope, + }, + }) + } + } + } + } + + // Link to Cluster (Dedicated Log Analytics cluster) + if workspace.Properties != nil && workspace.Properties.Features != nil && workspace.Properties.Features.ClusterResourceID != nil { + clusterName := azureshared.ExtractResourceName(*workspace.Properties.Features.ClusterResourceID) + if clusterName != "" { + linkedScope := scope + if extractedScope := azureshared.ExtractScopeFromResourceID(*workspace.Properties.Features.ClusterResourceID); extractedScope != "" { + linkedScope = extractedScope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.OperationalInsightsCluster.String(), + Method: sdp.QueryMethod_GET, + Query: clusterName, + Scope: linkedScope, + }, + }) + } + } + + return sdpItem, nil +} + +func (c operationalInsightsWorkspaceWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + OperationalInsightsWorkspaceLookupByName, + } +} + +func (c operationalInsightsWorkspaceWrapper) PotentialLinks() map[shared.ItemType]bool { + return shared.NewItemTypesSet( + azureshared.InsightsPrivateLinkScopeScopedResource, + azureshared.OperationalInsightsCluster, + ) +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftoperationalinsights +func (c operationalInsightsWorkspaceWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.OperationalInsights/workspaces/read", + } +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/built-in-roles/monitor#log-analytics-reader +func (c operationalInsightsWorkspaceWrapper) PredefinedRole() string { + return "Log Analytics Reader" +} diff --git a/sources/azure/manual/operational-insights-workspace_test.go b/sources/azure/manual/operational-insights-workspace_test.go new file mode 100644 index 00000000..84b4e0c7 --- /dev/null +++ b/sources/azure/manual/operational-insights-workspace_test.go @@ -0,0 +1,491 @@ +package manual_test + +import ( + "context" + "errors" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" +) + +func TestOperationalInsightsWorkspace(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + + t.Run("Get", func(t *testing.T) { + workspaceName := "test-workspace" + workspace := createAzureWorkspace(workspaceName, subscriptionID, resourceGroup) + + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, workspaceName, nil).Return( + armoperationalinsights.WorkspacesClientGetResponse{ + Workspace: *workspace, + }, nil) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], workspaceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.OperationalInsightsWorkspace.String() { + t.Errorf("Expected type %s, got %s", azureshared.OperationalInsightsWorkspace, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "name" { + t.Errorf("Expected unique attribute 'name', got %s", sdpItem.GetUniqueAttribute()) + } + + if sdpItem.UniqueAttributeValue() != workspaceName { + t.Errorf("Expected unique attribute value %s, got %s", workspaceName, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetTags()["env"] != "test" { + t.Errorf("Expected tag 'env=test', got: %v", sdpItem.GetTags()["env"]) + } + + // Verify health status based on provisioning state + if sdpItem.GetHealth() != sdp.Health_HEALTH_OK { + t.Errorf("Expected health OK, got %s", sdpItem.GetHealth()) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + // Properties.PrivateLinkScopedResources[0].ResourceID + ExpectedType: azureshared.InsightsPrivateLinkScopeScopedResource.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: shared.CompositeLookupKey("test-pls", "test-scoped-resource"), + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + // Properties.Features.ClusterResourceID + ExpectedType: azureshared.OperationalInsightsCluster.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "test-cluster", + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithCrossResourceGroupLinks", func(t *testing.T) { + workspaceName := "test-workspace-cross-rg" + workspace := createAzureWorkspaceWithCrossResourceGroupLinks(workspaceName, subscriptionID) + + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, workspaceName, nil).Return( + armoperationalinsights.WorkspacesClientGetResponse{ + Workspace: *workspace, + }, nil) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], workspaceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + // Verify that links use the correct scope from different resource groups + foundClusterLink := false + foundPLSScopedResourceLink := false + expectedScope := subscriptionID + ".other-rg" + + for _, link := range sdpItem.GetLinkedItemQueries() { + if link.GetQuery().GetType() == azureshared.OperationalInsightsCluster.String() { + foundClusterLink = true + if link.GetQuery().GetScope() != expectedScope { + t.Errorf("Expected Cluster scope %s, got %s", expectedScope, link.GetQuery().GetScope()) + } + } + if link.GetQuery().GetType() == azureshared.InsightsPrivateLinkScopeScopedResource.String() { + foundPLSScopedResourceLink = true + if link.GetQuery().GetScope() != expectedScope { + t.Errorf("Expected Private Link Scope Scoped Resource scope %s, got %s", expectedScope, link.GetQuery().GetScope()) + } + expectedQuery := shared.CompositeLookupKey("test-pls-cross", "test-scoped-resource-cross") + if link.GetQuery().GetQuery() != expectedQuery { + t.Errorf("Expected query %s, got %s", expectedQuery, link.GetQuery().GetQuery()) + } + } + } + + if !foundClusterLink { + t.Error("Expected to find Operational Insights Cluster link") + } + if !foundPLSScopedResourceLink { + t.Error("Expected to find Private Link Scope Scoped Resource link") + } + }) + + t.Run("GetWithoutLinks", func(t *testing.T) { + workspaceName := "test-workspace-no-links" + workspace := createAzureWorkspaceWithoutLinks(workspaceName) + + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, workspaceName, nil).Return( + armoperationalinsights.WorkspacesClientGetResponse{ + Workspace: *workspace, + }, nil) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], workspaceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if len(sdpItem.GetLinkedItemQueries()) != 0 { + t.Errorf("Expected no linked queries, got %d", len(sdpItem.GetLinkedItemQueries())) + } + }) + + t.Run("GetWithDifferentHealthStates", func(t *testing.T) { + healthTests := []struct { + state armoperationalinsights.WorkspaceEntityStatus + expectedHealth sdp.Health + }{ + {armoperationalinsights.WorkspaceEntityStatusSucceeded, sdp.Health_HEALTH_OK}, + {armoperationalinsights.WorkspaceEntityStatusCreating, sdp.Health_HEALTH_PENDING}, + {armoperationalinsights.WorkspaceEntityStatusUpdating, sdp.Health_HEALTH_PENDING}, + {armoperationalinsights.WorkspaceEntityStatusDeleting, sdp.Health_HEALTH_PENDING}, + {armoperationalinsights.WorkspaceEntityStatusProvisioningAccount, sdp.Health_HEALTH_PENDING}, + {armoperationalinsights.WorkspaceEntityStatusFailed, sdp.Health_HEALTH_ERROR}, + {armoperationalinsights.WorkspaceEntityStatusCanceled, sdp.Health_HEALTH_ERROR}, + } + + for _, ht := range healthTests { + t.Run(string(ht.state), func(t *testing.T) { + workspaceName := "test-workspace-" + string(ht.state) + workspace := createAzureWorkspaceWithProvisioningState(workspaceName, ht.state) + + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, workspaceName, nil).Return( + armoperationalinsights.WorkspacesClientGetResponse{ + Workspace: *workspace, + }, nil) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], workspaceName, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetHealth() != ht.expectedHealth { + t.Errorf("Expected health %s for state %s, got %s", ht.expectedHealth, ht.state, sdpItem.GetHealth()) + } + }) + } + }) + + t.Run("List", func(t *testing.T) { + workspace1 := createAzureWorkspace("test-workspace-1", subscriptionID, resourceGroup) + workspace2 := createAzureWorkspace("test-workspace-2", subscriptionID, resourceGroup) + + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + mockPager := newMockOperationalInsightsWorkspacePager(ctrl, []*armoperationalinsights.Workspace{workspace1, workspace2}) + + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, wrapper.Scopes()[0], true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if item.Validate() != nil { + t.Fatalf("Expected no validation error, got: %v", item.Validate()) + } + + if item.GetTags()["env"] != "test" { + t.Fatalf("Expected tag 'env=test', got: %s", item.GetTags()["env"]) + } + } + }) + + t.Run("ListStream", func(t *testing.T) { + workspace1 := createAzureWorkspace("test-workspace-1", subscriptionID, resourceGroup) + workspace2 := createAzureWorkspace("test-workspace-2", subscriptionID, resourceGroup) + + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + mockPager := newMockOperationalInsightsWorkspacePager(ctrl, []*armoperationalinsights.Workspace{workspace1, workspace2}) + + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + + var errs []error + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + listStreamable, ok := adapter.(discovery.ListStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support ListStream operation") + } + + listStreamable.ListStream(ctx, wrapper.Scopes()[0], true, stream) + wg.Wait() + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + + // Verify adapter doesn't support SearchStream + _, ok = adapter.(discovery.SearchStreamableAdapter) + if ok { + t.Fatalf("Adapter should not support SearchStream operation") + } + }) + + t.Run("ListWithNilName", func(t *testing.T) { + workspace1 := createAzureWorkspace("test-workspace-1", subscriptionID, resourceGroup) + workspaceNilName := &armoperationalinsights.Workspace{ + Name: nil, // nil name should be skipped + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + }, + } + + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + mockPager := newMockOperationalInsightsWorkspacePager(ctrl, []*armoperationalinsights.Workspace{workspace1, workspaceNilName}) + + mockClient.EXPECT().NewListByResourceGroupPager(resourceGroup, nil).Return(mockPager) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + listable, ok := adapter.(discovery.ListableAdapter) + if !ok { + t.Fatalf("Adapter does not support List operation") + } + + sdpItems, err := listable.List(ctx, wrapper.Scopes()[0], true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Should only return 1 item (the one with a name) + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (nil name skipped), got: %d", len(sdpItems)) + } + }) + + t.Run("ErrorHandling", func(t *testing.T) { + expectedErr := errors.New("workspace not found") + + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, "nonexistent-workspace", nil).Return( + armoperationalinsights.WorkspacesClientGetResponse{}, expectedErr) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "nonexistent-workspace", true) + if qErr == nil { + t.Error("Expected error when getting non-existent workspace, but got nil") + } + }) + + t.Run("GetWithEmptyName", func(t *testing.T) { + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], "", true) + if qErr == nil { + t.Error("Expected error when getting workspace with empty name, but got nil") + } + }) + + t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockOperationalInsightsWorkspaceClient(ctrl) + + wrapper := manual.NewOperationalInsightsWorkspace(mockClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + // Test the wrapper's Get method directly with insufficient query parts + _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0]) + if qErr == nil { + t.Error("Expected error when getting workspace with insufficient query parts, but got nil") + } + }) +} + +// createAzureWorkspace creates a mock Azure Log Analytics Workspace for testing +func createAzureWorkspace(workspaceName, subscriptionID, resourceGroup string) *armoperationalinsights.Workspace { + succeededState := armoperationalinsights.WorkspaceEntityStatusSucceeded + retentionDays := int32(30) + return &armoperationalinsights.Workspace{ + Name: new(workspaceName), + Location: new("eastus"), + ID: new("/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.OperationalInsights/workspaces/" + workspaceName), + Type: new("Microsoft.OperationalInsights/workspaces"), + Tags: map[string]*string{ + "env": new("test"), + "project": new("testing"), + }, + Properties: &armoperationalinsights.WorkspaceProperties{ + ProvisioningState: &succeededState, + RetentionInDays: &retentionDays, + Features: &armoperationalinsights.WorkspaceFeatures{ + ClusterResourceID: new("/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.OperationalInsights/clusters/test-cluster"), + }, + PrivateLinkScopedResources: []*armoperationalinsights.PrivateLinkScopedResource{ + { + // Note: ResourceID refers to microsoft.insights/privateLinkScopes/scopedResources + ResourceID: new("/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/microsoft.insights/privateLinkScopes/test-pls/scopedResources/test-scoped-resource"), + ScopeID: new("test-scope-id"), + }, + }, + }, + } +} + +// createAzureWorkspaceWithCrossResourceGroupLinks creates a mock Workspace with links to resources in different resource groups +func createAzureWorkspaceWithCrossResourceGroupLinks(workspaceName, subscriptionID string) *armoperationalinsights.Workspace { + succeededState := armoperationalinsights.WorkspaceEntityStatusSucceeded + return &armoperationalinsights.Workspace{ + Name: new(workspaceName), + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + }, + Properties: &armoperationalinsights.WorkspaceProperties{ + ProvisioningState: &succeededState, + Features: &armoperationalinsights.WorkspaceFeatures{ + ClusterResourceID: new("/subscriptions/" + subscriptionID + "/resourceGroups/other-rg/providers/Microsoft.OperationalInsights/clusters/test-cluster-cross-rg"), + }, + PrivateLinkScopedResources: []*armoperationalinsights.PrivateLinkScopedResource{ + { + ResourceID: new("/subscriptions/" + subscriptionID + "/resourceGroups/other-rg/providers/microsoft.insights/privateLinkScopes/test-pls-cross/scopedResources/test-scoped-resource-cross"), + ScopeID: new("test-scope-id"), + }, + }, + }, + } +} + +// createAzureWorkspaceWithoutLinks creates a mock Workspace without any linked resources +func createAzureWorkspaceWithoutLinks(workspaceName string) *armoperationalinsights.Workspace { + succeededState := armoperationalinsights.WorkspaceEntityStatusSucceeded + return &armoperationalinsights.Workspace{ + Name: new(workspaceName), + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + }, + Properties: &armoperationalinsights.WorkspaceProperties{ + ProvisioningState: &succeededState, + // No PrivateLinkScopedResources + }, + } +} + +// createAzureWorkspaceWithProvisioningState creates a mock Workspace with a specific provisioning state +func createAzureWorkspaceWithProvisioningState(workspaceName string, state armoperationalinsights.WorkspaceEntityStatus) *armoperationalinsights.Workspace { + return &armoperationalinsights.Workspace{ + Name: new(workspaceName), + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + }, + Properties: &armoperationalinsights.WorkspaceProperties{ + ProvisioningState: &state, + }, + } +} + +// mockOperationalInsightsWorkspacePager is a simple mock implementation of the Pager interface for testing +type mockOperationalInsightsWorkspacePager struct { + ctrl *gomock.Controller + items []*armoperationalinsights.Workspace + index int + more bool +} + +func newMockOperationalInsightsWorkspacePager(ctrl *gomock.Controller, items []*armoperationalinsights.Workspace) clients.OperationalInsightsWorkspacePager { + return &mockOperationalInsightsWorkspacePager{ + ctrl: ctrl, + items: items, + index: 0, + more: len(items) > 0, + } +} + +func (m *mockOperationalInsightsWorkspacePager) More() bool { + return m.more +} + +func (m *mockOperationalInsightsWorkspacePager) NextPage(ctx context.Context) (armoperationalinsights.WorkspacesClientListByResourceGroupResponse, error) { + if m.index >= len(m.items) { + m.more = false + return armoperationalinsights.WorkspacesClientListByResourceGroupResponse{ + WorkspaceListResult: armoperationalinsights.WorkspaceListResult{ + Value: []*armoperationalinsights.Workspace{}, + }, + }, nil + } + + item := m.items[m.index] + m.index++ + m.more = m.index < len(m.items) + + return armoperationalinsights.WorkspacesClientListByResourceGroupResponse{ + WorkspaceListResult: armoperationalinsights.WorkspaceListResult{ + Value: []*armoperationalinsights.Workspace{item}, + }, + }, nil +} diff --git a/sources/azure/manual/sql-server-failover-group.go b/sources/azure/manual/sql-server-failover-group.go new file mode 100644 index 00000000..b189c779 --- /dev/null +++ b/sources/azure/manual/sql-server-failover-group.go @@ -0,0 +1,313 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +var SQLServerFailoverGroupLookupByName = shared.NewItemTypeLookup("name", azureshared.SQLServerFailoverGroup) + +type sqlServerFailoverGroupWrapper struct { + client clients.SqlFailoverGroupsClient + + *azureshared.MultiResourceGroupBase +} + +func NewSqlServerFailoverGroup(client clients.SqlFailoverGroupsClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { + return &sqlServerFailoverGroupWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, + azureshared.SQLServerFailoverGroup, + ), + } +} + +// Get retrieves a specific failover group by server name and failover group name +// ref: https://learn.microsoft.com/en-us/rest/api/sql/failover-groups/get +func (c sqlServerFailoverGroupWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 2 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Get requires 2 query parts: serverName and failoverGroupName", + Scope: scope, + ItemType: c.Type(), + } + } + serverName := queryParts[0] + if serverName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "serverName cannot be empty", + Scope: scope, + ItemType: c.Type(), + } + } + failoverGroupName := queryParts[1] + if failoverGroupName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "failoverGroupName cannot be empty", + Scope: scope, + ItemType: c.Type(), + } + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + resp, err := c.client.Get(ctx, rgScope.ResourceGroup, serverName, failoverGroupName) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + return c.azureFailoverGroupToSDPItem(&resp.FailoverGroup, serverName, scope) +} + +// Search retrieves all failover groups for a given server +// ref: https://learn.microsoft.com/en-us/rest/api/sql/failover-groups/list-by-server +func (c sqlServerFailoverGroupWrapper) Search(ctx context.Context, scope string, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Search requires 1 query part: serverName", + Scope: scope, + ItemType: c.Type(), + } + } + serverName := queryParts[0] + if serverName == "" { + return nil, azureshared.QueryError(errors.New("serverName cannot be empty"), scope, c.Type()) + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + pager := c.client.ListByServer(ctx, rgScope.ResourceGroup, serverName) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + for _, failoverGroup := range page.Value { + if failoverGroup.Name == nil { + continue + } + item, sdpErr := c.azureFailoverGroupToSDPItem(failoverGroup, serverName, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +// SearchStream streams all failover groups for a given server +func (c sqlServerFailoverGroupWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { + if len(queryParts) < 1 { + stream.SendError(azureshared.QueryError(errors.New("Search requires 1 query part: serverName"), scope, c.Type())) + return + } + serverName := queryParts[0] + if serverName == "" { + stream.SendError(azureshared.QueryError(errors.New("serverName cannot be empty"), scope, c.Type())) + return + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + pager := c.client.ListByServer(ctx, rgScope.ResourceGroup, serverName) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + for _, failoverGroup := range page.Value { + if failoverGroup.Name == nil { + continue + } + item, sdpErr := c.azureFailoverGroupToSDPItem(failoverGroup, serverName, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (c sqlServerFailoverGroupWrapper) azureFailoverGroupToSDPItem(failoverGroup *armsql.FailoverGroup, serverName, scope string) (*sdp.Item, *sdp.QueryError) { + if failoverGroup.Name == nil { + return nil, azureshared.QueryError(errors.New("failover group name is nil"), scope, c.Type()) + } + failoverGroupName := *failoverGroup.Name + + attributes, err := shared.ToAttributesWithExclude(failoverGroup, "tags") + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(serverName, failoverGroupName)) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.SQLServerFailoverGroup.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attributes, + Scope: scope, + Tags: azureshared.ConvertAzureTags(failoverGroup.Tags), + } + + // Health mapping based on replication state + if failoverGroup.Properties != nil && failoverGroup.Properties.ReplicationState != nil { + switch *failoverGroup.Properties.ReplicationState { + case "CATCH_UP", "PENDING", "SEEDING": + sdpItem.Health = sdp.Health_HEALTH_PENDING.Enum() + case "SUSPENDED": + sdpItem.Health = sdp.Health_HEALTH_WARNING.Enum() + case "": + sdpItem.Health = sdp.Health_HEALTH_OK.Enum() + default: + sdpItem.Health = sdp.Health_HEALTH_UNKNOWN.Enum() + } + } + + // Link back to the parent SQL Server + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.SQLServer.String(), + Method: sdp.QueryMethod_GET, + Query: serverName, + Scope: scope, + }, + }) + + if failoverGroup.Properties != nil { + // Link to partner servers + if failoverGroup.Properties.PartnerServers != nil { + for _, partner := range failoverGroup.Properties.PartnerServers { + if partner != nil && partner.ID != nil && *partner.ID != "" { + partnerServerName := azureshared.ExtractResourceName(*partner.ID) + if partnerServerName != "" { + linkedScope := azureshared.ExtractScopeFromResourceID(*partner.ID) + if linkedScope == "" { + linkedScope = scope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.SQLServer.String(), + Method: sdp.QueryMethod_GET, + Query: partnerServerName, + Scope: linkedScope, + }, + }) + } + } + } + } + + // Link to databases in the failover group + if failoverGroup.Properties.Databases != nil { + for _, databaseID := range failoverGroup.Properties.Databases { + if databaseID != nil && *databaseID != "" { + // Extract server name and database name from the database resource ID + params := azureshared.ExtractPathParamsFromResourceID(*databaseID, []string{"servers", "databases"}) + if len(params) >= 2 { + dbServerName := params[0] + dbName := params[1] + linkedScope := azureshared.ExtractScopeFromResourceID(*databaseID) + if linkedScope == "" { + linkedScope = scope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.SQLDatabase.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(dbServerName, dbName), + Scope: linkedScope, + }, + }) + } + } + } + } + + // Link to read-only endpoint target server if specified + if failoverGroup.Properties.ReadOnlyEndpoint != nil && failoverGroup.Properties.ReadOnlyEndpoint.TargetServer != nil && *failoverGroup.Properties.ReadOnlyEndpoint.TargetServer != "" { + // TargetServer is a resource ID + targetServerName := azureshared.ExtractResourceName(*failoverGroup.Properties.ReadOnlyEndpoint.TargetServer) + if targetServerName != "" { + linkedScope := azureshared.ExtractScopeFromResourceID(*failoverGroup.Properties.ReadOnlyEndpoint.TargetServer) + if linkedScope == "" { + linkedScope = scope + } + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.SQLServer.String(), + Method: sdp.QueryMethod_GET, + Query: targetServerName, + Scope: linkedScope, + }, + }) + } + } + } + + return sdpItem, nil +} + +func (c sqlServerFailoverGroupWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + SQLServerLookupByName, + SQLServerFailoverGroupLookupByName, + } +} + +func (c sqlServerFailoverGroupWrapper) SearchLookups() []sources.ItemTypeLookups { + return []sources.ItemTypeLookups{ + { + SQLServerLookupByName, + }, + } +} + +func (c sqlServerFailoverGroupWrapper) PotentialLinks() map[shared.ItemType]bool { + return shared.NewItemTypesSet( + azureshared.SQLServer, + azureshared.SQLDatabase, + ) +} + +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftsql +func (c sqlServerFailoverGroupWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Sql/servers/failoverGroups/read", + } +} + +func (c sqlServerFailoverGroupWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/sql-server-failover-group_test.go b/sources/azure/manual/sql-server-failover-group_test.go new file mode 100644 index 00000000..68f0eeb9 --- /dev/null +++ b/sources/azure/manual/sql-server-failover-group_test.go @@ -0,0 +1,470 @@ +package manual_test + +import ( + "context" + "errors" + "slices" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" +) + +// mockSqlFailoverGroupsPager is a simple mock implementation of SqlFailoverGroupsPager +type mockSqlFailoverGroupsPager struct { + pages []armsql.FailoverGroupsClientListByServerResponse + index int +} + +func (m *mockSqlFailoverGroupsPager) More() bool { + return m.index < len(m.pages) +} + +func (m *mockSqlFailoverGroupsPager) NextPage(ctx context.Context) (armsql.FailoverGroupsClientListByServerResponse, error) { + if m.index >= len(m.pages) { + return armsql.FailoverGroupsClientListByServerResponse{}, errors.New("no more pages") + } + page := m.pages[m.index] + m.index++ + return page, nil +} + +// errorSqlFailoverGroupsPager is a mock pager that always returns an error +type errorSqlFailoverGroupsPager struct{} + +func (e *errorSqlFailoverGroupsPager) More() bool { + return true +} + +func (e *errorSqlFailoverGroupsPager) NextPage(ctx context.Context) (armsql.FailoverGroupsClientListByServerResponse, error) { + return armsql.FailoverGroupsClientListByServerResponse{}, errors.New("pager error") +} + +// testSqlFailoverGroupsClient wraps the mock to implement the correct interface +type testSqlFailoverGroupsClient struct { + *mocks.MockSqlFailoverGroupsClient + pager clients.SqlFailoverGroupsPager +} + +func (t *testSqlFailoverGroupsClient) ListByServer(ctx context.Context, resourceGroupName, serverName string) clients.SqlFailoverGroupsPager { + return t.pager +} + +func TestSqlServerFailoverGroup(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + serverName := "test-server" + failoverGroupName := "test-failover-group" + + t.Run("Get", func(t *testing.T) { + failoverGroup := createAzureSqlServerFailoverGroup(subscriptionID, resourceGroup, serverName, failoverGroupName) + + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, failoverGroupName).Return( + armsql.FailoverGroupsClientGetResponse{ + FailoverGroup: *failoverGroup, + }, nil) + + testClient := &testSqlFailoverGroupsClient{MockSqlFailoverGroupsClient: mockClient} + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, failoverGroupName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.SQLServerFailoverGroup.String() { + t.Errorf("Expected type %s, got %s", azureshared.SQLServerFailoverGroup, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + expectedUniqueAttrValue := shared.CompositeLookupKey(serverName, failoverGroupName) + if sdpItem.UniqueAttributeValue() != expectedUniqueAttrValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueAttrValue, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != subscriptionID+"."+resourceGroup { + t.Errorf("Expected scope %s, got %s", subscriptionID+"."+resourceGroup, sdpItem.GetScope()) + } + + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + // SQLServer link (parent) + ExpectedType: azureshared.SQLServer.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: serverName, + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + // Partner server link + ExpectedType: azureshared.SQLServer.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: "partner-server", + ExpectedScope: subscriptionID + ".partner-rg", + }, + { + // Database link + ExpectedType: azureshared.SQLDatabase.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: shared.CompositeLookupKey(serverName, "test-database"), + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("Get_WithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + testClient := &testSqlFailoverGroupsClient{MockSqlFailoverGroupsClient: mockClient} + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Only provide serverName without failoverGroupName + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) + if qErr == nil { + t.Error("Expected error when providing insufficient query parts, but got nil") + } + }) + + t.Run("GetWithEmptyServerName", func(t *testing.T) { + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + testClient := &testSqlFailoverGroupsClient{MockSqlFailoverGroupsClient: mockClient} + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Provide empty server name and valid failover group name + // Call wrapper.Get directly to get *sdp.QueryError + _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0], "", failoverGroupName) + if qErr == nil { + t.Fatal("Expected error when serverName is empty, but got nil") + } + if qErr.GetErrorString() != "serverName cannot be empty" { + t.Errorf("Expected error string 'serverName cannot be empty', got: %s", qErr.GetErrorString()) + } + }) + + t.Run("GetWithEmptyFailoverGroupName", func(t *testing.T) { + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + testClient := &testSqlFailoverGroupsClient{MockSqlFailoverGroupsClient: mockClient} + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Provide valid server name and empty failover group name + // Call wrapper.Get directly to get *sdp.QueryError + _, qErr := wrapper.Get(ctx, wrapper.Scopes()[0], serverName, "") + if qErr == nil { + t.Fatal("Expected error when failoverGroupName is empty, but got nil") + } + if qErr.GetErrorString() != "failoverGroupName cannot be empty" { + t.Errorf("Expected error string 'failoverGroupName cannot be empty', got: %s", qErr.GetErrorString()) + } + }) + + t.Run("Search", func(t *testing.T) { + failoverGroup1 := createAzureSqlServerFailoverGroup(subscriptionID, resourceGroup, serverName, "failover-group-1") + failoverGroup2 := createAzureSqlServerFailoverGroup(subscriptionID, resourceGroup, serverName, "failover-group-2") + + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + mockPager := &mockSqlFailoverGroupsPager{ + pages: []armsql.FailoverGroupsClientListByServerResponse{ + { + FailoverGroupListResult: armsql.FailoverGroupListResult{ + Value: []*armsql.FailoverGroup{failoverGroup1, failoverGroup2}, + }, + }, + }, + } + + testClient := &testSqlFailoverGroupsClient{ + MockSqlFailoverGroupsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if err := item.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + if item.GetType() != azureshared.SQLServerFailoverGroup.String() { + t.Errorf("Expected type %s, got %s", azureshared.SQLServerFailoverGroup, item.GetType()) + } + } + }) + + t.Run("SearchStream", func(t *testing.T) { + failoverGroup1 := createAzureSqlServerFailoverGroup(subscriptionID, resourceGroup, serverName, "failover-group-1") + failoverGroup2 := createAzureSqlServerFailoverGroup(subscriptionID, resourceGroup, serverName, "failover-group-2") + + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + mockPager := &mockSqlFailoverGroupsPager{ + pages: []armsql.FailoverGroupsClientListByServerResponse{ + { + FailoverGroupListResult: armsql.FailoverGroupListResult{ + Value: []*armsql.FailoverGroup{failoverGroup1, failoverGroup2}, + }, + }, + }, + } + + testClient := &testSqlFailoverGroupsClient{ + MockSqlFailoverGroupsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + + var errs []error + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + searchStreamable, ok := adapter.(discovery.SearchStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support SearchStream operation") + } + + searchStreamable.SearchStream(ctx, wrapper.Scopes()[0], serverName, true, stream) + wg.Wait() + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + }) + + t.Run("SearchWithEmptyServerName", func(t *testing.T) { + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + testClient := &testSqlFailoverGroupsClient{MockSqlFailoverGroupsClient: mockClient} + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Test Search directly with empty server name + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0], "") + if qErr == nil { + t.Error("Expected error when serverName is empty, but got nil") + } + }) + + t.Run("Search_InvalidQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + testClient := &testSqlFailoverGroupsClient{MockSqlFailoverGroupsClient: mockClient} + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Test Search directly with no query parts + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) + if qErr == nil { + t.Error("Expected error when providing no query parts, but got nil") + } + }) + + t.Run("Search_WithNilName", func(t *testing.T) { + failoverGroup1 := createAzureSqlServerFailoverGroup(subscriptionID, resourceGroup, serverName, "failover-group-1") + failoverGroup2 := &armsql.FailoverGroup{ + Name: nil, // FailoverGroup with nil name should be skipped + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Sql/servers/test-server/failoverGroups/failover-group-2"), + Tags: map[string]*string{ + "env": new("test"), + }, + } + + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + mockPager := &mockSqlFailoverGroupsPager{ + pages: []armsql.FailoverGroupsClientListByServerResponse{ + { + FailoverGroupListResult: armsql.FailoverGroupListResult{ + Value: []*armsql.FailoverGroup{failoverGroup1, failoverGroup2}, + }, + }, + }, + } + + testClient := &testSqlFailoverGroupsClient{ + MockSqlFailoverGroupsClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Should only return 1 item (failover group with nil name is skipped) + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (nil name filtered out), got: %d", len(sdpItems)) + } + }) + + t.Run("ErrorHandling_Get", func(t *testing.T) { + expectedErr := errors.New("failover group not found") + + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, "nonexistent-failover-group").Return( + armsql.FailoverGroupsClientGetResponse{}, expectedErr) + + testClient := &testSqlFailoverGroupsClient{MockSqlFailoverGroupsClient: mockClient} + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, "nonexistent-failover-group") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting non-existent failover group, but got nil") + } + }) + + t.Run("ErrorHandling_Search", func(t *testing.T) { + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + errorPager := &errorSqlFailoverGroupsPager{} + + testClient := &testSqlFailoverGroupsClient{ + MockSqlFailoverGroupsClient: mockClient, + pager: errorPager, + } + + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + _, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if err == nil { + t.Error("Expected error from pager when NextPage returns an error, but got nil") + } + }) + + t.Run("InterfaceCompliance", func(t *testing.T) { + mockClient := mocks.NewMockSqlFailoverGroupsClient(ctrl) + testClient := &testSqlFailoverGroupsClient{MockSqlFailoverGroupsClient: mockClient} + wrapper := manual.NewSqlServerFailoverGroup(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + w := wrapper.(sources.Wrapper) + + // Verify IAMPermissions + permissions := w.IAMPermissions() + if len(permissions) == 0 { + t.Error("Expected IAMPermissions to return at least one permission") + } + expectedPermission := "Microsoft.Sql/servers/failoverGroups/read" + found := slices.Contains(permissions, expectedPermission) + if !found { + t.Errorf("Expected IAMPermissions to include %s", expectedPermission) + } + + // Verify PotentialLinks + potentialLinks := w.PotentialLinks() + if len(potentialLinks) == 0 { + t.Error("Expected PotentialLinks to return at least one link") + } + if !potentialLinks[azureshared.SQLServer] { + t.Error("Expected PotentialLinks to include SQLServer") + } + if !potentialLinks[azureshared.SQLDatabase] { + t.Error("Expected PotentialLinks to include SQLDatabase") + } + }) +} + +// createAzureSqlServerFailoverGroup creates a mock Azure SQL Server Failover Group for testing +func createAzureSqlServerFailoverGroup(subscriptionID, resourceGroup, serverName, failoverGroupName string) *armsql.FailoverGroup { + failoverGroupID := "/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.Sql/servers/" + serverName + "/failoverGroups/" + failoverGroupName + partnerServerID := "/subscriptions/" + subscriptionID + "/resourceGroups/partner-rg/providers/Microsoft.Sql/servers/partner-server" + databaseID := "/subscriptions/" + subscriptionID + "/resourceGroups/" + resourceGroup + "/providers/Microsoft.Sql/servers/" + serverName + "/databases/test-database" + + replicationState := "" + + return &armsql.FailoverGroup{ + Name: new(failoverGroupName), + Location: new("eastus"), + Tags: map[string]*string{ + "env": new("test"), + }, + ID: new(failoverGroupID), + Properties: &armsql.FailoverGroupProperties{ + ReplicationState: &replicationState, + PartnerServers: []*armsql.PartnerInfo{ + { + ID: new(partnerServerID), + Location: new("westus"), + }, + }, + Databases: []*string{ + new(databaseID), + }, + ReadWriteEndpoint: &armsql.FailoverGroupReadWriteEndpoint{ + FailoverPolicy: new(armsql.ReadWriteEndpointFailoverPolicyAutomatic), + }, + }, + } +} diff --git a/sources/azure/manual/sql-server-key.go b/sources/azure/manual/sql-server-key.go new file mode 100644 index 00000000..40f0fd00 --- /dev/null +++ b/sources/azure/manual/sql-server-key.go @@ -0,0 +1,249 @@ +package manual + +import ( + "context" + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/shared" +) + +var SQLServerKeyLookupByName = shared.NewItemTypeLookup("name", azureshared.SQLServerKey) + +type sqlServerKeyWrapper struct { + client clients.SqlServerKeysClient + + *azureshared.MultiResourceGroupBase +} + +func NewSqlServerKey(client clients.SqlServerKeysClient, resourceGroupScopes []azureshared.ResourceGroupScope) sources.SearchableWrapper { + return &sqlServerKeyWrapper{ + client: client, + MultiResourceGroupBase: azureshared.NewMultiResourceGroupBase( + resourceGroupScopes, + sdp.AdapterCategory_ADAPTER_CATEGORY_DATABASE, + azureshared.SQLServerKey, + ), + } +} + +// Get retrieves a single SQL Server Key by serverName and keyName +// ref: https://learn.microsoft.com/en-us/rest/api/sql/server-keys/get +func (c sqlServerKeyWrapper) Get(ctx context.Context, scope string, queryParts ...string) (*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 2 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Get requires 2 query parts: serverName and keyName", + Scope: scope, + ItemType: c.Type(), + } + } + serverName := queryParts[0] + keyName := queryParts[1] + + if serverName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "serverName cannot be empty", + Scope: scope, + ItemType: c.Type(), + } + } + if keyName == "" { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "keyName cannot be empty", + Scope: scope, + ItemType: c.Type(), + } + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + resp, err := c.client.Get(ctx, rgScope.ResourceGroup, serverName, keyName) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + return c.azureSqlServerKeyToSDPItem(&resp.ServerKey, serverName, scope) +} + +func (c sqlServerKeyWrapper) azureSqlServerKeyToSDPItem(serverKey *armsql.ServerKey, serverName, scope string) (*sdp.Item, *sdp.QueryError) { + if serverKey.Name == nil { + return nil, azureshared.QueryError(errors.New("server key name is nil"), scope, c.Type()) + } + keyName := *serverKey.Name + + attributes, err := shared.ToAttributesWithExclude(serverKey) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + err = attributes.Set("uniqueAttr", shared.CompositeLookupKey(serverName, keyName)) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + + sdpItem := &sdp.Item{ + Type: azureshared.SQLServerKey.String(), + UniqueAttribute: "uniqueAttr", + Attributes: attributes, + Scope: scope, + } + + // Link back to parent SQL Server + if serverName != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.SQLServer.String(), + Method: sdp.QueryMethod_GET, + Query: serverName, + Scope: scope, + }, + }) + } + + // Link to Key Vault Key if this is an Azure Key Vault type key + // The URI field contains the Key Vault key URI for AzureKeyVault server key types + // URI format: https://{vaultName}.vault.azure.net/keys/{keyName}/{version} + if serverKey.Properties != nil && serverKey.Properties.URI != nil && *serverKey.Properties.URI != "" { + keyURI := *serverKey.Properties.URI + vaultName := azureshared.ExtractVaultNameFromURI(keyURI) + keyVaultKeyName := azureshared.ExtractKeyNameFromURI(keyURI) + if vaultName != "" && keyVaultKeyName != "" { + sdpItem.LinkedItemQueries = append(sdpItem.LinkedItemQueries, &sdp.LinkedItemQuery{ + Query: &sdp.Query{ + Type: azureshared.KeyVaultKey.String(), + Method: sdp.QueryMethod_GET, + Query: shared.CompositeLookupKey(vaultName, keyVaultKeyName), + Scope: scope, + }, + }) + } + } + + return sdpItem, nil +} + +// Search retrieves all SQL Server Keys for a given server +// ref: https://learn.microsoft.com/en-us/rest/api/sql/server-keys/list-by-server +func (c sqlServerKeyWrapper) Search(ctx context.Context, scope string, queryParts ...string) ([]*sdp.Item, *sdp.QueryError) { + if len(queryParts) < 1 { + return nil, &sdp.QueryError{ + ErrorType: sdp.QueryError_OTHER, + ErrorString: "Search requires 1 query part: serverName", + Scope: scope, + ItemType: c.Type(), + } + } + serverName := queryParts[0] + if serverName == "" { + return nil, azureshared.QueryError(errors.New("serverName cannot be empty"), scope, c.Type()) + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + pager := c.client.NewListByServerPager(rgScope.ResourceGroup, serverName) + + var items []*sdp.Item + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + return nil, azureshared.QueryError(err, scope, c.Type()) + } + for _, serverKey := range page.Value { + if serverKey.Name == nil { + continue + } + item, sdpErr := c.azureSqlServerKeyToSDPItem(serverKey, serverName, scope) + if sdpErr != nil { + return nil, sdpErr + } + items = append(items, item) + } + } + + return items, nil +} + +func (c sqlServerKeyWrapper) SearchStream(ctx context.Context, stream discovery.QueryResultStream, cache sdpcache.Cache, cacheKey sdpcache.CacheKey, scope string, queryParts ...string) { + if len(queryParts) < 1 { + stream.SendError(azureshared.QueryError(errors.New("Search requires 1 query part: serverName"), scope, c.Type())) + return + } + serverName := queryParts[0] + if serverName == "" { + stream.SendError(azureshared.QueryError(errors.New("serverName cannot be empty"), scope, c.Type())) + return + } + + rgScope, err := c.ResourceGroupScopeFromScope(scope) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + pager := c.client.NewListByServerPager(rgScope.ResourceGroup, serverName) + for pager.More() { + page, err := pager.NextPage(ctx) + if err != nil { + stream.SendError(azureshared.QueryError(err, scope, c.Type())) + return + } + for _, serverKey := range page.Value { + if serverKey.Name == nil { + continue + } + item, sdpErr := c.azureSqlServerKeyToSDPItem(serverKey, serverName, scope) + if sdpErr != nil { + stream.SendError(sdpErr) + continue + } + cache.StoreItem(ctx, item, shared.DefaultCacheDuration, cacheKey) + stream.SendItem(item) + } + } +} + +func (c sqlServerKeyWrapper) GetLookups() sources.ItemTypeLookups { + return sources.ItemTypeLookups{ + SQLServerLookupByName, + SQLServerKeyLookupByName, + } +} + +func (c sqlServerKeyWrapper) SearchLookups() []sources.ItemTypeLookups { + return []sources.ItemTypeLookups{ + { + SQLServerLookupByName, + }, + } +} + +func (c sqlServerKeyWrapper) PotentialLinks() map[shared.ItemType]bool { + return shared.NewItemTypesSet( + azureshared.SQLServer, + azureshared.KeyVaultKey, + ) +} + +// IAMPermissions returns the required Azure RBAC permissions for reading SQL Server Keys +// ref: https://learn.microsoft.com/en-us/azure/role-based-access-control/resource-provider-operations#microsoftsql +func (c sqlServerKeyWrapper) IAMPermissions() []string { + return []string{ + "Microsoft.Sql/servers/keys/read", + } +} + +func (c sqlServerKeyWrapper) PredefinedRole() string { + return "Reader" +} diff --git a/sources/azure/manual/sql-server-key_test.go b/sources/azure/manual/sql-server-key_test.go new file mode 100644 index 00000000..0d9337dc --- /dev/null +++ b/sources/azure/manual/sql-server-key_test.go @@ -0,0 +1,504 @@ +package manual_test + +import ( + "context" + "errors" + "slices" + "sync" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" + "go.uber.org/mock/gomock" + + "github.com/overmindtech/cli/go/discovery" + "github.com/overmindtech/cli/go/sdp-go" + "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/sources" + "github.com/overmindtech/cli/sources/azure/clients" + "github.com/overmindtech/cli/sources/azure/manual" + azureshared "github.com/overmindtech/cli/sources/azure/shared" + "github.com/overmindtech/cli/sources/azure/shared/mocks" + "github.com/overmindtech/cli/sources/shared" +) + +// mockSqlServerKeysPager is a simple mock implementation of SqlServerKeysPager +type mockSqlServerKeysPager struct { + pages []armsql.ServerKeysClientListByServerResponse + index int +} + +func (m *mockSqlServerKeysPager) More() bool { + return m.index < len(m.pages) +} + +func (m *mockSqlServerKeysPager) NextPage(ctx context.Context) (armsql.ServerKeysClientListByServerResponse, error) { + if m.index >= len(m.pages) { + return armsql.ServerKeysClientListByServerResponse{}, errors.New("no more pages") + } + page := m.pages[m.index] + m.index++ + return page, nil +} + +// errorSqlServerKeysPager is a mock pager that always returns an error +type errorSqlServerKeysPager struct{} + +func (e *errorSqlServerKeysPager) More() bool { + return true +} + +func (e *errorSqlServerKeysPager) NextPage(ctx context.Context) (armsql.ServerKeysClientListByServerResponse, error) { + return armsql.ServerKeysClientListByServerResponse{}, errors.New("pager error") +} + +// testSqlServerKeysClient wraps the mock to implement the correct interface +type testSqlServerKeysClient struct { + *mocks.MockSqlServerKeysClient + pager clients.SqlServerKeysPager +} + +func (t *testSqlServerKeysClient) NewListByServerPager(resourceGroupName, serverName string) clients.SqlServerKeysPager { + return t.pager +} + +func TestSqlServerKey(t *testing.T) { + ctx := context.Background() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + subscriptionID := "test-subscription" + resourceGroup := "test-rg" + serverName := "test-server" + keyName := "test-key" + + t.Run("Get", func(t *testing.T) { + serverKey := createAzureSqlServerKey(serverName, keyName, "") + + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, keyName).Return( + armsql.ServerKeysClientGetResponse{ + ServerKey: *serverKey, + }, nil) + + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Get requires serverName and keyName as query parts + query := shared.CompositeLookupKey(serverName, keyName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + if sdpItem.GetType() != azureshared.SQLServerKey.String() { + t.Errorf("Expected type %s, got %s", azureshared.SQLServerKey, sdpItem.GetType()) + } + + if sdpItem.GetUniqueAttribute() != "uniqueAttr" { + t.Errorf("Expected unique attribute 'uniqueAttr', got %s", sdpItem.GetUniqueAttribute()) + } + + expectedUniqueAttrValue := shared.CompositeLookupKey(serverName, keyName) + if sdpItem.UniqueAttributeValue() != expectedUniqueAttrValue { + t.Errorf("Expected unique attribute value %s, got %s", expectedUniqueAttrValue, sdpItem.UniqueAttributeValue()) + } + + if sdpItem.GetScope() != subscriptionID+"."+resourceGroup { + t.Errorf("Expected scope %s, got %s", subscriptionID+"."+resourceGroup, sdpItem.GetScope()) + } + + // Validate the item + if err := sdpItem.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + // SQLServer link (parent) + ExpectedType: azureshared.SQLServer.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: serverName, + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("Get_WithKeyVaultKey", func(t *testing.T) { + keyVaultKeyURI := "https://my-vault.vault.azure.net/keys/my-key/12345" + serverKey := createAzureSqlServerKey(serverName, keyName, keyVaultKeyURI) + + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, keyName).Return( + armsql.ServerKeysClientGetResponse{ + ServerKey: *serverKey, + }, nil) + + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, keyName) + sdpItem, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr != nil { + t.Fatalf("Expected no error, got: %v", qErr) + } + + t.Run("StaticTests", func(t *testing.T) { + queryTests := shared.QueryTests{ + { + // SQLServer link (parent) + ExpectedType: azureshared.SQLServer.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: serverName, + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + { + // KeyVaultKey link + ExpectedType: azureshared.KeyVaultKey.String(), + ExpectedMethod: sdp.QueryMethod_GET, + ExpectedQuery: shared.CompositeLookupKey("my-vault", "my-key"), + ExpectedScope: subscriptionID + "." + resourceGroup, + }, + } + + shared.RunStaticTests(t, adapter, sdpItem, queryTests) + }) + }) + + t.Run("GetWithInsufficientQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Test with insufficient query parts (only server name) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], serverName, true) + if qErr == nil { + t.Error("Expected error when providing insufficient query parts, but got nil") + } + }) + + t.Run("GetWithEmptyServerName", func(t *testing.T) { + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Test with empty server name + query := shared.CompositeLookupKey("", keyName) + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when providing empty server name, but got nil") + } + }) + + t.Run("GetWithEmptyKeyName", func(t *testing.T) { + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + // Test with empty key name + query := shared.CompositeLookupKey(serverName, "") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when providing empty key name, but got nil") + } + }) + + t.Run("Search", func(t *testing.T) { + serverKey1 := createAzureSqlServerKey(serverName, "key-1", "") + serverKey2 := createAzureSqlServerKey(serverName, "key-2", "") + + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + mockPager := &mockSqlServerKeysPager{ + pages: []armsql.ServerKeysClientListByServerResponse{ + { + ServerKeyListResult: armsql.ServerKeyListResult{ + Value: []*armsql.ServerKey{serverKey1, serverKey2}, + }, + }, + }, + } + + testClient := &testSqlServerKeysClient{ + MockSqlServerKeysClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + if len(sdpItems) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(sdpItems)) + } + + for _, item := range sdpItems { + if err := item.Validate(); err != nil { + t.Fatalf("Expected no validation error, got: %v", err) + } + + if item.GetType() != azureshared.SQLServerKey.String() { + t.Errorf("Expected type %s, got %s", azureshared.SQLServerKey, item.GetType()) + } + } + }) + + t.Run("SearchStream", func(t *testing.T) { + serverKey1 := createAzureSqlServerKey(serverName, "key-1", "") + serverKey2 := createAzureSqlServerKey(serverName, "key-2", "") + + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + mockPager := &mockSqlServerKeysPager{ + pages: []armsql.ServerKeysClientListByServerResponse{ + { + ServerKeyListResult: armsql.ServerKeyListResult{ + Value: []*armsql.ServerKey{serverKey1, serverKey2}, + }, + }, + }, + } + + testClient := &testSqlServerKeysClient{ + MockSqlServerKeysClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + wg := &sync.WaitGroup{} + wg.Add(2) + + var items []*sdp.Item + mockItemHandler := func(item *sdp.Item) { + items = append(items, item) + wg.Done() + } + + var errs []error + mockErrorHandler := func(err error) { + errs = append(errs, err) + } + + stream := discovery.NewQueryResultStream(mockItemHandler, mockErrorHandler) + + searchStreamable, ok := adapter.(discovery.SearchStreamableAdapter) + if !ok { + t.Fatalf("Adapter does not support SearchStream operation") + } + + searchStreamable.SearchStream(ctx, wrapper.Scopes()[0], serverName, true, stream) + wg.Wait() + + if len(errs) != 0 { + t.Fatalf("Expected no errors, got: %v", errs) + } + + if len(items) != 2 { + t.Fatalf("Expected 2 items, got: %d", len(items)) + } + }) + + t.Run("Search_WithNilName", func(t *testing.T) { + serverKey1 := createAzureSqlServerKey(serverName, "key-1", "") + serverKey2 := &armsql.ServerKey{ + Name: nil, // Key with nil name should be skipped + Location: new("eastus"), + ID: new("/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Sql/servers/test-server/keys/key-2"), + Properties: &armsql.ServerKeyProperties{ + ServerKeyType: new(armsql.ServerKeyTypeServiceManaged), + }, + } + + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + mockPager := &mockSqlServerKeysPager{ + pages: []armsql.ServerKeysClientListByServerResponse{ + { + ServerKeyListResult: armsql.ServerKeyListResult{ + Value: []*armsql.ServerKey{serverKey1, serverKey2}, + }, + }, + }, + } + + testClient := &testSqlServerKeysClient{ + MockSqlServerKeysClient: mockClient, + pager: mockPager, + } + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + sdpItems, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + if err != nil { + t.Fatalf("Expected no error, got: %v", err) + } + + // Should only return 1 item (key with nil name is skipped) + if len(sdpItems) != 1 { + t.Fatalf("Expected 1 item (nil name filtered out), got: %d", len(sdpItems)) + } + + if sdpItems[0].UniqueAttributeValue() != shared.CompositeLookupKey(serverName, "key-1") { + t.Fatalf("Expected key name 'key-1', got: %s", sdpItems[0].UniqueAttributeValue()) + } + }) + + t.Run("Search_InvalidQueryParts", func(t *testing.T) { + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Test Search directly with no query parts - should return error before calling NewListByServerPager + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0]) + if qErr == nil { + t.Error("Expected error when providing no query parts, but got nil") + } + }) + + t.Run("SearchWithEmptyServerName", func(t *testing.T) { + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Test Search with empty server name + _, qErr := wrapper.Search(ctx, wrapper.Scopes()[0], "") + if qErr == nil { + t.Error("Expected error when providing empty server name in Search, but got nil") + } + }) + + t.Run("ErrorHandling_Get", func(t *testing.T) { + expectedErr := errors.New("key not found") + + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + mockClient.EXPECT().Get(ctx, resourceGroup, serverName, "nonexistent-key").Return( + armsql.ServerKeysClientGetResponse{}, expectedErr) + + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + query := shared.CompositeLookupKey(serverName, "nonexistent-key") + _, qErr := adapter.Get(ctx, wrapper.Scopes()[0], query, true) + if qErr == nil { + t.Error("Expected error when getting non-existent key, but got nil") + } + }) + + t.Run("ErrorHandling_Search", func(t *testing.T) { + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + // Create a pager that returns an error when NextPage is called + errorPager := &errorSqlServerKeysPager{} + + testClient := &testSqlServerKeysClient{ + MockSqlServerKeysClient: mockClient, + pager: errorPager, + } + + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + adapter := sources.WrapperToAdapter(wrapper, sdpcache.NewNoOpCache()) + + searchable, ok := adapter.(discovery.SearchableAdapter) + if !ok { + t.Fatalf("Adapter does not support Search operation") + } + + _, err := searchable.Search(ctx, wrapper.Scopes()[0], serverName, true) + // The Search implementation should return an error when pager.NextPage returns an error + if err == nil { + t.Error("Expected error from pager when NextPage returns an error, but got nil") + } + }) + + t.Run("InterfaceCompliance", func(t *testing.T) { + mockClient := mocks.NewMockSqlServerKeysClient(ctrl) + testClient := &testSqlServerKeysClient{MockSqlServerKeysClient: mockClient} + wrapper := manual.NewSqlServerKey(testClient, []azureshared.ResourceGroupScope{azureshared.NewResourceGroupScope(subscriptionID, resourceGroup)}) + + // Cast to sources.Wrapper to access interface methods + w := wrapper.(sources.Wrapper) + + // Verify IAMPermissions + permissions := w.IAMPermissions() + if len(permissions) == 0 { + t.Error("Expected IAMPermissions to return at least one permission") + } + expectedPermission := "Microsoft.Sql/servers/keys/read" + found := slices.Contains(permissions, expectedPermission) + if !found { + t.Errorf("Expected IAMPermissions to include %s", expectedPermission) + } + + // Verify PotentialLinks + potentialLinks := w.PotentialLinks() + if len(potentialLinks) == 0 { + t.Error("Expected PotentialLinks to return at least one link") + } + if !potentialLinks[azureshared.SQLServer] { + t.Error("Expected PotentialLinks to include SQLServer") + } + if !potentialLinks[azureshared.KeyVaultKey] { + t.Error("Expected PotentialLinks to include KeyVaultKey") + } + + // Verify PredefinedRole using type assertion to the searchable wrapper + if sw, ok := wrapper.(interface{ PredefinedRole() string }); ok { + role := sw.PredefinedRole() + if role != "Reader" { + t.Errorf("Expected PredefinedRole to be 'Reader', got %s", role) + } + } + }) +} + +// createAzureSqlServerKey creates a mock Azure SQL Server Key for testing +func createAzureSqlServerKey(serverName, keyName, keyVaultKeyURI string) *armsql.ServerKey { + keyID := "/subscriptions/test-subscription/resourceGroups/test-rg/providers/Microsoft.Sql/servers/" + serverName + "/keys/" + keyName + + keyType := armsql.ServerKeyTypeServiceManaged + if keyVaultKeyURI != "" { + keyType = armsql.ServerKeyTypeAzureKeyVault + } + + serverKey := &armsql.ServerKey{ + Name: new(keyName), + Location: new("eastus"), + ID: new(keyID), + Properties: &armsql.ServerKeyProperties{ + ServerKeyType: new(keyType), + }, + } + + if keyVaultKeyURI != "" { + serverKey.Properties.URI = new(keyVaultKeyURI) + } + + return serverKey +} diff --git a/sources/azure/shared/item-types.go b/sources/azure/shared/item-types.go index 1b471196..c4ee9d1b 100644 --- a/sources/azure/shared/item-types.go +++ b/sources/azure/shared/item-types.go @@ -71,6 +71,8 @@ var ( NetworkSecurityRule = shared.NewItemType(Azure, Network, SecurityRule) NetworkDefaultSecurityRule = shared.NewItemType(Azure, Network, DefaultSecurityRule) NetworkIPGroup = shared.NewItemType(Azure, Network, IPGroup) + NetworkFirewall = shared.NewItemType(Azure, Network, Firewall) + NetworkFirewallPolicy = shared.NewItemType(Azure, Network, FirewallPolicy) NetworkRouteTable = shared.NewItemType(Azure, Network, RouteTable) NetworkRoute = shared.NewItemType(Azure, Network, Route) NetworkVirtualNetworkGateway = shared.NewItemType(Azure, Network, VirtualNetworkGateway) @@ -184,8 +186,16 @@ var ( // OperationalInsights item types OperationalInsightsWorkspace = shared.NewItemType(Azure, OperationalInsights, Workspace) + OperationalInsightsCluster = shared.NewItemType(Azure, OperationalInsights, Cluster) + + // Insights (Azure Monitor) item types + InsightsPrivateLinkScopeScopedResource = shared.NewItemType(Azure, Insights, PrivateLinkScopeScopedResource) // Authorization item types AuthorizationRoleAssignment = shared.NewItemType(Azure, Authorization, RoleAssignment) AuthorizationRoleDefinition = shared.NewItemType(Azure, Authorization, RoleDefinition) + + // Resources item types + ResourcesSubscription = shared.NewItemType(Azure, Resources, Subscription) + ResourcesResourceGroup = shared.NewItemType(Azure, Resources, ResourceGroup) ) diff --git a/sources/azure/shared/mocks/mock_batch_private_endpoint_connection_client.go b/sources/azure/shared/mocks/mock_batch_private_endpoint_connection_client.go new file mode 100644 index 00000000..27560518 --- /dev/null +++ b/sources/azure/shared/mocks/mock_batch_private_endpoint_connection_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: batch-private-endpoint-connection-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_batch_private_endpoint_connection_client.go -package=mocks -source=batch-private-endpoint-connection-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armbatch "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/batch/armbatch/v4" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockBatchPrivateEndpointConnectionClient is a mock of BatchPrivateEndpointConnectionClient interface. +type MockBatchPrivateEndpointConnectionClient struct { + ctrl *gomock.Controller + recorder *MockBatchPrivateEndpointConnectionClientMockRecorder + isgomock struct{} +} + +// MockBatchPrivateEndpointConnectionClientMockRecorder is the mock recorder for MockBatchPrivateEndpointConnectionClient. +type MockBatchPrivateEndpointConnectionClientMockRecorder struct { + mock *MockBatchPrivateEndpointConnectionClient +} + +// NewMockBatchPrivateEndpointConnectionClient creates a new mock instance. +func NewMockBatchPrivateEndpointConnectionClient(ctrl *gomock.Controller) *MockBatchPrivateEndpointConnectionClient { + mock := &MockBatchPrivateEndpointConnectionClient{ctrl: ctrl} + mock.recorder = &MockBatchPrivateEndpointConnectionClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockBatchPrivateEndpointConnectionClient) EXPECT() *MockBatchPrivateEndpointConnectionClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockBatchPrivateEndpointConnectionClient) Get(ctx context.Context, resourceGroupName, accountName, privateEndpointConnectionName string) (armbatch.PrivateEndpointConnectionClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, accountName, privateEndpointConnectionName) + ret0, _ := ret[0].(armbatch.PrivateEndpointConnectionClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockBatchPrivateEndpointConnectionClientMockRecorder) Get(ctx, resourceGroupName, accountName, privateEndpointConnectionName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockBatchPrivateEndpointConnectionClient)(nil).Get), ctx, resourceGroupName, accountName, privateEndpointConnectionName) +} + +// ListByBatchAccount mocks base method. +func (m *MockBatchPrivateEndpointConnectionClient) ListByBatchAccount(ctx context.Context, resourceGroupName, accountName string) clients.BatchPrivateEndpointConnectionPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListByBatchAccount", ctx, resourceGroupName, accountName) + ret0, _ := ret[0].(clients.BatchPrivateEndpointConnectionPager) + return ret0 +} + +// ListByBatchAccount indicates an expected call of ListByBatchAccount. +func (mr *MockBatchPrivateEndpointConnectionClientMockRecorder) ListByBatchAccount(ctx, resourceGroupName, accountName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListByBatchAccount", reflect.TypeOf((*MockBatchPrivateEndpointConnectionClient)(nil).ListByBatchAccount), ctx, resourceGroupName, accountName) +} diff --git a/sources/azure/shared/mocks/mock_dbforpostgresql_flexible_server_administrator_client.go b/sources/azure/shared/mocks/mock_dbforpostgresql_flexible_server_administrator_client.go new file mode 100644 index 00000000..0ef54f40 --- /dev/null +++ b/sources/azure/shared/mocks/mock_dbforpostgresql_flexible_server_administrator_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: dbforpostgresql-flexible-server-administrator-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_dbforpostgresql_flexible_server_administrator_client.go -package=mocks -source=dbforpostgresql-flexible-server-administrator-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armpostgresqlflexibleservers "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockDBforPostgreSQLFlexibleServerAdministratorClient is a mock of DBforPostgreSQLFlexibleServerAdministratorClient interface. +type MockDBforPostgreSQLFlexibleServerAdministratorClient struct { + ctrl *gomock.Controller + recorder *MockDBforPostgreSQLFlexibleServerAdministratorClientMockRecorder + isgomock struct{} +} + +// MockDBforPostgreSQLFlexibleServerAdministratorClientMockRecorder is the mock recorder for MockDBforPostgreSQLFlexibleServerAdministratorClient. +type MockDBforPostgreSQLFlexibleServerAdministratorClientMockRecorder struct { + mock *MockDBforPostgreSQLFlexibleServerAdministratorClient +} + +// NewMockDBforPostgreSQLFlexibleServerAdministratorClient creates a new mock instance. +func NewMockDBforPostgreSQLFlexibleServerAdministratorClient(ctrl *gomock.Controller) *MockDBforPostgreSQLFlexibleServerAdministratorClient { + mock := &MockDBforPostgreSQLFlexibleServerAdministratorClient{ctrl: ctrl} + mock.recorder = &MockDBforPostgreSQLFlexibleServerAdministratorClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBforPostgreSQLFlexibleServerAdministratorClient) EXPECT() *MockDBforPostgreSQLFlexibleServerAdministratorClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockDBforPostgreSQLFlexibleServerAdministratorClient) Get(ctx context.Context, resourceGroupName, serverName, objectID string) (armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, serverName, objectID) + ret0, _ := ret[0].(armpostgresqlflexibleservers.AdministratorsMicrosoftEntraClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockDBforPostgreSQLFlexibleServerAdministratorClientMockRecorder) Get(ctx, resourceGroupName, serverName, objectID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDBforPostgreSQLFlexibleServerAdministratorClient)(nil).Get), ctx, resourceGroupName, serverName, objectID) +} + +// ListByServer mocks base method. +func (m *MockDBforPostgreSQLFlexibleServerAdministratorClient) ListByServer(ctx context.Context, resourceGroupName, serverName string) clients.DBforPostgreSQLFlexibleServerAdministratorPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListByServer", ctx, resourceGroupName, serverName) + ret0, _ := ret[0].(clients.DBforPostgreSQLFlexibleServerAdministratorPager) + return ret0 +} + +// ListByServer indicates an expected call of ListByServer. +func (mr *MockDBforPostgreSQLFlexibleServerAdministratorClientMockRecorder) ListByServer(ctx, resourceGroupName, serverName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListByServer", reflect.TypeOf((*MockDBforPostgreSQLFlexibleServerAdministratorClient)(nil).ListByServer), ctx, resourceGroupName, serverName) +} diff --git a/sources/azure/shared/mocks/mock_dbforpostgresql_flexible_server_virtual_endpoint_client.go b/sources/azure/shared/mocks/mock_dbforpostgresql_flexible_server_virtual_endpoint_client.go new file mode 100644 index 00000000..44ede56c --- /dev/null +++ b/sources/azure/shared/mocks/mock_dbforpostgresql_flexible_server_virtual_endpoint_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: dbforpostgresql-flexible-server-virtual-endpoint-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_dbforpostgresql_flexible_server_virtual_endpoint_client.go -package=mocks -source=dbforpostgresql-flexible-server-virtual-endpoint-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armpostgresqlflexibleservers "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/postgresql/armpostgresqlflexibleservers/v5" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockDBforPostgreSQLFlexibleServerVirtualEndpointClient is a mock of DBforPostgreSQLFlexibleServerVirtualEndpointClient interface. +type MockDBforPostgreSQLFlexibleServerVirtualEndpointClient struct { + ctrl *gomock.Controller + recorder *MockDBforPostgreSQLFlexibleServerVirtualEndpointClientMockRecorder + isgomock struct{} +} + +// MockDBforPostgreSQLFlexibleServerVirtualEndpointClientMockRecorder is the mock recorder for MockDBforPostgreSQLFlexibleServerVirtualEndpointClient. +type MockDBforPostgreSQLFlexibleServerVirtualEndpointClientMockRecorder struct { + mock *MockDBforPostgreSQLFlexibleServerVirtualEndpointClient +} + +// NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient creates a new mock instance. +func NewMockDBforPostgreSQLFlexibleServerVirtualEndpointClient(ctrl *gomock.Controller) *MockDBforPostgreSQLFlexibleServerVirtualEndpointClient { + mock := &MockDBforPostgreSQLFlexibleServerVirtualEndpointClient{ctrl: ctrl} + mock.recorder = &MockDBforPostgreSQLFlexibleServerVirtualEndpointClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockDBforPostgreSQLFlexibleServerVirtualEndpointClient) EXPECT() *MockDBforPostgreSQLFlexibleServerVirtualEndpointClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockDBforPostgreSQLFlexibleServerVirtualEndpointClient) Get(ctx context.Context, resourceGroupName, serverName, virtualEndpointName string) (armpostgresqlflexibleservers.VirtualEndpointsClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, serverName, virtualEndpointName) + ret0, _ := ret[0].(armpostgresqlflexibleservers.VirtualEndpointsClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockDBforPostgreSQLFlexibleServerVirtualEndpointClientMockRecorder) Get(ctx, resourceGroupName, serverName, virtualEndpointName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockDBforPostgreSQLFlexibleServerVirtualEndpointClient)(nil).Get), ctx, resourceGroupName, serverName, virtualEndpointName) +} + +// ListByServer mocks base method. +func (m *MockDBforPostgreSQLFlexibleServerVirtualEndpointClient) ListByServer(ctx context.Context, resourceGroupName, serverName string) clients.DBforPostgreSQLFlexibleServerVirtualEndpointPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListByServer", ctx, resourceGroupName, serverName) + ret0, _ := ret[0].(clients.DBforPostgreSQLFlexibleServerVirtualEndpointPager) + return ret0 +} + +// ListByServer indicates an expected call of ListByServer. +func (mr *MockDBforPostgreSQLFlexibleServerVirtualEndpointClientMockRecorder) ListByServer(ctx, resourceGroupName, serverName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListByServer", reflect.TypeOf((*MockDBforPostgreSQLFlexibleServerVirtualEndpointClient)(nil).ListByServer), ctx, resourceGroupName, serverName) +} diff --git a/sources/azure/shared/mocks/mock_elastic_san_volume_client.go b/sources/azure/shared/mocks/mock_elastic_san_volume_client.go new file mode 100644 index 00000000..f2b6ad6e --- /dev/null +++ b/sources/azure/shared/mocks/mock_elastic_san_volume_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: elastic-san-volume-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_elastic_san_volume_client.go -package=mocks -source=elastic-san-volume-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armelasticsan "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/elasticsan/armelasticsan" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockElasticSanVolumeClient is a mock of ElasticSanVolumeClient interface. +type MockElasticSanVolumeClient struct { + ctrl *gomock.Controller + recorder *MockElasticSanVolumeClientMockRecorder + isgomock struct{} +} + +// MockElasticSanVolumeClientMockRecorder is the mock recorder for MockElasticSanVolumeClient. +type MockElasticSanVolumeClientMockRecorder struct { + mock *MockElasticSanVolumeClient +} + +// NewMockElasticSanVolumeClient creates a new mock instance. +func NewMockElasticSanVolumeClient(ctrl *gomock.Controller) *MockElasticSanVolumeClient { + mock := &MockElasticSanVolumeClient{ctrl: ctrl} + mock.recorder = &MockElasticSanVolumeClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockElasticSanVolumeClient) EXPECT() *MockElasticSanVolumeClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockElasticSanVolumeClient) Get(ctx context.Context, resourceGroupName, elasticSanName, volumeGroupName, volumeName string, options *armelasticsan.VolumesClientGetOptions) (armelasticsan.VolumesClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, options) + ret0, _ := ret[0].(armelasticsan.VolumesClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockElasticSanVolumeClientMockRecorder) Get(ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockElasticSanVolumeClient)(nil).Get), ctx, resourceGroupName, elasticSanName, volumeGroupName, volumeName, options) +} + +// NewListByVolumeGroupPager mocks base method. +func (m *MockElasticSanVolumeClient) NewListByVolumeGroupPager(resourceGroupName, elasticSanName, volumeGroupName string, options *armelasticsan.VolumesClientListByVolumeGroupOptions) clients.ElasticSanVolumePager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListByVolumeGroupPager", resourceGroupName, elasticSanName, volumeGroupName, options) + ret0, _ := ret[0].(clients.ElasticSanVolumePager) + return ret0 +} + +// NewListByVolumeGroupPager indicates an expected call of NewListByVolumeGroupPager. +func (mr *MockElasticSanVolumeClientMockRecorder) NewListByVolumeGroupPager(resourceGroupName, elasticSanName, volumeGroupName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListByVolumeGroupPager", reflect.TypeOf((*MockElasticSanVolumeClient)(nil).NewListByVolumeGroupPager), resourceGroupName, elasticSanName, volumeGroupName, options) +} diff --git a/sources/azure/shared/mocks/mock_federated_identity_credentials_client.go b/sources/azure/shared/mocks/mock_federated_identity_credentials_client.go new file mode 100644 index 00000000..6bf4fbf0 --- /dev/null +++ b/sources/azure/shared/mocks/mock_federated_identity_credentials_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: federated-identity-credentials-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_federated_identity_credentials_client.go -package=mocks -source=federated-identity-credentials-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armmsi "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockFederatedIdentityCredentialsClient is a mock of FederatedIdentityCredentialsClient interface. +type MockFederatedIdentityCredentialsClient struct { + ctrl *gomock.Controller + recorder *MockFederatedIdentityCredentialsClientMockRecorder + isgomock struct{} +} + +// MockFederatedIdentityCredentialsClientMockRecorder is the mock recorder for MockFederatedIdentityCredentialsClient. +type MockFederatedIdentityCredentialsClientMockRecorder struct { + mock *MockFederatedIdentityCredentialsClient +} + +// NewMockFederatedIdentityCredentialsClient creates a new mock instance. +func NewMockFederatedIdentityCredentialsClient(ctrl *gomock.Controller) *MockFederatedIdentityCredentialsClient { + mock := &MockFederatedIdentityCredentialsClient{ctrl: ctrl} + mock.recorder = &MockFederatedIdentityCredentialsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockFederatedIdentityCredentialsClient) EXPECT() *MockFederatedIdentityCredentialsClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockFederatedIdentityCredentialsClient) Get(ctx context.Context, resourceGroupName, resourceName, federatedIdentityCredentialResourceName string, options *armmsi.FederatedIdentityCredentialsClientGetOptions) (armmsi.FederatedIdentityCredentialsClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, resourceName, federatedIdentityCredentialResourceName, options) + ret0, _ := ret[0].(armmsi.FederatedIdentityCredentialsClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockFederatedIdentityCredentialsClientMockRecorder) Get(ctx, resourceGroupName, resourceName, federatedIdentityCredentialResourceName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockFederatedIdentityCredentialsClient)(nil).Get), ctx, resourceGroupName, resourceName, federatedIdentityCredentialResourceName, options) +} + +// NewListPager mocks base method. +func (m *MockFederatedIdentityCredentialsClient) NewListPager(resourceGroupName, resourceName string, options *armmsi.FederatedIdentityCredentialsClientListOptions) clients.FederatedIdentityCredentialsPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListPager", resourceGroupName, resourceName, options) + ret0, _ := ret[0].(clients.FederatedIdentityCredentialsPager) + return ret0 +} + +// NewListPager indicates an expected call of NewListPager. +func (mr *MockFederatedIdentityCredentialsClientMockRecorder) NewListPager(resourceGroupName, resourceName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListPager", reflect.TypeOf((*MockFederatedIdentityCredentialsClient)(nil).NewListPager), resourceGroupName, resourceName, options) +} diff --git a/sources/azure/shared/mocks/mock_interface_ip_configurations_client.go b/sources/azure/shared/mocks/mock_interface_ip_configurations_client.go new file mode 100644 index 00000000..cfdf8205 --- /dev/null +++ b/sources/azure/shared/mocks/mock_interface_ip_configurations_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: interface-ip-configurations-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_interface_ip_configurations_client.go -package=mocks -source=interface-ip-configurations-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armnetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockInterfaceIPConfigurationsClient is a mock of InterfaceIPConfigurationsClient interface. +type MockInterfaceIPConfigurationsClient struct { + ctrl *gomock.Controller + recorder *MockInterfaceIPConfigurationsClientMockRecorder + isgomock struct{} +} + +// MockInterfaceIPConfigurationsClientMockRecorder is the mock recorder for MockInterfaceIPConfigurationsClient. +type MockInterfaceIPConfigurationsClientMockRecorder struct { + mock *MockInterfaceIPConfigurationsClient +} + +// NewMockInterfaceIPConfigurationsClient creates a new mock instance. +func NewMockInterfaceIPConfigurationsClient(ctrl *gomock.Controller) *MockInterfaceIPConfigurationsClient { + mock := &MockInterfaceIPConfigurationsClient{ctrl: ctrl} + mock.recorder = &MockInterfaceIPConfigurationsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockInterfaceIPConfigurationsClient) EXPECT() *MockInterfaceIPConfigurationsClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockInterfaceIPConfigurationsClient) Get(ctx context.Context, resourceGroupName, networkInterfaceName, ipConfigurationName string) (armnetwork.InterfaceIPConfigurationsClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, networkInterfaceName, ipConfigurationName) + ret0, _ := ret[0].(armnetwork.InterfaceIPConfigurationsClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockInterfaceIPConfigurationsClientMockRecorder) Get(ctx, resourceGroupName, networkInterfaceName, ipConfigurationName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockInterfaceIPConfigurationsClient)(nil).Get), ctx, resourceGroupName, networkInterfaceName, ipConfigurationName) +} + +// List mocks base method. +func (m *MockInterfaceIPConfigurationsClient) List(ctx context.Context, resourceGroupName, networkInterfaceName string) clients.InterfaceIPConfigurationsPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", ctx, resourceGroupName, networkInterfaceName) + ret0, _ := ret[0].(clients.InterfaceIPConfigurationsPager) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockInterfaceIPConfigurationsClientMockRecorder) List(ctx, resourceGroupName, networkInterfaceName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockInterfaceIPConfigurationsClient)(nil).List), ctx, resourceGroupName, networkInterfaceName) +} diff --git a/sources/azure/shared/mocks/mock_ip_groups_client.go b/sources/azure/shared/mocks/mock_ip_groups_client.go new file mode 100644 index 00000000..034ea5fb --- /dev/null +++ b/sources/azure/shared/mocks/mock_ip_groups_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ip-groups-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_ip_groups_client.go -package=mocks -source=ip-groups-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armnetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockIPGroupsClient is a mock of IPGroupsClient interface. +type MockIPGroupsClient struct { + ctrl *gomock.Controller + recorder *MockIPGroupsClientMockRecorder + isgomock struct{} +} + +// MockIPGroupsClientMockRecorder is the mock recorder for MockIPGroupsClient. +type MockIPGroupsClientMockRecorder struct { + mock *MockIPGroupsClient +} + +// NewMockIPGroupsClient creates a new mock instance. +func NewMockIPGroupsClient(ctrl *gomock.Controller) *MockIPGroupsClient { + mock := &MockIPGroupsClient{ctrl: ctrl} + mock.recorder = &MockIPGroupsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIPGroupsClient) EXPECT() *MockIPGroupsClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockIPGroupsClient) Get(ctx context.Context, resourceGroupName, ipGroupsName string, options *armnetwork.IPGroupsClientGetOptions) (armnetwork.IPGroupsClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, ipGroupsName, options) + ret0, _ := ret[0].(armnetwork.IPGroupsClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockIPGroupsClientMockRecorder) Get(ctx, resourceGroupName, ipGroupsName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockIPGroupsClient)(nil).Get), ctx, resourceGroupName, ipGroupsName, options) +} + +// NewListByResourceGroupPager mocks base method. +func (m *MockIPGroupsClient) NewListByResourceGroupPager(resourceGroupName string, options *armnetwork.IPGroupsClientListByResourceGroupOptions) clients.IPGroupsPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListByResourceGroupPager", resourceGroupName, options) + ret0, _ := ret[0].(clients.IPGroupsPager) + return ret0 +} + +// NewListByResourceGroupPager indicates an expected call of NewListByResourceGroupPager. +func (mr *MockIPGroupsClientMockRecorder) NewListByResourceGroupPager(resourceGroupName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListByResourceGroupPager", reflect.TypeOf((*MockIPGroupsClient)(nil).NewListByResourceGroupPager), resourceGroupName, options) +} diff --git a/sources/azure/shared/mocks/mock_local_network_gateways_client.go b/sources/azure/shared/mocks/mock_local_network_gateways_client.go new file mode 100644 index 00000000..ad625817 --- /dev/null +++ b/sources/azure/shared/mocks/mock_local_network_gateways_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: local-network-gateways-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_local_network_gateways_client.go -package=mocks -source=local-network-gateways-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armnetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockLocalNetworkGatewaysClient is a mock of LocalNetworkGatewaysClient interface. +type MockLocalNetworkGatewaysClient struct { + ctrl *gomock.Controller + recorder *MockLocalNetworkGatewaysClientMockRecorder + isgomock struct{} +} + +// MockLocalNetworkGatewaysClientMockRecorder is the mock recorder for MockLocalNetworkGatewaysClient. +type MockLocalNetworkGatewaysClientMockRecorder struct { + mock *MockLocalNetworkGatewaysClient +} + +// NewMockLocalNetworkGatewaysClient creates a new mock instance. +func NewMockLocalNetworkGatewaysClient(ctrl *gomock.Controller) *MockLocalNetworkGatewaysClient { + mock := &MockLocalNetworkGatewaysClient{ctrl: ctrl} + mock.recorder = &MockLocalNetworkGatewaysClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockLocalNetworkGatewaysClient) EXPECT() *MockLocalNetworkGatewaysClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockLocalNetworkGatewaysClient) Get(ctx context.Context, resourceGroupName, localNetworkGatewayName string, options *armnetwork.LocalNetworkGatewaysClientGetOptions) (armnetwork.LocalNetworkGatewaysClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, localNetworkGatewayName, options) + ret0, _ := ret[0].(armnetwork.LocalNetworkGatewaysClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockLocalNetworkGatewaysClientMockRecorder) Get(ctx, resourceGroupName, localNetworkGatewayName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockLocalNetworkGatewaysClient)(nil).Get), ctx, resourceGroupName, localNetworkGatewayName, options) +} + +// NewListPager mocks base method. +func (m *MockLocalNetworkGatewaysClient) NewListPager(resourceGroupName string, options *armnetwork.LocalNetworkGatewaysClientListOptions) clients.LocalNetworkGatewaysPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListPager", resourceGroupName, options) + ret0, _ := ret[0].(clients.LocalNetworkGatewaysPager) + return ret0 +} + +// NewListPager indicates an expected call of NewListPager. +func (mr *MockLocalNetworkGatewaysClientMockRecorder) NewListPager(resourceGroupName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListPager", reflect.TypeOf((*MockLocalNetworkGatewaysClient)(nil).NewListPager), resourceGroupName, options) +} diff --git a/sources/azure/shared/mocks/mock_network_watchers_client.go b/sources/azure/shared/mocks/mock_network_watchers_client.go new file mode 100644 index 00000000..cdc8a973 --- /dev/null +++ b/sources/azure/shared/mocks/mock_network_watchers_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: network-watchers-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_network_watchers_client.go -package=mocks -source=network-watchers-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armnetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockNetworkWatchersClient is a mock of NetworkWatchersClient interface. +type MockNetworkWatchersClient struct { + ctrl *gomock.Controller + recorder *MockNetworkWatchersClientMockRecorder + isgomock struct{} +} + +// MockNetworkWatchersClientMockRecorder is the mock recorder for MockNetworkWatchersClient. +type MockNetworkWatchersClientMockRecorder struct { + mock *MockNetworkWatchersClient +} + +// NewMockNetworkWatchersClient creates a new mock instance. +func NewMockNetworkWatchersClient(ctrl *gomock.Controller) *MockNetworkWatchersClient { + mock := &MockNetworkWatchersClient{ctrl: ctrl} + mock.recorder = &MockNetworkWatchersClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockNetworkWatchersClient) EXPECT() *MockNetworkWatchersClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockNetworkWatchersClient) Get(ctx context.Context, resourceGroupName, networkWatcherName string, options *armnetwork.WatchersClientGetOptions) (armnetwork.WatchersClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, networkWatcherName, options) + ret0, _ := ret[0].(armnetwork.WatchersClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockNetworkWatchersClientMockRecorder) Get(ctx, resourceGroupName, networkWatcherName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockNetworkWatchersClient)(nil).Get), ctx, resourceGroupName, networkWatcherName, options) +} + +// NewListPager mocks base method. +func (m *MockNetworkWatchersClient) NewListPager(resourceGroupName string, options *armnetwork.WatchersClientListOptions) clients.NetworkWatchersPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListPager", resourceGroupName, options) + ret0, _ := ret[0].(clients.NetworkWatchersPager) + return ret0 +} + +// NewListPager indicates an expected call of NewListPager. +func (mr *MockNetworkWatchersClientMockRecorder) NewListPager(resourceGroupName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListPager", reflect.TypeOf((*MockNetworkWatchersClient)(nil).NewListPager), resourceGroupName, options) +} diff --git a/sources/azure/shared/mocks/mock_operational_insights_workspace_client.go b/sources/azure/shared/mocks/mock_operational_insights_workspace_client.go new file mode 100644 index 00000000..c4e12017 --- /dev/null +++ b/sources/azure/shared/mocks/mock_operational_insights_workspace_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: operational-insights-workspace-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_operational_insights_workspace_client.go -package=mocks -source=operational-insights-workspace-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armoperationalinsights "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/operationalinsights/armoperationalinsights" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockOperationalInsightsWorkspaceClient is a mock of OperationalInsightsWorkspaceClient interface. +type MockOperationalInsightsWorkspaceClient struct { + ctrl *gomock.Controller + recorder *MockOperationalInsightsWorkspaceClientMockRecorder + isgomock struct{} +} + +// MockOperationalInsightsWorkspaceClientMockRecorder is the mock recorder for MockOperationalInsightsWorkspaceClient. +type MockOperationalInsightsWorkspaceClientMockRecorder struct { + mock *MockOperationalInsightsWorkspaceClient +} + +// NewMockOperationalInsightsWorkspaceClient creates a new mock instance. +func NewMockOperationalInsightsWorkspaceClient(ctrl *gomock.Controller) *MockOperationalInsightsWorkspaceClient { + mock := &MockOperationalInsightsWorkspaceClient{ctrl: ctrl} + mock.recorder = &MockOperationalInsightsWorkspaceClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockOperationalInsightsWorkspaceClient) EXPECT() *MockOperationalInsightsWorkspaceClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockOperationalInsightsWorkspaceClient) Get(ctx context.Context, resourceGroupName, workspaceName string, options *armoperationalinsights.WorkspacesClientGetOptions) (armoperationalinsights.WorkspacesClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, workspaceName, options) + ret0, _ := ret[0].(armoperationalinsights.WorkspacesClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockOperationalInsightsWorkspaceClientMockRecorder) Get(ctx, resourceGroupName, workspaceName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockOperationalInsightsWorkspaceClient)(nil).Get), ctx, resourceGroupName, workspaceName, options) +} + +// NewListByResourceGroupPager mocks base method. +func (m *MockOperationalInsightsWorkspaceClient) NewListByResourceGroupPager(resourceGroupName string, options *armoperationalinsights.WorkspacesClientListByResourceGroupOptions) clients.OperationalInsightsWorkspacePager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListByResourceGroupPager", resourceGroupName, options) + ret0, _ := ret[0].(clients.OperationalInsightsWorkspacePager) + return ret0 +} + +// NewListByResourceGroupPager indicates an expected call of NewListByResourceGroupPager. +func (mr *MockOperationalInsightsWorkspaceClientMockRecorder) NewListByResourceGroupPager(resourceGroupName, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListByResourceGroupPager", reflect.TypeOf((*MockOperationalInsightsWorkspaceClient)(nil).NewListByResourceGroupPager), resourceGroupName, options) +} diff --git a/sources/azure/shared/mocks/mock_private_link_services_client.go b/sources/azure/shared/mocks/mock_private_link_services_client.go new file mode 100644 index 00000000..c3775bd3 --- /dev/null +++ b/sources/azure/shared/mocks/mock_private_link_services_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: private-link-services-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_private_link_services_client.go -package=mocks -source=private-link-services-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armnetwork "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/network/armnetwork/v9" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockPrivateLinkServicesClient is a mock of PrivateLinkServicesClient interface. +type MockPrivateLinkServicesClient struct { + ctrl *gomock.Controller + recorder *MockPrivateLinkServicesClientMockRecorder + isgomock struct{} +} + +// MockPrivateLinkServicesClientMockRecorder is the mock recorder for MockPrivateLinkServicesClient. +type MockPrivateLinkServicesClientMockRecorder struct { + mock *MockPrivateLinkServicesClient +} + +// NewMockPrivateLinkServicesClient creates a new mock instance. +func NewMockPrivateLinkServicesClient(ctrl *gomock.Controller) *MockPrivateLinkServicesClient { + mock := &MockPrivateLinkServicesClient{ctrl: ctrl} + mock.recorder = &MockPrivateLinkServicesClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPrivateLinkServicesClient) EXPECT() *MockPrivateLinkServicesClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockPrivateLinkServicesClient) Get(ctx context.Context, resourceGroupName, serviceName string) (armnetwork.PrivateLinkServicesClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, serviceName) + ret0, _ := ret[0].(armnetwork.PrivateLinkServicesClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockPrivateLinkServicesClientMockRecorder) Get(ctx, resourceGroupName, serviceName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPrivateLinkServicesClient)(nil).Get), ctx, resourceGroupName, serviceName) +} + +// List mocks base method. +func (m *MockPrivateLinkServicesClient) List(resourceGroupName string) clients.PrivateLinkServicesPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "List", resourceGroupName) + ret0, _ := ret[0].(clients.PrivateLinkServicesPager) + return ret0 +} + +// List indicates an expected call of List. +func (mr *MockPrivateLinkServicesClientMockRecorder) List(resourceGroupName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "List", reflect.TypeOf((*MockPrivateLinkServicesClient)(nil).List), resourceGroupName) +} diff --git a/sources/azure/shared/mocks/mock_role_definitions_client.go b/sources/azure/shared/mocks/mock_role_definitions_client.go new file mode 100644 index 00000000..abf6862c --- /dev/null +++ b/sources/azure/shared/mocks/mock_role_definitions_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: role-definitions-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_role_definitions_client.go -package=mocks -source=role-definitions-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armauthorization "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/authorization/armauthorization/v3" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockRoleDefinitionsClient is a mock of RoleDefinitionsClient interface. +type MockRoleDefinitionsClient struct { + ctrl *gomock.Controller + recorder *MockRoleDefinitionsClientMockRecorder + isgomock struct{} +} + +// MockRoleDefinitionsClientMockRecorder is the mock recorder for MockRoleDefinitionsClient. +type MockRoleDefinitionsClientMockRecorder struct { + mock *MockRoleDefinitionsClient +} + +// NewMockRoleDefinitionsClient creates a new mock instance. +func NewMockRoleDefinitionsClient(ctrl *gomock.Controller) *MockRoleDefinitionsClient { + mock := &MockRoleDefinitionsClient{ctrl: ctrl} + mock.recorder = &MockRoleDefinitionsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRoleDefinitionsClient) EXPECT() *MockRoleDefinitionsClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockRoleDefinitionsClient) Get(ctx context.Context, scope, roleDefinitionID string, options *armauthorization.RoleDefinitionsClientGetOptions) (armauthorization.RoleDefinitionsClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, scope, roleDefinitionID, options) + ret0, _ := ret[0].(armauthorization.RoleDefinitionsClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockRoleDefinitionsClientMockRecorder) Get(ctx, scope, roleDefinitionID, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockRoleDefinitionsClient)(nil).Get), ctx, scope, roleDefinitionID, options) +} + +// NewListPager mocks base method. +func (m *MockRoleDefinitionsClient) NewListPager(scope string, options *armauthorization.RoleDefinitionsClientListOptions) clients.RoleDefinitionsPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListPager", scope, options) + ret0, _ := ret[0].(clients.RoleDefinitionsPager) + return ret0 +} + +// NewListPager indicates an expected call of NewListPager. +func (mr *MockRoleDefinitionsClientMockRecorder) NewListPager(scope, options any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListPager", reflect.TypeOf((*MockRoleDefinitionsClient)(nil).NewListPager), scope, options) +} diff --git a/sources/azure/shared/mocks/mock_sql_failover_groups_client.go b/sources/azure/shared/mocks/mock_sql_failover_groups_client.go new file mode 100644 index 00000000..0ed8e7c7 --- /dev/null +++ b/sources/azure/shared/mocks/mock_sql_failover_groups_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: sql-failover-groups-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_sql_failover_groups_client.go -package=mocks -source=sql-failover-groups-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armsql "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockSqlFailoverGroupsClient is a mock of SqlFailoverGroupsClient interface. +type MockSqlFailoverGroupsClient struct { + ctrl *gomock.Controller + recorder *MockSqlFailoverGroupsClientMockRecorder + isgomock struct{} +} + +// MockSqlFailoverGroupsClientMockRecorder is the mock recorder for MockSqlFailoverGroupsClient. +type MockSqlFailoverGroupsClientMockRecorder struct { + mock *MockSqlFailoverGroupsClient +} + +// NewMockSqlFailoverGroupsClient creates a new mock instance. +func NewMockSqlFailoverGroupsClient(ctrl *gomock.Controller) *MockSqlFailoverGroupsClient { + mock := &MockSqlFailoverGroupsClient{ctrl: ctrl} + mock.recorder = &MockSqlFailoverGroupsClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSqlFailoverGroupsClient) EXPECT() *MockSqlFailoverGroupsClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockSqlFailoverGroupsClient) Get(ctx context.Context, resourceGroupName, serverName, failoverGroupName string) (armsql.FailoverGroupsClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, serverName, failoverGroupName) + ret0, _ := ret[0].(armsql.FailoverGroupsClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockSqlFailoverGroupsClientMockRecorder) Get(ctx, resourceGroupName, serverName, failoverGroupName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSqlFailoverGroupsClient)(nil).Get), ctx, resourceGroupName, serverName, failoverGroupName) +} + +// ListByServer mocks base method. +func (m *MockSqlFailoverGroupsClient) ListByServer(ctx context.Context, resourceGroupName, serverName string) clients.SqlFailoverGroupsPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListByServer", ctx, resourceGroupName, serverName) + ret0, _ := ret[0].(clients.SqlFailoverGroupsPager) + return ret0 +} + +// ListByServer indicates an expected call of ListByServer. +func (mr *MockSqlFailoverGroupsClientMockRecorder) ListByServer(ctx, resourceGroupName, serverName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListByServer", reflect.TypeOf((*MockSqlFailoverGroupsClient)(nil).ListByServer), ctx, resourceGroupName, serverName) +} diff --git a/sources/azure/shared/mocks/mock_sql_server_keys_client.go b/sources/azure/shared/mocks/mock_sql_server_keys_client.go new file mode 100644 index 00000000..e48679a8 --- /dev/null +++ b/sources/azure/shared/mocks/mock_sql_server_keys_client.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: sql-server-keys-client.go +// +// Generated by this command: +// +// mockgen -destination=../shared/mocks/mock_sql_server_keys_client.go -package=mocks -source=sql-server-keys-client.go +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + armsql "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/sql/armsql/v2" + clients "github.com/overmindtech/cli/sources/azure/clients" + gomock "go.uber.org/mock/gomock" +) + +// MockSqlServerKeysClient is a mock of SqlServerKeysClient interface. +type MockSqlServerKeysClient struct { + ctrl *gomock.Controller + recorder *MockSqlServerKeysClientMockRecorder + isgomock struct{} +} + +// MockSqlServerKeysClientMockRecorder is the mock recorder for MockSqlServerKeysClient. +type MockSqlServerKeysClientMockRecorder struct { + mock *MockSqlServerKeysClient +} + +// NewMockSqlServerKeysClient creates a new mock instance. +func NewMockSqlServerKeysClient(ctrl *gomock.Controller) *MockSqlServerKeysClient { + mock := &MockSqlServerKeysClient{ctrl: ctrl} + mock.recorder = &MockSqlServerKeysClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSqlServerKeysClient) EXPECT() *MockSqlServerKeysClientMockRecorder { + return m.recorder +} + +// Get mocks base method. +func (m *MockSqlServerKeysClient) Get(ctx context.Context, resourceGroupName, serverName, keyName string) (armsql.ServerKeysClientGetResponse, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", ctx, resourceGroupName, serverName, keyName) + ret0, _ := ret[0].(armsql.ServerKeysClientGetResponse) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockSqlServerKeysClientMockRecorder) Get(ctx, resourceGroupName, serverName, keyName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSqlServerKeysClient)(nil).Get), ctx, resourceGroupName, serverName, keyName) +} + +// NewListByServerPager mocks base method. +func (m *MockSqlServerKeysClient) NewListByServerPager(resourceGroupName, serverName string) clients.SqlServerKeysPager { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewListByServerPager", resourceGroupName, serverName) + ret0, _ := ret[0].(clients.SqlServerKeysPager) + return ret0 +} + +// NewListByServerPager indicates an expected call of NewListByServerPager. +func (mr *MockSqlServerKeysClientMockRecorder) NewListByServerPager(resourceGroupName, serverName any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewListByServerPager", reflect.TypeOf((*MockSqlServerKeysClient)(nil).NewListByServerPager), resourceGroupName, serverName) +} diff --git a/sources/azure/shared/models.go b/sources/azure/shared/models.go index 2ed1b1ef..f8504a8e 100644 --- a/sources/azure/shared/models.go +++ b/sources/azure/shared/models.go @@ -53,6 +53,9 @@ const ( // OperationalInsights OperationalInsights shared.API = "operationalinsights" // Microsoft.OperationalInsights + // Insights (Azure Monitor) + Insights shared.API = "insights" // microsoft.insights + // ExtendedLocation (custom locations, edge zones) ExtendedLocation shared.API = "extendedlocation" // Microsoft.ExtendedLocation ) @@ -125,6 +128,8 @@ const ( SecurityRule shared.Resource = "security-rule" DefaultSecurityRule shared.Resource = "default-security-rule" IPGroup shared.Resource = "ip-group" + Firewall shared.Resource = "firewall" + FirewallPolicy shared.Resource = "firewall-policy" RouteTable shared.Resource = "route-table" Route shared.Resource = "route" VirtualNetworkGateway shared.Resource = "virtual-network-gateway" @@ -236,8 +241,16 @@ const ( RoleAssignment shared.Resource = "role-assignment" RoleDefinition shared.Resource = "role-definition" + // Resources (subscriptions, resource groups) + Subscription shared.Resource = "subscription" + ResourceGroup shared.Resource = "resource-group" + // OperationalInsights resources Workspace shared.Resource = "workspace" + Cluster shared.Resource = "cluster" + + // Insights (Azure Monitor) resources + PrivateLinkScopeScopedResource shared.Resource = "private-link-scope-scoped-resource" // ExtendedLocation resources CustomLocation shared.Resource = "custom-location" diff --git a/sources/azure/shared/utils.go b/sources/azure/shared/utils.go index 09346736..e344ed4b 100644 --- a/sources/azure/shared/utils.go +++ b/sources/azure/shared/utils.go @@ -55,6 +55,7 @@ func GetResourceIDPathKeys(resourceType string) []string { "azure-batch-batch-pool": {"batchAccounts", "pools"}, // "/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Batch/batchAccounts/{accountName}/pools/{poolName}", "azure-network-dns-record-set": {"dnszones"}, // "/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Network/dnszones/{zoneName}/{recordType}/{relativeRecordSetName}" "azure-elasticsan-elastic-san-volume-group": {"elasticSans", "volumegroups"}, // "/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.ElasticSan/elasticSans/{elasticSanName}/volumegroups/{volumeGroupName}" + "azure-elasticsan-volume": {"elasticSans", "volumegroups", "volumes"}, // "/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.ElasticSan/elasticSans/{elasticSanName}/volumegroups/{volumeGroupName}/volumes/{volumeName}" "azure-elasticsan-elastic-san-volume-snapshot": {"elasticSans", "volumegroups", "snapshots"}, // "/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.ElasticSan/elasticSans/{elasticSanName}/volumegroups/{volumeGroupName}/snapshots/{snapshotName}" "azure-compute-disk-access-private-endpoint-connection": {"diskAccesses", "privateEndpointConnections"}, // "/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Compute/diskAccesses/{diskAccessName}/privateEndpointConnections/{connectionName}" "azure-network-dns-virtual-network-link": {"privateDnsZones", "virtualNetworkLinks"}, // "/subscriptions/{sub}/resourceGroups/{rg}/providers/Microsoft.Network/privateDnsZones/{zoneName}/virtualNetworkLinks/{linkName}" @@ -419,6 +420,25 @@ func ExtractSubscriptionIDFromResourceID(resourceID string) string { return "" } +// ExtractResourceGroupFromResourceID extracts the resource group name from an Azure resource ID +// Azure resource IDs follow the format: +// /subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/... +// Returns empty string if the resource ID doesn't contain a resource group +func ExtractResourceGroupFromResourceID(resourceID string) string { + if resourceID == "" { + return "" + } + + parts := strings.Split(strings.Trim(resourceID, "/"), "/") + for i, part := range parts { + if strings.EqualFold(part, "resourceGroups") && i+1 < len(parts) { + return parts[i+1] + } + } + + return "" +} + // ExtractScopeFromResourceID extracts the scope (subscription.resourceGroup) from an Azure resource ID // Azure resource IDs follow the format: // /subscriptions/{subscriptionId}/resourceGroups/{resourceGroupName}/providers/... diff --git a/sources/gcp/build/package/Dockerfile b/sources/gcp/build/package/Dockerfile index f90fad0a..10958189 100644 --- a/sources/gcp/build/package/Dockerfile +++ b/sources/gcp/build/package/Dockerfile @@ -10,8 +10,12 @@ RUN apk upgrade --no-cache && apk add --no-cache git WORKDIR /workspace -# Copy the go source -COPY . . +COPY go.mod go.sum ./ +RUN --mount=type=cache,target=/go/pkg \ + go mod download + +COPY go/ go/ +COPY sources/ sources/ # Build RUN --mount=type=cache,target=/go/pkg \ diff --git a/sources/gcp/proc/proc.go b/sources/gcp/proc/proc.go index d59b1ff6..d07766f4 100644 --- a/sources/gcp/proc/proc.go +++ b/sources/gcp/proc/proc.go @@ -22,6 +22,7 @@ import ( "github.com/overmindtech/cli/go/discovery" "github.com/overmindtech/cli/go/sdp-go" "github.com/overmindtech/cli/go/sdpcache" + "github.com/overmindtech/cli/go/tracing" "github.com/overmindtech/cli/sources/gcp/dynamic" _ "github.com/overmindtech/cli/sources/gcp/dynamic/adapters" // Import all adapters to register them "github.com/overmindtech/cli/sources/gcp/manual" @@ -512,7 +513,9 @@ func InitializeAdapters(ctx context.Context, engine *discovery.Engine, cfg *GCPC // Run initial permission check before starting the source to fail fast if // we don't have the required permissions. This validates that we can access // the Cloud Resource Manager API for all configured projects. - result, err := healthChecker.Check(ctx) + checkCtx, checkSpan := tracing.Tracer().Start(ctx, "InitializeAdapters.HealthCheck") + result, err := healthChecker.Check(checkCtx) + checkSpan.End() if err != nil { log.WithContext(ctx).WithError(err).WithFields(log.Fields{ "ovm.source.type": "gcp", diff --git a/sources/snapshot/build/package/Dockerfile b/sources/snapshot/build/package/Dockerfile index 8780ed41..04c7be06 100644 --- a/sources/snapshot/build/package/Dockerfile +++ b/sources/snapshot/build/package/Dockerfile @@ -10,8 +10,13 @@ RUN apk upgrade --no-cache && apk add --no-cache git WORKDIR /workspace -# Copy the go source -COPY . . +COPY go.mod go.sum ./ +RUN --mount=type=cache,target=/go/pkg \ + go mod download + +COPY go/ go/ +COPY sources/ sources/ +COPY docs.overmind.tech/docs/sources/ docs.overmind.tech/docs/sources/ # Build RUN --mount=type=cache,target=/go/pkg \ diff --git a/stdlib-source/build/package/Dockerfile b/stdlib-source/build/package/Dockerfile index 0ea8527d..cff25871 100644 --- a/stdlib-source/build/package/Dockerfile +++ b/stdlib-source/build/package/Dockerfile @@ -6,12 +6,16 @@ ARG BUILD_VERSION ARG BUILD_COMMIT # required for accessing the private dependencies and generating version descriptor -RUN apk upgrade --no-cache && apk add --no-cache git curl +RUN apk upgrade --no-cache && apk add --no-cache git WORKDIR /workspace -# Copy the go source -COPY . . +COPY go.mod go.sum ./ +RUN --mount=type=cache,target=/go/pkg \ + go mod download + +COPY go/ go/ +COPY stdlib-source/ stdlib-source/ # Build RUN --mount=type=cache,target=/go/pkg \