Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions flow/activities/snapshot_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,25 @@ func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, flo
}
defer connClose(ctx)

logger := internal.LoggerFromCtx(ctx)

exportSnapshotOutput, tx, err := conn.ExportTxSnapshot(ctx, flowName, env)
if err != nil {
return err
}

// Ensure transaction is always finished, even if activity is terminated
defer func() {
a.SnapshotStatesMutex.Lock()
delete(a.TxSnapshotStates, sessionID)
a.SnapshotStatesMutex.Unlock()
if tx != nil {
if err := conn.FinishExport(tx); err != nil {
logger.Error("finish export error in defer", slog.Any("error", err))
}
}
}()

a.SnapshotStatesMutex.Lock()
if exportSnapshotOutput != nil {
a.TxSnapshotStates[sessionID] = TxSnapshotState{
Expand All @@ -130,21 +144,18 @@ func (a *SnapshotActivity) MaintainTx(ctx context.Context, sessionID string, flo
}
a.SnapshotStatesMutex.Unlock()

logger := internal.LoggerFromCtx(ctx)
start := time.Now()
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()

for {
logger.Info("maintaining export snapshot transaction", slog.Int64("seconds", int64(time.Since(start).Round(time.Second)/time.Second)))
if ctx.Err() != nil {
a.SnapshotStatesMutex.Lock()
delete(a.TxSnapshotStates, sessionID)
a.SnapshotStatesMutex.Unlock()
if err := conn.FinishExport(tx); err != nil {
logger.Error("finish export error", slog.Any("error", err))
return err
}
select {
case <-ctx.Done():
return nil
case <-ticker.C:
// Continue loop
}
time.Sleep(time.Minute)
}
}

Expand Down
16 changes: 16 additions & 0 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"maps"
"net"
"slices"
"strings"
"sync"
Expand Down Expand Up @@ -200,6 +201,17 @@ func (c *PostgresConnector) SetupReplConn(ctx context.Context, env map[string]st
return nil
}

// clearConnectionDeadline clears any TCP deadline set on the underlying net.Conn.
// This is necessary when using context.Background() for database operations after
// a cancellable context may have triggered pgx's ContextWatcher to set a deadline.
func clearConnectionDeadline(conn interface{ Conn() net.Conn }, logger log.Logger, operation string) {
if clearErr := conn.Conn().SetDeadline(time.Time{}); clearErr != nil {
logger.Warn("failed to clear connection deadline",
slog.String("operation", operation),
slog.Any("error", clearErr))
}
}

// To keep connection alive between sync batches.
// By default postgres drops connection after 1 minute of inactivity.
func (c *PostgresConnector) ReplPing(ctx context.Context) error {
Expand Down Expand Up @@ -1447,6 +1459,10 @@ func (c *PostgresConnector) FinishExport(tx any) error {
return nil
}
pgtx := tx.(pgx.Tx)

// Clear any deadline set by cancelled context to ensure commit can proceed
clearConnectionDeadline(pgtx.Conn().PgConn(), c.logger, "finishing export")

timeout, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
return pgtx.Commit(timeout)
Expand Down
37 changes: 29 additions & 8 deletions flow/connectors/postgres/qrep.go
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,13 @@ func syncQRepRecords(
if err != nil {
return 0, nil, fmt.Errorf("failed to create tx pool: %w", err)
}
defer txConn.Close(ctx)
defer func() {
closeCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := txConn.Close(closeCtx); err != nil {
c.logger.Warn("failed to close transaction connection", slog.Any("error", err))
}
}()

if err := shared.RegisterExtensions(ctx, txConn, config.Version); err != nil {
return 0, nil, fmt.Errorf("failed to register extensions: %w", err)
Expand All @@ -494,6 +500,13 @@ func syncQRepRecords(
}
defer shared.RollbackTx(tx, c.logger)

// Clear any existing deadline at the start to ensure clean state
clearConnectionDeadline(txConn.PgConn(), c.logger, "qrep start")

// Clear any deadline set during execution to ensure commit/rollback can proceed
// Must happen regardless of function exit path, so use defer
defer clearConnectionDeadline(txConn.PgConn(), c.logger, "qrep cleanup")

// Step 2: Insert records into destination table
var numRowsSynced int64

Expand All @@ -503,7 +516,8 @@ func syncQRepRecords(
if writeMode != nil && writeMode.WriteType == protos.QRepWriteType_QREP_WRITE_MODE_OVERWRITE {
// Truncate destination table before copying records
c.logger.Info(fmt.Sprintf("Truncating table %s for overwrite mode", dstTable), syncLog)
_, err = c.execWithLoggingTx(ctx,
// Use context.Background() to prevent ContextWatcher creation
_, err = c.execWithLoggingTx(context.Background(),
"TRUNCATE TABLE "+dstTable.String(), tx)
if err != nil {
return -1, nil, fmt.Errorf("failed to TRUNCATE table before copy: %w", err)
Expand All @@ -522,7 +536,8 @@ func syncQRepRecords(
common.QuoteIdentifier(syncedAtCol),
common.QuoteIdentifier(syncedAtCol),
)
if _, err := tx.Exec(ctx, updateSyncedAtStmt); err != nil {
// Use context.Background() to prevent ContextWatcher creation
if _, err := tx.Exec(context.Background(), updateSyncedAtStmt); err != nil {
return -1, nil, fmt.Errorf("failed to update synced_at column: %w", err)
}
}
Expand All @@ -534,7 +549,8 @@ func syncQRepRecords(

// From PG docs: The cost of setting a large value in sessions that do not actually need many
// temporary buffers is only a buffer descriptor, or about 64 bytes, per increment in temp_buffers.
if _, err := tx.Exec(ctx, "SET temp_buffers = '4GB';"); err != nil {
// Use context.Background() to prevent ContextWatcher creation
if _, err := tx.Exec(context.Background(), "SET temp_buffers = '4GB';"); err != nil {
return -1, nil, fmt.Errorf("failed to set temp_buffers: %w", err)
}

Expand All @@ -546,7 +562,8 @@ func syncQRepRecords(

c.logger.Info(fmt.Sprintf("Creating staging table %s - '%s'",
stagingTableName, createStagingTableStmt), syncLog)
if _, err := c.execWithLoggingTx(ctx, createStagingTableStmt, tx); err != nil {
// Use context.Background() to prevent ContextWatcher creation
if _, err := c.execWithLoggingTx(context.Background(), createStagingTableStmt, tx); err != nil {
return -1, nil, fmt.Errorf("failed to create staging table: %w", err)
}

Expand Down Expand Up @@ -594,7 +611,8 @@ func syncQRepRecords(
setClause,
)
c.logger.Info("Performing upsert operation", slog.String("upsertStmt", upsertStmt), syncLog)
if _, err := tx.Exec(ctx, upsertStmt); err != nil {
// Use context.Background() to prevent ContextWatcher creation
if _, err := tx.Exec(context.Background(), upsertStmt); err != nil {
return -1, nil, fmt.Errorf("failed to perform upsert operation: %w", err)
}
}
Expand All @@ -613,8 +631,9 @@ func syncQRepRecords(
metadataTableIdentifier.Sanitize(),
)
c.logger.Info("Executing transaction inside QRep sync", syncLog)
// Use context.Background() to prevent ContextWatcher creation
if _, err := tx.Exec(
ctx,
context.Background(),
insertMetadataStmt,
flowJobName,
partitionID,
Expand All @@ -625,7 +644,9 @@ func syncQRepRecords(
return -1, nil, fmt.Errorf("failed to execute statements in a transaction: %w", err)
}

if err := tx.Commit(ctx); err != nil {
commitCtx, commitCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer commitCancel()
if err := tx.Commit(commitCtx); err != nil {
return -1, nil, fmt.Errorf("failed to commit transaction: %w", err)
}

Expand Down
62 changes: 57 additions & 5 deletions flow/connectors/postgres/sink_pg.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"log/slog"
"strings"
"time"

"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5"
Expand Down Expand Up @@ -52,10 +53,19 @@ func (p PgCopyWriter) ExecuteQueryWithTx(
query string,
args ...any,
) (int64, int64, error) {
defer qe.Conn().Close(context.Background())
defer shared.RollbackTx(tx, qe.logger)

// Clear any existing deadline at the start to ensure clean state
clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "sink_pg start")

// Clear any deadline set during execution to ensure commit/rollback can proceed
// Must happen regardless of function exit path, so use defer
defer clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "sink_pg cleanup")

if qe.snapshot != "" {
if _, err := tx.Exec(ctx, "SET TRANSACTION SNAPSHOT "+utils.QuoteLiteral(qe.snapshot)); err != nil {
// Use context.Background() to prevent ContextWatcher creation
if _, err := tx.Exec(context.Background(), "SET TRANSACTION SNAPSHOT "+utils.QuoteLiteral(qe.snapshot)); err != nil {
qe.logger.Error("[pg_query_executor] failed to set snapshot",
slog.Any("error", err), slog.String("query", query))
if shared.IsSQLStateError(err, pgerrcode.UndefinedObject, pgerrcode.InvalidParameterValue) {
Expand All @@ -70,7 +80,8 @@ func (p PgCopyWriter) ExecuteQueryWithTx(
}
}

norows, err := tx.Query(ctx, query+" limit 0", args...)
// Use context.Background() to prevent ContextWatcher creation
norows, err := tx.Query(context.Background(), query+" limit 0", args...)
if err != nil {
return 0, 0, err
}
Expand All @@ -90,15 +101,37 @@ func (p PgCopyWriter) ExecuteQueryWithTx(

copyQuery := fmt.Sprintf("COPY (%s) TO STDOUT", query)
qe.logger.Info("[pg_query_executor] executing copy", slog.String("query", copyQuery))
ct, err := qe.conn.PgConn().CopyTo(ctx, p.PipeWriter, copyQuery)

// Monitor context cancellation and close pipe to trigger clean exit
// Use context.Background() for CopyTo to avoid ContextWatcher entirely
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-ctx.Done():
p.PipeWriter.CloseWithError(ctx.Err())
case <-done:
}
}()

// Use Background context to prevent ContextWatcher creation (ctx == context.Background() check)
// Cancellation is handled via pipe closing above, timeout is handled by Temporal activity timeout
ct, err := qe.conn.PgConn().CopyTo(context.Background(), p.PipeWriter, copyQuery)
if err != nil {
// Close pipe explicitly to ensure destination side exits cleanly
if closeErr := p.PipeWriter.CloseWithError(err); closeErr != nil {
qe.logger.Warn("[pg_query_executor] failed to close pipe on copy error",
slog.Any("closeError", closeErr), slog.Any("copyError", err))
}
qe.logger.Info("[pg_query_executor] failed to copy",
slog.String("copyQuery", copyQuery), slog.Any("error", err))
return 0, 0, fmt.Errorf("[pg_query_executor] failed to copy: %w", err)
}

qe.logger.Info("Committing transaction")
if err := tx.Commit(ctx); err != nil {
commitCtx, commitCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer commitCancel()
if err := tx.Commit(commitCtx); err != nil {
qe.logger.Error("[pg_query_executor] failed to commit transaction", slog.Any("error", err))
return 0, 0, fmt.Errorf("[pg_query_executor] failed to commit transaction: %w", err)
}
Expand Down Expand Up @@ -128,8 +161,27 @@ func (p PgCopyReader) CopyInto(ctx context.Context, c *PostgresConnector, tx pgx
for _, col := range cols {
quotedCols = append(quotedCols, common.QuoteIdentifier(col))
}

// Monitor context cancellation and close pipe to trigger clean exit
// Use context.Background() for CopyFrom to avoid ContextWatcher entirely
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-ctx.Done():
p.PipeReader.CloseWithError(ctx.Err())
case <-done:
}
}()

// Clear deadline immediately before CopyFrom as final safeguard against races in BeginTx
// This handles the case where context was cancelled between BeginTx and here
clearConnectionDeadline(tx.Conn().PgConn(), c.logger, "before CopyFrom")

// Use Background context to prevent ContextWatcher creation (ctx == context.Background() check)
// Cancellation is handled via pipe closing above, timeout is handled by Temporal activity timeout
ct, err := tx.Conn().PgConn().CopyFrom(
ctx,
context.Background(),
p.PipeReader,
fmt.Sprintf("COPY %s (%s) FROM STDIN", table.Sanitize(), strings.Join(quotedCols, ",")),
)
Expand Down
39 changes: 34 additions & 5 deletions flow/connectors/postgres/sink_q.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"log/slog"
"math/rand/v2"
"time"

"github.com/jackc/pgerrcode"
"github.com/jackc/pgx/v5"
Expand All @@ -31,8 +32,16 @@ func (stream RecordStreamSink) ExecuteQueryWithTx(
) (int64, int64, error) {
defer shared.RollbackTx(tx, qe.logger)

// Clear any existing deadline at the start to ensure clean state
clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "sink_q start")

// Clear any deadline set during execution to ensure commit/rollback can proceed
// Must happen regardless of function exit path, so use defer
defer clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "sink_q cleanup")

if qe.snapshot != "" {
if _, err := tx.Exec(ctx, "SET TRANSACTION SNAPSHOT "+utils.QuoteLiteral(qe.snapshot)); err != nil {
// Use context.Background() to prevent ContextWatcher creation
if _, err := tx.Exec(context.Background(), "SET TRANSACTION SNAPSHOT "+utils.QuoteLiteral(qe.snapshot)); err != nil {
qe.logger.Error("[pg_query_executor] failed to set snapshot",
slog.Any("error", err), slog.String("query", query))
if shared.IsSQLStateError(err, pgerrcode.UndefinedObject, pgerrcode.InvalidParameterValue) {
Expand All @@ -53,7 +62,8 @@ func (stream RecordStreamSink) ExecuteQueryWithTx(
cursorName := fmt.Sprintf("peerdb_cursor_%d", randomUint)
cursorQuery := fmt.Sprintf("DECLARE %s CURSOR FOR %s", cursorName, query)

if _, err := tx.Exec(ctx, cursorQuery, args...); err != nil {
// Use context.Background() to prevent ContextWatcher creation
if _, err := tx.Exec(context.Background(), cursorQuery, args...); err != nil {
qe.logger.Info("[pg_query_executor] failed to declare cursor",
slog.String("cursorQuery", cursorQuery), slog.Any("args", args), slog.Any("error", err))
return 0, 0, fmt.Errorf("[pg_query_executor] failed to declare cursor: %w", err)
Expand All @@ -66,18 +76,23 @@ func (stream RecordStreamSink) ExecuteQueryWithTx(
slog.Int("channelLen", len(stream.Records)))

if !stream.IsSchemaSet() {
schema, schemaDebug, err := qe.cursorToSchema(ctx, tx, cursorName)
// Use context.Background() to prevent ContextWatcher creation
schema, schemaDebug, err := qe.cursorToSchema(context.Background(), tx, cursorName)
if err != nil {
return 0, 0, err
}
stream.SetSchema(schema)
stream.SetSchemaDebug(schemaDebug)
}

// Clear deadline immediately before fetch loop as final safeguard
clearConnectionDeadline(qe.conn.PgConn(), qe.logger, "before fetch loop")

var totalNumRows int64
var totalNumBytes int64
for {
numRows, numBytes, err := qe.processFetchedRows(ctx, query, tx, cursorName, shared.QRepFetchSize,
// Use context.Background() to prevent ContextWatcher creation during fetch
numRows, numBytes, err := qe.processFetchedRows(context.Background(), query, tx, cursorName, shared.QRepFetchSize,
stream.DestinationType, stream.QRecordStream)
if err != nil {
qe.logger.Error("[pg_query_executor] failed to process fetched rows", slog.Any("error", err))
Expand All @@ -98,7 +113,9 @@ func (stream RecordStreamSink) ExecuteQueryWithTx(
}

qe.logger.Info("[pg_query_executor] committing transaction")
if err := tx.Commit(ctx); err != nil {
commitCtx, commitCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer commitCancel()
if err := tx.Commit(commitCtx); err != nil {
qe.logger.Error("[pg_query_executor] failed to commit transaction", slog.Any("error", err))
return totalNumRows, totalNumBytes, fmt.Errorf("[pg_query_executor] failed to commit transaction: %w", err)
}
Expand All @@ -116,6 +133,18 @@ func (stream RecordStreamSink) CopyInto(ctx context.Context, _ *PostgresConnecto
if err != nil {
return 0, err
}

// Monitor context cancellation and close stream to unblock reads
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-ctx.Done():
stream.QRecordStream.Close(ctx.Err())
case <-done:
}
}()

return tx.CopyFrom(ctx, table, columnNames, model.NewQRecordCopyFromSource(stream.QRecordStream))
}

Expand Down
Loading