Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flow/connectors/utils/avro_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"log/slog"
"os"
"runtime/debug"
"strings"
"sync/atomic"
"time"

Expand Down Expand Up @@ -139,6 +140,8 @@ func (p *peerDBOCFWriter) WriteRecordsToS3(
return AvroFile{}, fmt.Errorf("could not get s3 part size config: %w", err)
}

isGCS := strings.Contains(s3Creds.GetEndpointURL(), "storage.googleapis.com")

// Create the uploader using the AWS SDK v2 manager
uploader := manager.NewUploader(s3svc, func(u *manager.Uploader) {
if partSize > 0 {
Expand All @@ -147,6 +150,9 @@ func (p *peerDBOCFWriter) WriteRecordsToS3(
u.Concurrency = 1
}
}
if isGCS {
u.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
}
})

if _, err := uploader.Upload(ctx, &s3.PutObjectInput{
Expand Down
98 changes: 40 additions & 58 deletions flow/connectors/utils/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ import (
"time"

"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/stscreds"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/sts"
smithyendpoints "github.com/aws/smithy-go/endpoints"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/google/uuid"

"github.com/PeerDB-io/peerdb/flow/generated/protos"
Expand Down Expand Up @@ -364,6 +365,26 @@ func (r *resolverV2) ResolveEndpoint(ctx context.Context, params s3.EndpointPara
}, nil
}

// removeAcceptEncodingMiddleware strips Accept-Encoding before signing.
// GCS's front end mutates this header, which can otherwise invalidate SigV4.
func removeAcceptEncodingMiddleware(stack *middleware.Stack) error {
mw := middleware.FinalizeMiddlewareFunc("RemoveAcceptEncoding", func(
ctx context.Context, input middleware.FinalizeInput, next middleware.FinalizeHandler,
) (middleware.FinalizeOutput, middleware.Metadata, error) {
if req, ok := input.Request.(*smithyhttp.Request); ok {
req.Header.Del("Accept-Encoding")
}
return next.HandleFinalize(ctx, input)
})

// Place immediately before signing so any earlier middleware can't re-add it.
if err := stack.Finalize.Insert(mw, "Signing", middleware.Before); err != nil {
// Fallback if Signing isn't present for some reason.
return stack.Finalize.Add(mw, middleware.Before)
}
return nil
}

func CreateS3Client(ctx context.Context, credsProvider AWSCredentialsProvider) (*s3.Client, error) {
awsCredentials, err := credsProvider.Retrieve(ctx)
if err != nil {
Expand All @@ -375,6 +396,7 @@ func CreateS3Client(ctx context.Context, credsProvider AWSCredentialsProvider) (
Credentials: credsProvider.GetUnderlyingProvider(),
}
if awsCredentials.EndpointUrl != nil && *awsCredentials.EndpointUrl != "" {
isGCS := strings.Contains(*awsCredentials.EndpointUrl, "storage.googleapis.com")
options.BaseEndpoint = awsCredentials.EndpointUrl
options.UsePathStyle = true
url, err := url.Parse(*awsCredentials.EndpointUrl)
Expand All @@ -385,68 +407,28 @@ func CreateS3Client(ctx context.Context, credsProvider AWSCredentialsProvider) (
URL: *url,
}

if strings.Contains(*awsCredentials.EndpointUrl, "storage.googleapis.com") {
// Assign custom client with our own transport
options.HTTPClient = &http.Client{
Transport: &RecalculateV4Signature{
next: http.DefaultTransport,
signer: v4.NewSigner(),
credentials: credsProvider.GetUnderlyingProvider(),
region: options.Region,
},
}
} else {
rootCAs, tlsHost := credsProvider.GetTlsConfig()
if rootCAs != nil || tlsHost != "" {
// start with a clone of DefaultTransport so we keep http2, idle-conns, etc.
tlsConfig, err := common.CreateTlsConfig(tls.VersionTLS13, rootCAs, tlsHost, tlsHost, tlsHost == "")
if err != nil {
return nil, err
}

tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = tlsConfig
options.HTTPClient = &http.Client{Transport: tr}
}
if isGCS {
// GCS S3 compatibility doesn't support the SDK's default CRC32 checksums.
options.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
options.ResponseChecksumValidation = aws.ResponseChecksumValidationWhenRequired
options.APIOptions = append(options.APIOptions, removeAcceptEncodingMiddleware)
}
}

return s3.New(options), nil
}

// RecalculateV4Signature allow GCS over S3, removing Accept-Encoding header from sign
// https://stackoverflow.com/a/74382598/1204665
// https://github.com/aws/aws-sdk-go-v2/issues/1816
type RecalculateV4Signature struct {
next http.RoundTripper
signer *v4.Signer
credentials aws.CredentialsProvider
region string
}

func (lt *RecalculateV4Signature) RoundTrip(req *http.Request) (*http.Response, error) {
// store for later use
acceptEncodingValue := req.Header.Get("Accept-Encoding")

// delete the header so the header doesn't account for in the signature
req.Header.Del("Accept-Encoding")

// sign with the same date
timeString := req.Header.Get("X-Amz-Date")
timeDate, _ := time.Parse("20060102T150405Z", timeString)
rootCAs, tlsHost := credsProvider.GetTlsConfig()
if rootCAs != nil || tlsHost != "" {
// start with a clone of DefaultTransport so we keep http2, idle-conns, etc.
tlsConfig, err := common.CreateTlsConfig(tls.VersionTLS13, rootCAs, tlsHost, tlsHost, tlsHost == "")
if err != nil {
return nil, err
}

creds, err := lt.credentials.Retrieve(req.Context())
if err != nil {
return nil, err
}
if err := lt.signer.SignHTTP(req.Context(), creds, req, v4.GetPayloadHash(req.Context()), "s3", lt.region, timeDate); err != nil {
return nil, err
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = tlsConfig
options.HTTPClient = &http.Client{Transport: tr}
}
}
// Reset Accept-Encoding if desired
req.Header.Set("Accept-Encoding", acceptEncodingValue)

// follows up the original round tripper
return lt.next.RoundTrip(req)
return s3.New(options), nil
}

// Write an empty file and then delete it
Expand Down
Loading