From 7171c34bef6afef22ff9ad58d52151847cfd8114 Mon Sep 17 00:00:00 2001 From: Ibrar Ahmed Date: Thu, 15 Jan 2026 19:59:52 +0500 Subject: [PATCH] Add preserve-origin feature for table repair operations Implement preserve-origin feature that maintains original replication origin node ID and LSN when repairing rows during recovery scenarios. This prevents replication conflicts when the origin node returns to the cluster. --- db/queries/queries.go | 96 +++ db/queries/templates.go | 24 + docs/commands/repair/table-repair.md | 40 ++ internal/api/http/handler.go | 5 + internal/cli/cli.go | 6 + internal/consistency/repair/table_repair.go | 756 +++++++++++++++++--- 6 files changed, 817 insertions(+), 110 deletions(-) diff --git a/db/queries/queries.go b/db/queries/queries.go index 519e4fa..244ce54 100644 --- a/db/queries/queries.go +++ b/db/queries/queries.go @@ -2925,3 +2925,99 @@ func RemoveTableFromCDCMetadata(ctx context.Context, db DBQuerier, tableName, pu return nil } + +func GetReplicationOriginByName(ctx context.Context, db DBQuerier, originName string) (*uint32, error) { + sql, err := RenderSQL(SQLTemplates.GetReplicationOriginByName, nil) + if err != nil { + return nil, fmt.Errorf("failed to render GetReplicationOriginByName SQL: %w", err) + } + + var originID uint32 + err = db.QueryRow(ctx, sql, originName).Scan(&originID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, nil + } + return nil, fmt.Errorf("query to get replication origin by name '%s' failed: %w", originName, err) + } + + return &originID, nil +} + +func CreateReplicationOrigin(ctx context.Context, db DBQuerier, originName string) (uint32, error) { + sql, err := RenderSQL(SQLTemplates.CreateReplicationOrigin, nil) + if err != nil { + return 0, fmt.Errorf("failed to render CreateReplicationOrigin SQL: %w", err) + } + + var originID uint32 + err = db.QueryRow(ctx, sql, originName).Scan(&originID) + if err != nil { + return 0, fmt.Errorf("query to create replication origin '%s' failed: %w", originName, err) + } + + return originID, nil +} + +func SetupReplicationOriginSession(ctx context.Context, db DBQuerier, originName string) error { + sql, err := RenderSQL(SQLTemplates.SetupReplicationOriginSession, nil) + if err != nil { + return fmt.Errorf("failed to render SetupReplicationOriginSession SQL: %w", err) + } + + _, err = db.Exec(ctx, sql, originName) + if err != nil { + return fmt.Errorf("query to setup replication origin session for origin '%s' failed: %w", originName, err) + } + + return nil +} + +func ResetReplicationOriginSession(ctx context.Context, db DBQuerier) error { + sql, err := RenderSQL(SQLTemplates.ResetReplicationOriginSession, nil) + if err != nil { + return fmt.Errorf("failed to render ResetReplicationOriginSession SQL: %w", err) + } + + _, err = db.Exec(ctx, sql) + if err != nil { + return fmt.Errorf("query to reset replication origin session failed: %w", err) + } + + return nil +} + +func SetupReplicationOriginXact(ctx context.Context, db DBQuerier, originLSN string, originTimestamp *time.Time) error { + sql, err := RenderSQL(SQLTemplates.SetupReplicationOriginXact, nil) + if err != nil { + return fmt.Errorf("failed to render SetupReplicationOriginXact SQL: %w", err) + } + + var timestampParam any + if originTimestamp != nil { + timestampParam = originTimestamp.Format(time.RFC3339) + } else { + timestampParam = nil + } + + _, err = db.Exec(ctx, sql, originLSN, timestampParam) + if err != nil { + return fmt.Errorf("query to setup replication origin xact with LSN %s failed: %w", originLSN, err) + } + + return nil +} + +func ResetReplicationOriginXact(ctx context.Context, db DBQuerier) error { + sql, err := RenderSQL(SQLTemplates.ResetReplicationOriginXact, nil) + if err != nil { + return fmt.Errorf("failed to render ResetReplicationOriginXact SQL: %w", err) + } + + _, err = db.Exec(ctx, sql) + if err != nil { + return fmt.Errorf("query to reset replication origin xact failed: %w", err) + } + + return nil +} diff --git a/db/queries/templates.go b/db/queries/templates.go index b93bd23..4298ccd 100644 --- a/db/queries/templates.go +++ b/db/queries/templates.go @@ -120,6 +120,12 @@ type Templates struct { RemoveTableFromCDCMetadata *template.Template GetSpockOriginLSNForNode *template.Template GetSpockSlotLSNForNode *template.Template + GetReplicationOriginByName *template.Template + CreateReplicationOrigin *template.Template + SetupReplicationOriginSession *template.Template + ResetReplicationOriginSession *template.Template + SetupReplicationOriginXact *template.Template + ResetReplicationOriginXact *template.Template } var SQLTemplates = Templates{ @@ -1543,4 +1549,22 @@ var SQLTemplates = Templates{ ORDER BY rs.confirmed_flush_lsn DESC LIMIT 1 `)), + GetReplicationOriginByName: template.Must(template.New("getReplicationOriginByName").Parse(` + SELECT roident FROM pg_replication_origin WHERE roname = $1 + `)), + CreateReplicationOrigin: template.Must(template.New("createReplicationOrigin").Parse(` + SELECT pg_replication_origin_create($1) + `)), + SetupReplicationOriginSession: template.Must(template.New("setupReplicationOriginSession").Parse(` + SELECT pg_replication_origin_session_setup($1) + `)), + ResetReplicationOriginSession: template.Must(template.New("resetReplicationOriginSession").Parse(` + SELECT pg_replication_origin_session_reset() + `)), + SetupReplicationOriginXact: template.Must(template.New("setupReplicationOriginXact").Parse(` + SELECT pg_replication_origin_xact_setup($1, $2) + `)), + ResetReplicationOriginXact: template.Must(template.New("resetReplicationOriginXact").Parse(` + SELECT pg_replication_origin_xact_reset() + `)), } diff --git a/docs/commands/repair/table-repair.md b/docs/commands/repair/table-repair.md index eec49a6..64de026 100644 --- a/docs/commands/repair/table-repair.md +++ b/docs/commands/repair/table-repair.md @@ -30,6 +30,7 @@ Performs repairs on tables of divergent nodes based on the diff report generated | `--bidirectional` | `-Z` | Perform insert-only repairs in both directions | `false` | | `--fire-triggers` | `-t` | Execute triggers (otherwise runs with `session_replication_role='replica'`) | `false` | | `--recovery-mode` | | Enable recovery-mode repair when the diff was generated with `--against-origin`; can auto-select a source of truth using Spock LSNs | `false` | +| `--preserve-origin` | | Preserve replication origin node ID and LSN for repaired rows. When enabled, repaired rows will have commits with the original node's origin ID instead of the local node ID. Requires LSN to be available from a survivor node. | `true` | | `--quiet` | `-q` | Suppress non-essential logging | `false` | | `--debug` | `-v` | Enable verbose logging | `false` | @@ -69,3 +70,42 @@ Diff reports share the same prefix generated by `table-diff` (for example `publi ## Fixing null-only drifts (`--fix-nulls`) Replication hiccups can leave some columns NULL on one node while populated on another. The `--fix-nulls` mode cross-fills those NULLs in both directions using values from the paired node(s); it does **not** require a source-of-truth. Use it when the diff shows only NULL/NOT NULL mismatches and you want to reconcile columns without preferring a single node. + +## Preserving replication origin (`--preserve-origin`) + +By default, `--preserve-origin` is enabled. When repairing rows, this ensures that the repaired rows maintain the correct replication origin node ID and LSN from the original transaction, rather than using the local node's ID. This is particularly important in recovery scenarios where: + +- A node fails and rows are repaired from a survivor +- The failed node may come back online +- Without origin tracking, the repaired rows would have the local node's origin ID, which could cause conflicts when the original node resumes replication + +### How it works + +1. **Origin extraction**: ACE extracts the `node_origin` and `commit_ts` from the diff file metadata for each row being repaired. + +2. **LSN retrieval**: For each origin node, ACE queries a survivor node to obtain the origin LSN. This LSN must be available - if it's not, the repair will fail (as required for data consistency). + +3. **Replication origin session**: Before executing repairs for each origin group, ACE: + - Gets or creates a replication origin for the origin node + - Sets up a replication origin session + - Configures the session with the origin LSN and timestamp + - Executes the repairs + - Resets the session + +4. **Grouping**: Rows are automatically grouped by origin node to minimize session setup overhead. + +### Requirements and limitations + +- **LSN availability**: The origin LSN must be available from at least one survivor node. If not available, the repair will fail with an error. +- **Survivor nodes**: At least one survivor node must be accessible to fetch the origin LSN. +- **Privileges**: Replication origin functions require superuser or replication privileges on the target database. +- **Missing metadata**: If origin metadata is missing from the diff file for some rows, those rows will be repaired without origin tracking (a warning will be logged). + +### When to disable + +You may want to disable `--preserve-origin` with `--no-preserve-origin` if: +- You're certain the origin node will not come back online +- You've permanently removed the origin node from the cluster +- You want repaired rows to be treated as local writes + +**Note**: Disabling origin preservation should only be done when you're certain about the node's status, as it can cause replication conflicts if the origin node returns. diff --git a/internal/api/http/handler.go b/internal/api/http/handler.go index 3f176cb..ef5dc96 100644 --- a/internal/api/http/handler.go +++ b/internal/api/http/handler.go @@ -55,6 +55,7 @@ type tableRepairRequest struct { GenerateReport bool `json:"generate_report"` FixNulls bool `json:"fix_nulls"` Bidirectional bool `json:"bidirectional"` + PreserveOrigin *bool `json:"preserve_origin,omitempty"` } type spockDiffRequest struct { @@ -434,6 +435,10 @@ func (s *APIServer) handleTableRepair(w http.ResponseWriter, r *http.Request) { task.GenerateReport = req.GenerateReport task.FixNulls = req.FixNulls task.Bidirectional = req.Bidirectional + // PreserveOrigin defaults to true if not explicitly set + if req.PreserveOrigin != nil { + task.PreserveOrigin = *req.PreserveOrigin + } task.Ctx = r.Context() task.ClientRole = clientInfo.role task.InvokeMethod = "api" diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 3c0b214..6bd8bf1 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -223,6 +223,11 @@ func SetupCLI() *cli.App { Usage: "Enable recovery-mode repair using origin-only diffs", Value: false, }, + &cli.BoolFlag{ + Name: "preserve-origin", + Usage: "Preserve replication origin node ID and LSN for repaired rows (default: true)", + Value: true, + }, &cli.BoolFlag{ Name: "fix-nulls", Aliases: []string{"X"}, @@ -1199,6 +1204,7 @@ func TableRepairCLI(ctx *cli.Context) error { task.Bidirectional = ctx.Bool("bidirectional") task.GenerateReport = ctx.Bool("generate-report") task.RecoveryMode = ctx.Bool("recovery-mode") + task.PreserveOrigin = ctx.Bool("preserve-origin") if err := task.ValidateAndPrepare(); err != nil { return fmt.Errorf("validation failed: %w", err) diff --git a/internal/consistency/repair/table_repair.go b/internal/consistency/repair/table_repair.go index b49d9ed..d375f39 100644 --- a/internal/consistency/repair/table_repair.go +++ b/internal/consistency/repair/table_repair.go @@ -70,6 +70,7 @@ type TableRepairTask struct { FixNulls bool // TBD Bidirectional bool RecoveryMode bool + PreserveOrigin bool InvokeMethod string // TBD ClientRole string // TBD @@ -113,8 +114,9 @@ func NewTableRepairTask() *TableRepairTask { TaskType: taskstore.TaskTypeTableRepair, TaskStatus: taskstore.StatusPending, }, - InvokeMethod: "cli", - Pools: make(map[string]*pgxpool.Pool), + InvokeMethod: "cli", + PreserveOrigin: true, + Pools: make(map[string]*pgxpool.Pool), DerivedFields: types.DerivedFields{ HostMap: make(map[string]string), }, @@ -711,9 +713,10 @@ type rowData struct { } type nullUpdate struct { - pkValues []any - pkMap map[string]any - columns map[string]any + pkValues []any + pkMap map[string]any + columns map[string]any + sourceRow types.OrderedMap // Source row providing the non-null value (for origin tracking) } func (t *TableRepairTask) runFixNulls(startTime time.Time) error { @@ -851,7 +854,27 @@ func (t *TableRepairTask) runFixNulls(startTime time.Time) error { continue } - updatedCount, err := t.applyFixNullsUpdates(tx, col, colType, rowsForCol, colTypes) + // Extract origin information for fix-nulls updates if preserve-origin is enabled + var originInfoMap map[string]*rowOriginInfo + if t.PreserveOrigin { + originInfoMap = make(map[string]*rowOriginInfo) + for _, nu := range rowsForCol { + if nu.sourceRow != nil { + pkeyStr, err := utils.StringifyKey(nu.pkMap, t.Key) + if err != nil { + // Try alternative method + pkeyStr, err = utils.StringifyOrderedMapKey(nu.sourceRow, t.Key) + } + if err == nil { + if originInfo := extractOriginInfoFromRow(nu.sourceRow); originInfo != nil { + originInfoMap[pkeyStr] = originInfo + } + } + } + } + } + + updatedCount, err := t.applyFixNullsUpdates(tx, col, colType, rowsForCol, colTypes, originInfoMap, nodeName) if err != nil { nodeFailed = true tx.Rollback(t.Ctx) @@ -977,9 +1000,27 @@ func (t *TableRepairTask) buildNullUpdates() (map[string]map[string]*nullUpdate, val2 := row2.data[col] if val1 == nil && val2 != nil { - addNullUpdate(updatesByNode, node1Name, row1, col, val2) + // Update node1 with value from node2 - find the source row from node2 + var sourceRow types.OrderedMap + for _, r := range node2Rows { + pkeyStr, err := utils.StringifyOrderedMapKey(r, t.Key) + if err == nil && pkeyStr == pkKey { + sourceRow = r + break + } + } + addNullUpdate(updatesByNode, node1Name, row1, col, val2, sourceRow) } else if val2 == nil && val1 != nil { - addNullUpdate(updatesByNode, node2Name, row2, col, val1) + // Update node2 with value from node1 - find the source row from node1 + var sourceRow types.OrderedMap + for _, r := range node1Rows { + pkeyStr, err := utils.StringifyOrderedMapKey(r, t.Key) + if err == nil && pkeyStr == pkKey { + sourceRow = r + break + } + } + addNullUpdate(updatesByNode, node2Name, row2, col, val1, sourceRow) } } } @@ -1019,7 +1060,7 @@ func buildRowIndex(rows []types.OrderedMap, keyCols []string) (map[string]rowDat return index, nil } -func addNullUpdate(updates map[string]map[string]*nullUpdate, nodeName string, row rowData, col string, value any) { +func addNullUpdate(updates map[string]map[string]*nullUpdate, nodeName string, row rowData, col string, value any, sourceRow types.OrderedMap) { if value == nil { return } @@ -1032,9 +1073,10 @@ func addNullUpdate(updates map[string]map[string]*nullUpdate, nodeName string, r nu, ok := nodeUpdates[row.pkKey] if !ok { nu = &nullUpdate{ - pkValues: row.pkValues, - pkMap: row.pkMap, - columns: make(map[string]any), + pkValues: row.pkValues, + pkMap: row.pkMap, + columns: make(map[string]any), + sourceRow: sourceRow, } nodeUpdates[row.pkKey] = nu } @@ -1042,6 +1084,11 @@ func addNullUpdate(updates map[string]map[string]*nullUpdate, nodeName string, r if _, exists := nu.columns[col]; !exists { nu.columns[col] = value } + + // Update source row if not set or if this is a newer source + if nu.sourceRow == nil || len(sourceRow) > 0 { + nu.sourceRow = sourceRow + } } func (t *TableRepairTask) getFixNullsDryRunOutput(updates map[string]map[string]*nullUpdate) (string, error) { @@ -1122,30 +1169,143 @@ func (t *TableRepairTask) populateFixNullsReport(nodeName string, nodeUpdates ma t.report.Changes[nodeName].(map[string]any)[field] = rows } -func (t *TableRepairTask) applyFixNullsUpdates(tx pgx.Tx, column string, columnType string, updates []*nullUpdate, colTypes map[string]string) (int, error) { +func (t *TableRepairTask) applyFixNullsUpdates(tx pgx.Tx, column string, columnType string, updates []*nullUpdate, colTypes map[string]string, originInfoMap map[string]*rowOriginInfo, nodeName string) (int, error) { if len(updates) == 0 { return 0, nil } + // Group updates by origin if preserve-origin is enabled + var originGroups map[string][]*nullUpdate + if t.PreserveOrigin && originInfoMap != nil && len(originInfoMap) > 0 { + originGroups = make(map[string][]*nullUpdate) + for _, nu := range updates { + pkeyStr, err := utils.StringifyKey(nu.pkMap, t.Key) + if err != nil { + // Try alternative method if available + if nu.sourceRow != nil { + pkeyStr, err = utils.StringifyOrderedMapKey(nu.sourceRow, t.Key) + } + } + if err != nil { + continue + } + originInfo, hasOrigin := originInfoMap[pkeyStr] + originNode := "" + if hasOrigin && originInfo != nil && originInfo.nodeOrigin != "" { + originNode = originInfo.nodeOrigin + } + originGroups[originNode] = append(originGroups[originNode], nu) + } + } else { + // No origin tracking - process all updates together + originGroups = map[string][]*nullUpdate{ + "": updates, + } + } + totalUpdated := 0 - batchSize := 500 - for i := 0; i < len(updates); i += batchSize { - end := i + batchSize - if end > len(updates) { - end = len(updates) + + // Process each origin group separately + for originNode, originUpdates := range originGroups { + if len(originUpdates) == 0 { + continue } - batch := updates[i:end] - updateSQL, args, err := t.buildFixNullsBatchSQL(column, columnType, batch, colTypes) - if err != nil { - return totalUpdated, err + // Set up replication origin session if we have origin info and preserve-origin is enabled + if t.PreserveOrigin && originNode != "" { + // Get origin info for this group + var groupOriginInfo *rowOriginInfo + for _, nu := range originUpdates { + pkeyStr, err := utils.StringifyKey(nu.pkMap, t.Key) + if err != nil { + // Try alternative method if available + if nu.sourceRow != nil { + pkeyStr, err = utils.StringifyOrderedMapKey(nu.sourceRow, t.Key) + } + } + if err == nil { + if info, ok := originInfoMap[pkeyStr]; ok && info != nil { + groupOriginInfo = info + break + } + } + } + + if groupOriginInfo != nil && groupOriginInfo.nodeOrigin != "" { + // Get LSN from a survivor node if not already set + if groupOriginInfo.lsn == nil { + var survivorNode string + for poolNode := range t.Pools { + if poolNode != groupOriginInfo.nodeOrigin && poolNode != nodeName { + survivorNode = poolNode + break + } + } + if survivorNode == "" && t.SourceOfTruth != "" && t.SourceOfTruth != groupOriginInfo.nodeOrigin { + survivorNode = t.SourceOfTruth + } + + if survivorNode != "" { + lsn, err := t.getOriginLSNForNode(groupOriginInfo.nodeOrigin, survivorNode) + if err != nil { + return totalUpdated, fmt.Errorf("failed to get origin LSN for node %s: %w", groupOriginInfo.nodeOrigin, err) + } + groupOriginInfo.lsn = lsn + } else { + return totalUpdated, fmt.Errorf("no survivor node available to fetch LSN for origin node %s", groupOriginInfo.nodeOrigin) + } + } + + // Step 1: Setup session + _, err := t.setupReplicationOriginSession(tx, groupOriginInfo.nodeOrigin) + if err != nil { + return totalUpdated, fmt.Errorf("failed to setup replication origin session for node %s: %w", groupOriginInfo.nodeOrigin, err) + } + + // Step 2: Setup xact with LSN and timestamp (BEFORE DML) + if err := t.setupReplicationOriginXact(tx, groupOriginInfo.lsn, groupOriginInfo.commitTS); err != nil { + t.resetReplicationOriginSession(tx) // Cleanup on error + return totalUpdated, fmt.Errorf("failed to setup replication origin xact for node %s: %w", groupOriginInfo.nodeOrigin, err) + } + } } - tag, err := tx.Exec(t.Ctx, updateSQL, args...) - if err != nil { - return totalUpdated, fmt.Errorf("error executing fix-nulls batch for column %s: %w", column, err) + // Process batches for this origin group + batchSize := 500 + for i := 0; i < len(originUpdates); i += batchSize { + end := i + batchSize + if end > len(originUpdates) { + end = len(originUpdates) + } + batch := originUpdates[i:end] + + updateSQL, args, err := t.buildFixNullsBatchSQL(column, columnType, batch, colTypes) + if err != nil { + if t.PreserveOrigin && originNode != "" { + t.resetReplicationOriginXact(tx) + t.resetReplicationOriginSession(tx) + } + return totalUpdated, err + } + + tag, err := tx.Exec(t.Ctx, updateSQL, args...) + if err != nil { + if t.PreserveOrigin && originNode != "" { + t.resetReplicationOriginXact(tx) + t.resetReplicationOriginSession(tx) + } + return totalUpdated, fmt.Errorf("error executing fix-nulls batch for column %s: %w", column, err) + } + totalUpdated += int(tag.RowsAffected()) + } + + // Step 3: Reset xact (BEFORE commit, within transaction) + if t.PreserveOrigin && originNode != "" { + if err := t.resetReplicationOriginXact(tx); err != nil { + return totalUpdated, fmt.Errorf("failed to reset replication origin xact: %w", err) + } } - totalUpdated += int(tag.RowsAffected()) + // Note: Session reset happens AFTER commit in the calling function } return totalUpdated, nil @@ -1387,7 +1547,52 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { continue } - upsertedCount, err := executeUpserts(tx, t, nodeName, nodeUpserts, targetNodeColTypes) + // Extract origin information from source rows if preserve-origin is enabled + var originInfoMap map[string]*rowOriginInfo + if t.PreserveOrigin { + originInfoMap = make(map[string]*rowOriginInfo) + // Extract origin info from all source rows in the diff + // For repair plans, we'll extract from both nodes and use the appropriate one + // For source-of-truth repairs, we extract from the source of truth + for nodePair, diffs := range t.RawDiffs.NodeDiffs { + nodes := strings.Split(nodePair, "/") + if len(nodes) != 2 { + continue + } + node1Name, node2Name := nodes[0], nodes[1] + + // Extract from both nodes - we'll use the one that matches the source + for _, sourceNode := range []string{node1Name, node2Name} { + sourceRows := diffs.Rows[sourceNode] + for _, row := range sourceRows { + pkeyStr, err := utils.StringifyOrderedMapKey(row, t.Key) + if err != nil { + continue + } + // Only add if this row is being upserted to the target node + if _, isBeingUpserted := nodeUpserts[pkeyStr]; isBeingUpserted { + if originInfo := extractOriginInfoFromRow(row); originInfo != nil { + // For repair plans, prefer the source node that's providing the data + // For source-of-truth, prefer the source of truth + if t.RepairPlan == nil { + // Source-of-truth: only use if it's the source of truth + if sourceNode == t.SourceOfTruth { + originInfoMap[pkeyStr] = originInfo + } + } else { + // Repair plan: use the first one we find (will be overridden if needed) + if _, exists := originInfoMap[pkeyStr]; !exists { + originInfoMap[pkeyStr] = originInfo + } + } + } + } + } + } + } + } + + upsertedCount, err := executeUpserts(tx, t, nodeName, nodeUpserts, targetNodeColTypes, originInfoMap) if err != nil { tx.Rollback(t.Ctx) logger.Error("executing upserts on node %s: %v", nodeName, err) @@ -1432,6 +1637,14 @@ func (t *TableRepairTask) runUnidirectionalRepair(startTime time.Time) error { continue } logger.Debug("Transaction committed successfully on %s", nodeName) + + // Reset replication origin session AFTER commit (session-level, not transaction-level) + if t.PreserveOrigin { + if err := t.resetReplicationOriginSessionOnConnection(divergentPool); err != nil { + logger.Warn("failed to reset replication origin session on %s after commit: %v", nodeName, err) + // Don't fail the repair for this - it's cleanup + } + } } t.FinishedAt = time.Now() @@ -1690,7 +1903,42 @@ func (t *TableRepairTask) performBirectionalInserts(nodeName string, inserts map // Bidirectional is always insert only originalInsertOnly := t.InsertOnly t.InsertOnly = true - insertedCount, err := executeUpserts(tx, t, nodeName, inserts, targetNodeColTypes) + // Extract origin information from source rows for bidirectional repair + var originInfoMap map[string]*rowOriginInfo + if t.PreserveOrigin { + originInfoMap = make(map[string]*rowOriginInfo) + // For bidirectional, origin is the node providing the data + // We need to find which node pair this insert came from + for nodePairKey, diffs := range t.RawDiffs.NodeDiffs { + nodes := strings.Split(nodePairKey, "/") + if len(nodes) != 2 { + continue + } + var sourceNode string + if nodes[0] == nodeName { + sourceNode = nodes[1] // Data coming from the other node + } else if nodes[1] == nodeName { + sourceNode = nodes[0] // Data coming from the other node + } else { + continue + } + + sourceRows := diffs.Rows[sourceNode] + for _, row := range sourceRows { + pkeyStr, err := utils.StringifyOrderedMapKey(row, t.Key) + if err != nil { + continue + } + if _, exists := inserts[pkeyStr]; exists { + if originInfo := extractOriginInfoFromRow(row); originInfo != nil { + originInfoMap[pkeyStr] = originInfo + } + } + } + } + } + + insertedCount, err := executeUpserts(tx, t, nodeName, inserts, targetNodeColTypes, originInfoMap) t.InsertOnly = originalInsertOnly if err != nil { @@ -1818,122 +2066,301 @@ func executeDeletes(ctx context.Context, tx pgx.Tx, task *TableRepairTask, nodeN return totalDeletedCount, nil } +// rowOriginInfo holds origin metadata for a row +type rowOriginInfo struct { + nodeOrigin string + commitTS *time.Time + lsn *uint64 +} + +// extractOriginInfoFromRow extracts origin information from a row's metadata. +// Returns nil if no origin information is available. +func extractOriginInfoFromRow(row types.OrderedMap) *rowOriginInfo { + rowMap := utils.OrderedMapToMap(row) + + // Check for metadata in _spock_metadata_ field + var meta map[string]any + if rawMeta, ok := rowMap["_spock_metadata_"].(map[string]any); ok { + meta = rawMeta + } else { + meta = make(map[string]any) + } + + // Also check for direct fields (for backward compatibility) + if val, ok := rowMap["node_origin"]; ok { + meta["node_origin"] = val + } + if val, ok := rowMap["commit_ts"]; ok { + meta["commit_ts"] = val + } + + var nodeOrigin string + var commitTS *time.Time + + if originVal, ok := meta["node_origin"]; ok && originVal != nil { + originStr := strings.TrimSpace(fmt.Sprintf("%v", originVal)) + if originStr != "" && originStr != "0" && originStr != "local" { + nodeOrigin = originStr + } + } + + if tsVal, ok := meta["commit_ts"]; ok && tsVal != nil { + var ts time.Time + var err error + switch v := tsVal.(type) { + case time.Time: + ts = v + case string: + ts, err = time.Parse(time.RFC3339, v) + if err != nil { + // Try other formats + ts, err = time.Parse("2006-01-02 15:04:05.999999-07", v) + } + } + if err == nil { + commitTS = &ts + } + } + + if nodeOrigin == "" { + return nil + } + + return &rowOriginInfo{ + nodeOrigin: nodeOrigin, + commitTS: commitTS, + } +} + // executeUpserts handles upserting rows in batches. -func executeUpserts(tx pgx.Tx, task *TableRepairTask, nodeName string, upserts map[string]map[string]any, colTypes map[string]string) (int, error) { +// originInfoMap maps primary key strings to their origin information. +// If originInfoMap is nil or empty, origin tracking is skipped. +func executeUpserts(tx pgx.Tx, task *TableRepairTask, nodeName string, upserts map[string]map[string]any, colTypes map[string]string, originInfoMap map[string]*rowOriginInfo) (int, error) { if err := task.filterStaleRepairs(task.Ctx, tx, nodeName, upserts, colTypes, "upsert"); err != nil { return 0, err } - rowsToUpsert := make([][]any, 0, len(upserts)) + // Group rows by origin if preserve-origin is enabled and we have origin info + var originGroups map[string]map[string]map[string]any // origin -> pkey -> row + rowsWithoutOrigin := 0 + if task.PreserveOrigin && originInfoMap != nil && len(originInfoMap) > 0 { + originGroups = make(map[string]map[string]map[string]any) + for pkey, row := range upserts { + originInfo, hasOrigin := originInfoMap[pkey] + if hasOrigin && originInfo != nil && originInfo.nodeOrigin != "" { + originNode := originInfo.nodeOrigin + if originGroups[originNode] == nil { + originGroups[originNode] = make(map[string]map[string]any) + } + originGroups[originNode][pkey] = row + } else { + // Rows without origin info go into a special group + rowsWithoutOrigin++ + if originGroups[""] == nil { + originGroups[""] = make(map[string]map[string]any) + } + originGroups[""][pkey] = row + } + } + if rowsWithoutOrigin > 0 { + logger.Warn("preserve-origin enabled but %d rows missing origin metadata - these will be repaired without origin tracking", rowsWithoutOrigin) + } + } else { + // No origin tracking - process all rows together + originGroups = map[string]map[string]map[string]any{ + "": upserts, + } + } + + totalUpsertedCount := 0 orderedCols := task.Cols - for _, rowMap := range upserts { - typedRow := make([]any, len(orderedCols)) - for i, colName := range orderedCols { - val, valExists := rowMap[colName] - pgType, typeExists := colTypes[colName] + // Process each origin group separately + for originNode, originUpserts := range originGroups { + if len(originUpserts) == 0 { + continue + } - if !valExists { - typedRow[i] = nil - continue - } - if !typeExists { - return 0, fmt.Errorf("type for column %s not found in target node's colTypes", colName) + // Set up replication origin session if we have origin info and preserve-origin is enabled + if task.PreserveOrigin && originNode != "" { + // Get origin info for this group (all rows in group should have same origin) + var groupOriginInfo *rowOriginInfo + for pkey := range originUpserts { + if info, ok := originInfoMap[pkey]; ok && info != nil { + groupOriginInfo = info + break + } } - convertedVal, err := utils.ConvertToPgxType(val, pgType) - if err != nil { - return 0, fmt.Errorf("error converting value for column %s (value: %v, type: %s): %w", colName, val, pgType, err) + if groupOriginInfo != nil && groupOriginInfo.nodeOrigin != "" { + // Get LSN from a survivor node if not already set + if groupOriginInfo.lsn == nil { + // Find a survivor node (any node that's not the origin and is in our pools) + var survivorNode string + for poolNode := range task.Pools { + if poolNode != groupOriginInfo.nodeOrigin && poolNode != nodeName { + survivorNode = poolNode + break + } + } + // If no other node, try using source of truth + if survivorNode == "" && task.SourceOfTruth != "" && task.SourceOfTruth != groupOriginInfo.nodeOrigin { + survivorNode = task.SourceOfTruth + } + + if survivorNode != "" { + lsn, err := task.getOriginLSNForNode(groupOriginInfo.nodeOrigin, survivorNode) + if err != nil { + return totalUpsertedCount, fmt.Errorf("failed to get origin LSN for node %s: %w", groupOriginInfo.nodeOrigin, err) + } + groupOriginInfo.lsn = lsn + } else { + return totalUpsertedCount, fmt.Errorf("no survivor node available to fetch LSN for origin node %s", groupOriginInfo.nodeOrigin) + } + } + + // Step 1: Setup session + _, err := task.setupReplicationOriginSession(tx, groupOriginInfo.nodeOrigin) + if err != nil { + return totalUpsertedCount, fmt.Errorf("failed to setup replication origin session for node %s: %w", groupOriginInfo.nodeOrigin, err) + } + + // Step 2: Setup xact with LSN and timestamp (BEFORE DML) + if err := task.setupReplicationOriginXact(tx, groupOriginInfo.lsn, groupOriginInfo.commitTS); err != nil { + task.resetReplicationOriginSession(tx) // Cleanup on error + return totalUpsertedCount, fmt.Errorf("failed to setup replication origin xact for node %s: %w", groupOriginInfo.nodeOrigin, err) + } } - typedRow[i] = convertedVal } - rowsToUpsert = append(rowsToUpsert, typedRow) - } - if len(rowsToUpsert) == 0 { - return 0, nil - } + // Convert rows to typed format + rowsToUpsert := make([][]any, 0, len(originUpserts)) + for _, rowMap := range originUpserts { + typedRow := make([]any, len(orderedCols)) + for i, colName := range orderedCols { + val, valExists := rowMap[colName] + pgType, typeExists := colTypes[colName] - totalUpsertedCount := 0 - // TODO: Make this configurable - batchSize := 1000 + if !valExists { + typedRow[i] = nil + continue + } + if !typeExists { + return totalUpsertedCount, fmt.Errorf("type for column %s not found in target node's colTypes", colName) + } - // For the max placeholders issue - if len(orderedCols) > 0 && batchSize*len(orderedCols) > 65500 { - batchSize = 65500 / len(orderedCols) - if batchSize == 0 { - batchSize = 1 + convertedVal, err := utils.ConvertToPgxType(val, pgType) + if err != nil { + return totalUpsertedCount, fmt.Errorf("error converting value for column %s (value: %v, type: %s): %w", colName, val, pgType, err) + } + typedRow[i] = convertedVal + } + rowsToUpsert = append(rowsToUpsert, typedRow) } - } - tableIdent := pgx.Identifier{task.Schema, task.Table}.Sanitize() - colIdents := make([]string, len(orderedCols)) - for i, col := range orderedCols { - colIdents[i] = pgx.Identifier{col}.Sanitize() - } - colsSQL := strings.Join(colIdents, ", ") + if len(rowsToUpsert) == 0 { + // Reset xact and session if we set them up + if task.PreserveOrigin && originNode != "" { + task.resetReplicationOriginXact(tx) + task.resetReplicationOriginSession(tx) + } + continue + } - pkColIdents := make([]string, len(task.Key)) - for i, pkCol := range task.Key { - pkColIdents[i] = pgx.Identifier{pkCol}.Sanitize() - } - pkSQL := strings.Join(pkColIdents, ", ") + // Process batches for this origin group + // TODO: Make this configurable + batchSize := 1000 - for i := 0; i < len(rowsToUpsert); i += batchSize { - end := i + batchSize - if end > len(rowsToUpsert) { - end = len(rowsToUpsert) + // For the max placeholders issue + if len(orderedCols) > 0 && batchSize*len(orderedCols) > 65500 { + batchSize = 65500 / len(orderedCols) + if batchSize == 0 { + batchSize = 1 + } } - batchRows := rowsToUpsert[i:end] - var upsertSQL strings.Builder - args := []any{} - paramIdx := 1 + tableIdent := pgx.Identifier{task.Schema, task.Table}.Sanitize() + colIdents := make([]string, len(orderedCols)) + for i, col := range orderedCols { + colIdents[i] = pgx.Identifier{col}.Sanitize() + } + colsSQL := strings.Join(colIdents, ", ") - upsertSQL.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES ", tableIdent, colsSQL)) - for j, row := range batchRows { - if j > 0 { - upsertSQL.WriteString(", ") + pkColIdents := make([]string, len(task.Key)) + for i, pkCol := range task.Key { + pkColIdents[i] = pgx.Identifier{pkCol}.Sanitize() + } + pkSQL := strings.Join(pkColIdents, ", ") + + for i := 0; i < len(rowsToUpsert); i += batchSize { + end := i + batchSize + if end > len(rowsToUpsert) { + end = len(rowsToUpsert) } - upsertSQL.WriteString("(") - for k, val := range row { - if k > 0 { + batchRows := rowsToUpsert[i:end] + + var upsertSQL strings.Builder + args := []any{} + paramIdx := 1 + + upsertSQL.WriteString(fmt.Sprintf("INSERT INTO %s (%s) VALUES ", tableIdent, colsSQL)) + for j, row := range batchRows { + if j > 0 { upsertSQL.WriteString(", ") } - upsertSQL.WriteString(fmt.Sprintf("$%d", paramIdx)) - args = append(args, val) - paramIdx++ + upsertSQL.WriteString("(") + for k, val := range row { + if k > 0 { + upsertSQL.WriteString(", ") + } + upsertSQL.WriteString(fmt.Sprintf("$%d", paramIdx)) + args = append(args, val) + paramIdx++ + } + upsertSQL.WriteString(")") } - upsertSQL.WriteString(")") - } - upsertSQL.WriteString(fmt.Sprintf(" ON CONFLICT (%s) ", pkSQL)) - if task.InsertOnly { - upsertSQL.WriteString("DO NOTHING") - } else { - upsertSQL.WriteString("DO UPDATE SET ") - setClauses := make([]string, 0, len(orderedCols)) - for _, col := range orderedCols { - isPkCol := false - for _, pk := range task.Key { - if col == pk { - isPkCol = true - break + upsertSQL.WriteString(fmt.Sprintf(" ON CONFLICT (%s) ", pkSQL)) + if task.InsertOnly { + upsertSQL.WriteString("DO NOTHING") + } else { + upsertSQL.WriteString("DO UPDATE SET ") + setClauses := make([]string, 0, len(orderedCols)) + for _, col := range orderedCols { + isPkCol := false + for _, pk := range task.Key { + if col == pk { + isPkCol = true + break + } + } + if !isPkCol { + sanitisedCol := pgx.Identifier{col}.Sanitize() + setClauses = append(setClauses, fmt.Sprintf("%s = EXCLUDED.%s", sanitisedCol, sanitisedCol)) } } - if !isPkCol { - sanitisedCol := pgx.Identifier{col}.Sanitize() - setClauses = append(setClauses, fmt.Sprintf("%s = EXCLUDED.%s", sanitisedCol, sanitisedCol)) + upsertSQL.WriteString(strings.Join(setClauses, ", ")) + } + + cmdTag, err := tx.Exec(task.Ctx, upsertSQL.String(), args...) + if err != nil { + // Reset xact and session before returning error + if task.PreserveOrigin && originNode != "" { + task.resetReplicationOriginXact(tx) + task.resetReplicationOriginSession(tx) } + return totalUpsertedCount, fmt.Errorf("error executing upsert batch: %w (SQL: %s, Args: %v)", err, upsertSQL.String(), args) } - upsertSQL.WriteString(strings.Join(setClauses, ", ")) + totalUpsertedCount += int(cmdTag.RowsAffected()) } - cmdTag, err := tx.Exec(task.Ctx, upsertSQL.String(), args...) - if err != nil { - return totalUpsertedCount, fmt.Errorf("error executing upsert batch: %w (SQL: %s, Args: %v)", err, upsertSQL.String(), args) + // Step 3: Reset xact (BEFORE commit, within transaction) + if task.PreserveOrigin && originNode != "" { + if err := task.resetReplicationOriginXact(tx); err != nil { + return totalUpsertedCount, fmt.Errorf("failed to reset replication origin xact: %w", err) + } } - totalUpsertedCount += int(cmdTag.RowsAffected()) + // Note: Session reset happens AFTER commit in the calling function } return totalUpsertedCount, nil @@ -2292,3 +2719,112 @@ func (t *TableRepairTask) autoSelectSourceOfTruth(failedNode string, involved ma return best.node, lsnDetails, nil } + +// getOriginLSNForNode fetches the origin LSN for a given origin node from a survivor node. +// Returns error if LSN is not available (per requirement: require LSN). +func (t *TableRepairTask) getOriginLSNForNode(originNodeName, survivorNodeName string) (*uint64, error) { + survivorPool, ok := t.Pools[survivorNodeName] + if !ok || survivorPool == nil { + return nil, fmt.Errorf("no connection pool for survivor node %s", survivorNodeName) + } + + originLSN, _, err := t.fetchLSNsForNode(survivorPool, originNodeName, survivorNodeName) + if err != nil { + return nil, fmt.Errorf("failed to fetch origin LSN for node %s from survivor %s: %w", originNodeName, survivorNodeName, err) + } + + if originLSN == nil { + return nil, fmt.Errorf("origin LSN not available for node %s on survivor %s (required for preserve-origin)", originNodeName, survivorNodeName) + } + + return originLSN, nil +} + +// setupReplicationOriginSession sets up the replication origin for the session. +// This should be called before starting the transaction or at the very start. +// Returns the origin ID for use in xact setup. +func (t *TableRepairTask) setupReplicationOriginSession(tx pgx.Tx, originNodeName string) (uint32, error) { + // Normalize origin node name - use "node_X" format for replication origin + originName := fmt.Sprintf("node_%s", originNodeName) + + // Get or create replication origin + originID, err := queries.GetReplicationOriginByName(t.Ctx, tx, originName) + if err != nil { + return 0, fmt.Errorf("failed to get replication origin by name '%s': %w", originName, err) + } + + if originID == nil { + // Create the replication origin if it doesn't exist + createdID, createErr := queries.CreateReplicationOrigin(t.Ctx, tx, originName) + if createErr != nil { + return 0, fmt.Errorf("failed to create replication origin '%s': %w", originName, createErr) + } + originID = &createdID + logger.Debug("Created replication origin '%s' with ID %d", originName, *originID) + } else { + logger.Debug("Found existing replication origin '%s' with ID %d", originName, *originID) + } + + // Set up the replication origin session + if err := queries.SetupReplicationOriginSession(t.Ctx, tx, originName); err != nil { + return 0, fmt.Errorf("failed to setup replication origin session for '%s': %w", originName, err) + } + + return *originID, nil +} + +// setupReplicationOriginXact sets up the transaction-level LSN and timestamp. +// This must be called within the transaction, before any DML operations. +func (t *TableRepairTask) setupReplicationOriginXact(tx pgx.Tx, originLSN *uint64, originTimestamp *time.Time) error { + if originLSN == nil { + return fmt.Errorf("origin LSN is required for xact setup") + } + + lsnStr := pglogrepl.LSN(*originLSN).String() + + if err := queries.SetupReplicationOriginXact(t.Ctx, tx, lsnStr, originTimestamp); err != nil { + return fmt.Errorf("failed to setup replication origin xact with LSN %s: %w", lsnStr, err) + } + + logger.Debug("Set replication origin xact LSN to %s", lsnStr) + if originTimestamp != nil { + logger.Debug("Set replication origin xact timestamp to %s", originTimestamp.Format(time.RFC3339)) + } + + return nil +} + +// resetReplicationOriginXact resets the transaction-level replication origin state. +// This must be called within the transaction, before commit. +func (t *TableRepairTask) resetReplicationOriginXact(tx pgx.Tx) error { + if err := queries.ResetReplicationOriginXact(t.Ctx, tx); err != nil { + return fmt.Errorf("failed to reset replication origin xact: %w", err) + } + logger.Debug("Reset replication origin xact") + return nil +} + +// resetReplicationOriginSessionOnConnection resets the session-level replication origin state. +// This should be called after commit, on the connection (not transaction). +func (t *TableRepairTask) resetReplicationOriginSessionOnConnection(pool *pgxpool.Pool) error { + conn, err := pool.Acquire(t.Ctx) + if err != nil { + return fmt.Errorf("failed to acquire connection for session reset: %w", err) + } + defer conn.Release() + + if err := queries.ResetReplicationOriginSession(t.Ctx, conn.Conn()); err != nil { + return fmt.Errorf("failed to reset replication origin session: %w", err) + } + logger.Debug("Reset replication origin session") + return nil +} + +// resetReplicationOriginSession resets the session-level replication origin state (for error cleanup within transaction). +func (t *TableRepairTask) resetReplicationOriginSession(tx pgx.Tx) error { + if err := queries.ResetReplicationOriginSession(t.Ctx, tx); err != nil { + return fmt.Errorf("failed to reset replication origin session: %w", err) + } + logger.Debug("Reset replication origin session") + return nil +}