Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 70 additions & 33 deletions flow/connectors/postgres/cdc.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"github.com/PeerDB-io/peerdb/flow/shared/types"
)

//nolint:govet // fieldalignment: fields grouped by purpose for readability
type PostgresCDCSource struct {
*PostgresConnector
srcTableIDNameMapping map[uint32]string
Expand All @@ -51,6 +52,8 @@ type PostgresCDCSource struct {

// for partitioned tables, maps child relid to parent relid
childToParentRelIDMapping map[uint32]uint32
idToRelKindMap map[uint32]byte
publishViaPartitionRoot bool

// for storing schema delta audit logs to catalog
catalogPool shared.CatalogPool
Expand Down Expand Up @@ -82,13 +85,22 @@ type PostgresCDCConfig struct {

// Create a new PostgresCDCSource
func (c *PostgresConnector) NewPostgresCDCSource(ctx context.Context, cdcConfig *PostgresCDCConfig) (*PostgresCDCSource, error) {
childToParentRelIDMap, err := getChildToParentRelIDMap(ctx,
childToParentRelIDMap, idToRelKindMap, err := getChildToParentRelIDMap(ctx,
c.conn, slices.Collect(maps.Keys(cdcConfig.SrcTableIDNameMapping)),
cdcConfig.HandleInheritanceForNonPartitionedTables)
if err != nil {
return nil, fmt.Errorf("error getting child to parent relid map: %w", err)
}

var publishViaPartitionRoot bool
if err := c.conn.QueryRow(ctx,
"SELECT COALESCE(pubviaroot, false) FROM pg_publication WHERE pubname=$1",
cdcConfig.Publication,
).Scan(&publishViaPartitionRoot); err != nil {
return nil, fmt.Errorf("error checking publish_via_partition_root for publication %s: %w",
cdcConfig.Publication, err)
}

var schemaNameForRelID map[uint32]string
if cdcConfig.SourceSchemaAsDestinationColumn {
schemaNameForRelID = make(map[uint32]string, len(cdcConfig.TableNameSchemaMapping))
Expand All @@ -107,6 +119,8 @@ func (c *PostgresConnector) NewPostgresCDCSource(ctx context.Context, cdcConfig
publication: cdcConfig.Publication,
commitLock: nil,
childToParentRelIDMapping: childToParentRelIDMap,
idToRelKindMap: idToRelKindMap,
publishViaPartitionRoot: publishViaPartitionRoot,
catalogPool: cdcConfig.CatalogPool,
otelManager: cdcConfig.OtelManager,
hushWarnUnhandledMessageType: make(map[pglogrepl.MessageType]struct{}),
Expand Down Expand Up @@ -136,14 +150,14 @@ func (p *PostgresCDCSource) getSourceSchemaForDestinationColumn(relID uint32, ta

func getChildToParentRelIDMap(ctx context.Context,
conn *pgx.Conn, parentTableOIDs []uint32, handleInheritanceForNonPartitionedTables bool,
) (map[uint32]uint32, error) {
) (map[uint32]uint32, map[uint32]byte, error) {
relkinds := "'p'"
if handleInheritanceForNonPartitionedTables {
relkinds = "'p', 'r'"
}

query := fmt.Sprintf(`
SELECT parent.oid AS parentrelid, child.oid AS childrelid
SELECT parent.oid AS parentrelid, child.oid AS childrelid, parent.relkind
FROM pg_inherits
JOIN pg_class parent ON pg_inherits.inhparent = parent.oid
JOIN pg_class child ON pg_inherits.inhrelid = child.oid
Expand All @@ -152,19 +166,22 @@ func getChildToParentRelIDMap(ctx context.Context,

rows, err := conn.Query(ctx, query, parentTableOIDs)
if err != nil {
return nil, fmt.Errorf("error querying for child to parent relid map: %w", err)
return nil, nil, fmt.Errorf("error querying for child to parent relid map: %w", err)
}

childToParentRelIDMap := make(map[uint32]uint32)
idToRelKindMap := make(map[uint32]byte)
var parentRelID, childRelID pgtype.Uint32
if _, err := pgx.ForEachRow(rows, []any{&parentRelID, &childRelID}, func() error {
var relkind byte
if _, err := pgx.ForEachRow(rows, []any{&parentRelID, &childRelID, &relkind}, func() error {
childToParentRelIDMap[childRelID.Uint32] = parentRelID.Uint32
idToRelKindMap[parentRelID.Uint32] = relkind
return nil
}); err != nil {
return nil, fmt.Errorf("error iterating over child to parent relid map: %w", err)
return nil, nil, fmt.Errorf("error iterating over child to parent relid map: %w", err)
}

return childToParentRelIDMap, nil
return childToParentRelIDMap, idToRelKindMap, nil
}

// replProcessor implements ingesting PostgreSQL logical replication tuples into items.
Expand Down Expand Up @@ -904,8 +921,10 @@ func processMessage[Items model.Items](
p.otelManager.Metrics.CommitLagGauge.Record(ctx, time.Now().UTC().Sub(msg.CommitTime).Microseconds())
p.commitLock = nil
case *pglogrepl.RelationMessage:
originalRelID := msg.RelationID
var parentRelKind byte
// treat all relation messages as corresponding to parent if partitioned.
msg.RelationID, err = p.checkIfUnknownTableInherits(ctx, msg.RelationID)
msg.RelationID, parentRelKind, err = p.checkIfUnknownTableInherits(ctx, msg.RelationID)
if err != nil {
return nil, err
}
Expand All @@ -914,14 +933,22 @@ func processMessage[Items model.Items](
return nil, nil
}

// With publish_via_partition_root = true, PG emits a parent RelationMessage
// followed by a child RelationMessage for each partition. The parent's
// column list matches the tuple data wire format, so skip the child's
// to avoid overwriting with a potentially reordered column definition.
if originalRelID != msg.RelationID && parentRelKind == 'p' && p.publishViaPartitionRoot {
return nil, nil
}

logger.Info("processing RelationMessage",
slog.Any("LSN", currentClientXlogPos),
slog.Uint64("RelationID", uint64(msg.RelationID)),
slog.String("Namespace", msg.Namespace),
slog.String("RelationName", msg.RelationName),
slog.Any("Columns", msg.Columns))

return processRelationMessage[Items](ctx, p, currentClientXlogPos, msg)
return processRelationMessage[Items](ctx, p, currentClientXlogPos, msg, originalRelID)
case *pglogrepl.LogicalDecodingMessage:
logger.Debug("LogicalDecodingMessage",
slog.Bool("Transactional", msg.Transactional),
Expand Down Expand Up @@ -952,22 +979,25 @@ func processInsertMessage[Items model.Items](
processor replProcessor[Items],
customTypeMapping map[uint32]shared.CustomDataType,
) (model.Record[Items], error) {
relID := p.getParentRelIDIfPartitioned(msg.RelationID)
parentRelID := p.getParentRelIDIfPartitioned(msg.RelationID)

tableName, exists := p.srcTableIDNameMapping[relID]
tableName, exists := p.srcTableIDNameMapping[parentRelID]
if !exists {
return nil, nil
}

// log lsn and relation id for debugging
p.logger.Debug("InsertMessage", slog.Any("LSN", lsn), slog.Uint64("RelationID", uint64(relID)), slog.String("Relation Name", tableName))
p.logger.Debug("InsertMessage",
slog.Any("LSN", lsn),
slog.Uint64("RelationID", uint64(msg.RelationID)),
slog.String("Relation Name", tableName))

rel, ok := p.relationMessageMapping[relID]
rel, ok := p.relationMessageMapping[msg.RelationID]
if !ok {
return nil, fmt.Errorf("unknown relation id %d for table %s", relID, tableName)
return nil, fmt.Errorf("unknown relation id %d for table %s", msg.RelationID, tableName)
}

schemaName, err := p.getSourceSchemaForDestinationColumn(relID, tableName)
schemaName, err := p.getSourceSchemaForDestinationColumn(parentRelID, tableName)
if err != nil {
return nil, err
}
Expand All @@ -994,22 +1024,25 @@ func processUpdateMessage[Items model.Items](
processor replProcessor[Items],
customTypeMapping map[uint32]shared.CustomDataType,
) (model.Record[Items], error) {
relID := p.getParentRelIDIfPartitioned(msg.RelationID)
parentRelID := p.getParentRelIDIfPartitioned(msg.RelationID)

tableName, exists := p.srcTableIDNameMapping[relID]
tableName, exists := p.srcTableIDNameMapping[parentRelID]
if !exists {
return nil, nil
}

// log lsn and relation id for debugging
p.logger.Debug("UpdateMessage", slog.Any("LSN", lsn), slog.Uint64("RelationID", uint64(relID)), slog.String("Relation Name", tableName))
p.logger.Debug("UpdateMessage",
slog.Any("LSN", lsn),
slog.Uint64("RelationID", uint64(msg.RelationID)),
slog.String("Relation Name", tableName))

rel, ok := p.relationMessageMapping[relID]
rel, ok := p.relationMessageMapping[msg.RelationID]
if !ok {
return nil, fmt.Errorf("unknown relation id %d for table %s", relID, tableName)
return nil, fmt.Errorf("unknown relation id %d for table %s", msg.RelationID, tableName)
}

schemaName, err := p.getSourceSchemaForDestinationColumn(relID, tableName)
schemaName, err := p.getSourceSchemaForDestinationColumn(parentRelID, tableName)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1057,22 +1090,25 @@ func processDeleteMessage[Items model.Items](
processor replProcessor[Items],
customTypeMapping map[uint32]shared.CustomDataType,
) (model.Record[Items], error) {
relID := p.getParentRelIDIfPartitioned(msg.RelationID)
parentRelID := p.getParentRelIDIfPartitioned(msg.RelationID)

tableName, exists := p.srcTableIDNameMapping[relID]
tableName, exists := p.srcTableIDNameMapping[parentRelID]
if !exists {
return nil, nil
}

// log lsn and relation id for debugging
p.logger.Debug("DeleteMessage", slog.Any("LSN", lsn), slog.Uint64("RelationID", uint64(relID)), slog.String("Relation Name", tableName))
p.logger.Debug("DeleteMessage",
slog.Any("LSN", lsn),
slog.Uint64("RelationID", uint64(msg.RelationID)),
slog.String("Relation Name", tableName))

rel, ok := p.relationMessageMapping[relID]
rel, ok := p.relationMessageMapping[msg.RelationID]
if !ok {
return nil, fmt.Errorf("unknown relation id %d for table %s", relID, tableName)
return nil, fmt.Errorf("unknown relation id %d for table %s", msg.RelationID, tableName)
}

schemaName, err := p.getSourceSchemaForDestinationColumn(relID, tableName)
schemaName, err := p.getSourceSchemaForDestinationColumn(parentRelID, tableName)
if err != nil {
return nil, err
}
Expand All @@ -1097,6 +1133,7 @@ func processRelationMessage[Items model.Items](
p *PostgresCDCSource,
lsn pglogrepl.LSN,
currRel *pglogrepl.RelationMessage,
originalRelID uint32,
) (model.Record[Items], error) {
// not present in tables to sync, return immediately
currRelName, ok := p.srcTableIDNameMapping[currRel.RelationID]
Expand Down Expand Up @@ -1249,7 +1286,7 @@ func processRelationMessage[Items model.Items](
}
}

p.relationMessageMapping[currRel.RelationID] = currRel
p.relationMessageMapping[originalRelID] = currRel
// only log audit if there is actionable delta
if len(schemaDelta.AddedColumns) > 0 {
return &model.RelationRecord[Items]{
Expand Down Expand Up @@ -1286,7 +1323,7 @@ func (p *PostgresCDCSource) getParentRelIDIfPartitioned(relID uint32) uint32 {
// filtered by relkind; parent needs to be a partitioned table by default
func (p *PostgresCDCSource) checkIfUnknownTableInherits(ctx context.Context,
relID uint32,
) (uint32, error) {
) (uint32, byte, error) {
relID = p.getParentRelIDIfPartitioned(relID)
relkinds := "'p'"
if p.handleInheritanceForNonPartitionedTables {
Expand All @@ -1303,18 +1340,18 @@ func (p *PostgresCDCSource) checkIfUnknownTableInherits(ctx context.Context,
relID,
).Scan(&parentRelID); err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return relID, nil
return relID, 0, nil
}
return 0, fmt.Errorf("failed to query pg_inherits: %w", err)
return 0, 0, fmt.Errorf("failed to query pg_inherits: %w", err)
}
p.childToParentRelIDMapping[relID] = parentRelID
p.hushWarnUnknownTableDetected[relID] = struct{}{}
p.logger.Info("Detected new child table in CDC stream, remapping to parent table",
slog.Uint64("childRelID", uint64(relID)),
slog.Uint64("parentRelID", uint64(parentRelID)),
slog.String("parentTableName", p.srcTableIDNameMapping[parentRelID]))
return parentRelID, nil
return parentRelID, p.idToRelKindMap[parentRelID], nil
}

return relID, nil
return relID, p.idToRelKindMap[relID], nil
}
Loading
Loading