diff --git a/flow/activities/snapshot_activity.go b/flow/activities/snapshot_activity.go index fb47a34e0..1eb1c87da 100644 --- a/flow/activities/snapshot_activity.go +++ b/flow/activities/snapshot_activity.go @@ -115,11 +115,25 @@ func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, flo } defer connClose(ctx) + logger := internal.LoggerFromCtx(ctx) + exportSnapshotOutput, tx, err := conn.ExportTxSnapshot(ctx, flowName, env) if err != nil { return err } + // Ensure transaction is always finished, even if activity is terminated + defer func() { + a.SnapshotStatesMutex.Lock() + delete(a.TxSnapshotStates, sessionID) + a.SnapshotStatesMutex.Unlock() + if tx != nil { + if err := conn.FinishExport(tx); err != nil { + logger.Error("finish export error in defer", slog.Any("error", err)) + } + } + }() + a.SnapshotStatesMutex.Lock() if exportSnapshotOutput != nil { a.TxSnapshotStates[sessionID] = TxSnapshotState{ @@ -130,21 +144,18 @@ func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, flo } a.SnapshotStatesMutex.Unlock() - logger := internal.LoggerFromCtx(ctx) start := time.Now() + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + for { logger.Info("maintaining export snapshot transaction", slog.Int64("seconds", int64(time.Since(start).Round(time.Second)/time.Second))) - if ctx.Err() != nil { - a.SnapshotStatesMutex.Lock() - delete(a.TxSnapshotStates, sessionID) - a.SnapshotStatesMutex.Unlock() - if err := conn.FinishExport(tx); err != nil { - logger.Error("finish export error", slog.Any("error", err)) - return err - } + select { + case <-ctx.Done(): return nil + case <-ticker.C: + // Continue loop } - time.Sleep(time.Minute) } } diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index b31fdcee7..982fcc44f 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -7,6 +7,7 @@ import ( "fmt" "log/slog" "maps" + "net" "slices" "strings" "sync" @@ -200,6 +201,17 @@ func (c *PostgresConnector) SetupReplConn(ctx context.Context, env map[string]st return nil } +// clearConnectionDeadline clears any TCP deadline set on the underlying net.Conn. +// This is necessary when using context.Background() for database operations after +// a cancellable context may have triggered pgx's ContextWatcher to set a deadline. +func clearConnectionDeadline(conn interface{ Conn() net.Conn }, logger log.Logger, operation string) { + if clearErr := conn.Conn().SetDeadline(time.Time{}); clearErr != nil { + logger.Warn("failed to clear connection deadline", + slog.String("operation", operation), + slog.Any("error", clearErr)) + } +} + // To keep connection alive between sync batches. // By default postgres drops connection after 1 minute of inactivity. func (c *PostgresConnector) ReplPing(ctx context.Context) error { @@ -1447,6 +1459,10 @@ func (c *PostgresConnector) FinishExport(tx any) error { return nil } pgtx := tx.(pgx.Tx) + + // Clear any deadline set by cancelled context to ensure commit can proceed + clearConnectionDeadline(pgtx.Conn().PgConn(), c.logger, "finishing export") + timeout, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() return pgtx.Commit(timeout) diff --git a/flow/connectors/postgres/qrep.go b/flow/connectors/postgres/qrep.go index 8a5a55bf7..c20356474 100644 --- a/flow/connectors/postgres/qrep.go +++ b/flow/connectors/postgres/qrep.go @@ -482,7 +482,13 @@ func syncQRepRecords( if err != nil { return 0, nil, fmt.Errorf("failed to create tx pool: %w", err) } - defer txConn.Close(ctx) + defer func() { + closeCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := txConn.Close(closeCtx); err != nil { + c.logger.Warn("failed to close transaction connection", slog.Any("error", err)) + } + }() if err := shared.RegisterExtensions(ctx, txConn, config.Version); err != nil { return 0, nil, fmt.Errorf("failed to register extensions: %w", err) @@ -494,6 +500,13 @@ func syncQRepRecords( } defer shared.RollbackTx(tx, c.logger) + // Clear any existing deadline at the start to ensure clean state + clearConnectionDeadline(txConn.PgConn(), c.logger, "qrep start") + + // Clear any deadline set during execution to ensure commit/rollback can proceed + // Must happen regardless of function exit path, so use defer + defer clearConnectionDeadline(txConn.PgConn(), c.logger, "qrep cleanup") + // Step 2: Insert records into destination table var numRowsSynced int64 @@ -503,7 +516,8 @@ func syncQRepRecords( if writeMode != nil && writeMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE { // Truncate destination table before copying records c.logger.Info(fmt.Sprintf("Truncating table %s for overwrite mode", dstTable), syncLog) - _, err = c.execWithLoggingTx(ctx, + // Use context.Background() to prevent ContextWatcher creation + _, err = c.execWithLoggingTx(context.Background(), "TRUNCATE TABLE "+dstTable.String(), tx) if err != nil { return -1, nil, fmt.Errorf("failed to TRUNCATE table before copy: %w", err) @@ -522,7 +536,8 @@ func syncQRepRecords( common.QuoteIdentifier(syncedAtCol), common.QuoteIdentifier(syncedAtCol), ) - if _, err := tx.Exec(ctx, updateSyncedAtStmt); err != nil { + // Use context.Background() to prevent ContextWatcher creation + if _, err := tx.Exec(context.Background(), updateSyncedAtStmt); err != nil { return -1, nil, fmt.Errorf("failed to update synced_at column: %w", err) } } @@ -534,7 +549,8 @@ func syncQRepRecords( // From PG docs: The cost of setting a large value in sessions that do not actually need many // temporary buffers is only a buffer descriptor, or about 64 bytes, per increment in temp_buffers. - if _, err := tx.Exec(ctx, "SET temp_buffers = '4GB';"); err != nil { + // Use context.Background() to prevent ContextWatcher creation + if _, err := tx.Exec(context.Background(), "SET temp_buffers = '4GB';"); err != nil { return -1, nil, fmt.Errorf("failed to set temp_buffers: %w", err) } @@ -546,7 +562,8 @@ func syncQRepRecords( c.logger.Info(fmt.Sprintf("Creating staging table %s - '%s'", stagingTableName, createStagingTableStmt), syncLog) - if _, err := c.execWithLoggingTx(ctx, createStagingTableStmt, tx); err != nil { + // Use context.Background() to prevent ContextWatcher creation + if _, err := c.execWithLoggingTx(context.Background(), createStagingTableStmt, tx); err != nil { return -1, nil, fmt.Errorf("failed to create staging table: %w", err) } @@ -594,7 +611,8 @@ func syncQRepRecords( setClause, ) c.logger.Info("Performing upsert operation", slog.String("upsertStmt", upsertStmt), syncLog) - if _, err := tx.Exec(ctx, upsertStmt); err != nil { + // Use context.Background() to prevent ContextWatcher creation + if _, err := tx.Exec(context.Background(), upsertStmt); err != nil { return -1, nil, fmt.Errorf("failed to perform upsert operation: %w", err) } } @@ -613,8 +631,9 @@ func syncQRepRecords( metadataTableIdentifier.Sanitize(), ) c.logger.Info("Executing transaction inside QRep sync", syncLog) + // Use context.Background() to prevent ContextWatcher creation if _, err := tx.Exec( - ctx, + context.Background(), insertMetadataStmt, flowJobName, partitionID, @@ -625,7 +644,9 @@ func syncQRepRecords( return -1, nil, fmt.Errorf("failed to execute statements in a transaction: %w", err) } - if err := tx.Commit(ctx); err != nil { + commitCtx, commitCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer commitCancel() + if err := tx.Commit(commitCtx); err != nil { return -1, nil, fmt.Errorf("failed to commit transaction: %w", err) } diff --git a/flow/connectors/postgres/sink_pg.go b/flow/connectors/postgres/sink_pg.go index 729fe9fce..1d19f8e87 100644 --- a/flow/connectors/postgres/sink_pg.go +++ b/flow/connectors/postgres/sink_pg.go @@ -6,6 +6,7 @@ import ( "io" "log/slog" "strings" + "time" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" @@ -52,10 +53,19 @@ func (p PgCopyWriter) ExecuteQueryWithTx( query string, args ...any, ) (int64, int64, error) { + defer qe.Conn().Close(context.Background()) defer shared.RollbackTx(tx, qe.logger) + // Clear any existing deadline at the start to ensure clean state + clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "sink_pg start") + + // Clear any deadline set during execution to ensure commit/rollback can proceed + // Must happen regardless of function exit path, so use defer + defer clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "sink_pg cleanup") + if qe.snapshot != "" { - if _, err := tx.Exec(ctx, "SET TRANSACTION SNAPSHOT "+utils.QuoteLiteral(qe.snapshot)); err != nil { + // Use context.Background() to prevent ContextWatcher creation + if _, err := tx.Exec(context.Background(), "SET TRANSACTION SNAPSHOT "+utils.QuoteLiteral(qe.snapshot)); err != nil { qe.logger.Error("[pg_query_executor] failed to set snapshot", slog.Any("error", err), slog.String("query", query)) if shared.IsSQLStateError(err, pgerrcode.UndefinedObject, pgerrcode.InvalidParameterValue) { @@ -70,7 +80,8 @@ func (p PgCopyWriter) ExecuteQueryWithTx( } } - norows, err := tx.Query(ctx, query+" limit 0", args...) + // Use context.Background() to prevent ContextWatcher creation + norows, err := tx.Query(context.Background(), query+" limit 0", args...) if err != nil { return 0, 0, err } @@ -90,15 +101,37 @@ func (p PgCopyWriter) ExecuteQueryWithTx( copyQuery := fmt.Sprintf("COPY (%s) TO STDOUT", query) qe.logger.Info("[pg_query_executor] executing copy", slog.String("query", copyQuery)) - ct, err := qe.conn.PgConn().CopyTo(ctx, p.PipeWriter, copyQuery) + + // Monitor context cancellation and close pipe to trigger clean exit + // Use context.Background() for CopyTo to avoid ContextWatcher entirely + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + p.PipeWriter.CloseWithError(ctx.Err()) + case <-done: + } + }() + + // Use Background context to prevent ContextWatcher creation (ctx == context.Background() check) + // Cancellation is handled via pipe closing above, timeout is handled by Temporal activity timeout + ct, err := qe.conn.PgConn().CopyTo(context.Background(), p.PipeWriter, copyQuery) if err != nil { + // Close pipe explicitly to ensure destination side exits cleanly + if closeErr := p.PipeWriter.CloseWithError(err); closeErr != nil { + qe.logger.Warn("[pg_query_executor] failed to close pipe on copy error", + slog.Any("closeError", closeErr), slog.Any("copyError", err)) + } qe.logger.Info("[pg_query_executor] failed to copy", slog.String("copyQuery", copyQuery), slog.Any("error", err)) return 0, 0, fmt.Errorf("[pg_query_executor] failed to copy: %w", err) } qe.logger.Info("Committing transaction") - if err := tx.Commit(ctx); err != nil { + commitCtx, commitCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer commitCancel() + if err := tx.Commit(commitCtx); err != nil { qe.logger.Error("[pg_query_executor] failed to commit transaction", slog.Any("error", err)) return 0, 0, fmt.Errorf("[pg_query_executor] failed to commit transaction: %w", err) } @@ -128,8 +161,27 @@ func (p PgCopyReader) CopyInto(ctx context.Context, c *PostgresConnector, tx pgx for _, col := range cols { quotedCols = append(quotedCols, common.QuoteIdentifier(col)) } + + // Monitor context cancellation and close pipe to trigger clean exit + // Use context.Background() for CopyFrom to avoid ContextWatcher entirely + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + p.PipeReader.CloseWithError(ctx.Err()) + case <-done: + } + }() + + // Clear deadline immediately before CopyFrom as final safeguard against races in BeginTx + // This handles the case where context was cancelled between BeginTx and here + clearConnectionDeadline(tx.Conn().PgConn(), c.logger, "before CopyFrom") + + // Use Background context to prevent ContextWatcher creation (ctx == context.Background() check) + // Cancellation is handled via pipe closing above, timeout is handled by Temporal activity timeout ct, err := tx.Conn().PgConn().CopyFrom( - ctx, + context.Background(), p.PipeReader, fmt.Sprintf("COPY %s (%s) FROM STDIN", table.Sanitize(), strings.Join(quotedCols, ",")), ) diff --git a/flow/connectors/postgres/sink_q.go b/flow/connectors/postgres/sink_q.go index de02dbc92..2ae9da32a 100644 --- a/flow/connectors/postgres/sink_q.go +++ b/flow/connectors/postgres/sink_q.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "math/rand/v2" + "time" "github.com/jackc/pgerrcode" "github.com/jackc/pgx/v5" @@ -31,8 +32,16 @@ func (stream RecordStreamSink) ExecuteQueryWithTx( ) (int64, int64, error) { defer shared.RollbackTx(tx, qe.logger) + // Clear any existing deadline at the start to ensure clean state + clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "sink_q start") + + // Clear any deadline set during execution to ensure commit/rollback can proceed + // Must happen regardless of function exit path, so use defer + defer clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "sink_q cleanup") + if qe.snapshot != "" { - if _, err := tx.Exec(ctx, "SET TRANSACTION SNAPSHOT "+utils.QuoteLiteral(qe.snapshot)); err != nil { + // Use context.Background() to prevent ContextWatcher creation + if _, err := tx.Exec(context.Background(), "SET TRANSACTION SNAPSHOT "+utils.QuoteLiteral(qe.snapshot)); err != nil { qe.logger.Error("[pg_query_executor] failed to set snapshot", slog.Any("error", err), slog.String("query", query)) if shared.IsSQLStateError(err, pgerrcode.UndefinedObject, pgerrcode.InvalidParameterValue) { @@ -53,7 +62,8 @@ func (stream RecordStreamSink) ExecuteQueryWithTx( cursorName := fmt.Sprintf("peerdb_cursor_%d", randomUint) cursorQuery := fmt.Sprintf("DECLARE %s CURSOR FOR %s", cursorName, query) - if _, err := tx.Exec(ctx, cursorQuery, args...); err != nil { + // Use context.Background() to prevent ContextWatcher creation + if _, err := tx.Exec(context.Background(), cursorQuery, args...); err != nil { qe.logger.Info("[pg_query_executor] failed to declare cursor", slog.String("cursorQuery", cursorQuery), slog.Any("args", args), slog.Any("error", err)) return 0, 0, fmt.Errorf("[pg_query_executor] failed to declare cursor: %w", err) @@ -66,7 +76,8 @@ func (stream RecordStreamSink) ExecuteQueryWithTx( slog.Int("channelLen", len(stream.Records))) if !stream.IsSchemaSet() { - schema, schemaDebug, err := qe.cursorToSchema(ctx, tx, cursorName) + // Use context.Background() to prevent ContextWatcher creation + schema, schemaDebug, err := qe.cursorToSchema(context.Background(), tx, cursorName) if err != nil { return 0, 0, err } @@ -74,10 +85,14 @@ func (stream RecordStreamSink) ExecuteQueryWithTx( stream.SetSchemaDebug(schemaDebug) } + // Clear deadline immediately before fetch loop as final safeguard + clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "before fetch loop") + var totalNumRows int64 var totalNumBytes int64 for { - numRows, numBytes, err := qe.processFetchedRows(ctx, query, tx, cursorName, shared.QRepFetchSize, + // Use context.Background() to prevent ContextWatcher creation during fetch + numRows, numBytes, err := qe.processFetchedRows(context.Background(), query, tx, cursorName, shared.QRepFetchSize, stream.DestinationType, stream.QRecordStream) if err != nil { qe.logger.Error("[pg_query_executor] failed to process fetched rows", slog.Any("error", err)) @@ -98,7 +113,9 @@ func (stream RecordStreamSink) ExecuteQueryWithTx( } qe.logger.Info("[pg_query_executor] committing transaction") - if err := tx.Commit(ctx); err != nil { + commitCtx, commitCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer commitCancel() + if err := tx.Commit(commitCtx); err != nil { qe.logger.Error("[pg_query_executor] failed to commit transaction", slog.Any("error", err)) return totalNumRows, totalNumBytes, fmt.Errorf("[pg_query_executor] failed to commit transaction: %w", err) } @@ -116,6 +133,18 @@ func (stream RecordStreamSink) CopyInto(ctx context.Context, _ *PostgresConnecto if err != nil { return 0, err } + + // Monitor context cancellation and close stream to unblock reads + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + stream.QRecordStream.Close(ctx.Err()) + case <-done: + } + }() + return tx.CopyFrom(ctx, table, columnNames, model.NewQRecordCopyFromSource(stream.QRecordStream)) }