From 7882a4bb78c8a684873ea1b4a16d8ee7fe6b1685 Mon Sep 17 00:00:00 2001 From: Ford Date: Thu, 11 Dec 2025 15:02:38 -0800 Subject: [PATCH] base loader: fix micro batch is_processed marking, add tests --- src/amp/loaders/base.py | 32 +++-- .../implementations/postgresql_loader.py | 15 ++- tests/integration/test_postgresql_loader.py | 123 ++++++++++++++++++ tests/integration/test_snowflake_loader.py | 108 +++++++++++++++ 4 files changed, 265 insertions(+), 13 deletions(-) diff --git a/src/amp/loaders/base.py b/src/amp/loaders/base.py index 3097feb..cc8a9a9 100644 --- a/src/amp/loaders/base.py +++ b/src/amp/loaders/base.py @@ -484,6 +484,7 @@ def load_stream_continuous( table_name, connection_name, response.metadata.ranges, + ranges_complete=response.metadata.ranges_complete, ) else: # Non-transactional loading (separate check, load, mark) @@ -494,6 +495,7 @@ def load_stream_continuous( table_name, connection_name, response.metadata.ranges, + ranges_complete=response.metadata.ranges_complete, **filtered_kwargs, ) @@ -611,6 +613,7 @@ def _process_batch_transactional( table_name: str, connection_name: str, ranges: List[BlockRange], + ranges_complete: bool = False, ) -> LoadResult: """ Process a data batch using transactional exactly-once semantics. @@ -622,6 +625,7 @@ def _process_batch_transactional( table_name: Target table name connection_name: Connection identifier ranges: Block ranges for this batch + ranges_complete: True when this RecordBatch completes a microbatch (streaming only) Returns: LoadResult with operation outcome @@ -630,13 +634,17 @@ def _process_batch_transactional( try: # Delegate to loader-specific transactional implementation # Loaders that support transactions implement load_batch_transactional() - rows_loaded_batch = self.load_batch_transactional(batch_data, table_name, connection_name, ranges) + rows_loaded_batch = self.load_batch_transactional( + batch_data, table_name, connection_name, ranges, ranges_complete + ) duration = time.time() - start_time - # Mark batches as processed in state store after successful transaction - if ranges: + # Mark batches as processed ONLY when microbatch is complete + # multiple RecordBatches can share the same microbatch ID + if ranges and ranges_complete: batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges] self.state_store.mark_processed(connection_name, table_name, batch_ids) + self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs') return LoadResult( rows_loaded=rows_loaded_batch, @@ -648,6 +656,7 @@ def _process_batch_transactional( metadata={ 'operation': 'transactional_load' if rows_loaded_batch > 0 else 'skip_duplicate', 'ranges': [r.to_dict() for r in ranges], + 'ranges_complete': ranges_complete, }, ) @@ -670,6 +679,7 @@ def _process_batch_non_transactional( table_name: str, connection_name: str, ranges: Optional[List[BlockRange]], + ranges_complete: bool = False, **kwargs, ) -> Optional[LoadResult]: """ @@ -682,13 +692,17 @@ def _process_batch_non_transactional( table_name: Target table name connection_name: Connection identifier ranges: Block ranges for this batch (if available) + ranges_complete: True when this RecordBatch completes a microbatch (streaming only) **kwargs: Additional options passed to load_batch Returns: LoadResult, or None if batch was skipped as duplicate """ # Check if batch already processed (idempotency / exactly-once) - if ranges and self.state_enabled: + # For streaming: only check when ranges_complete=True (end of microbatch) + # Multiple RecordBatches can share the same microbatch ID, so we must wait + # until the entire microbatch is delivered before checking/marking as processed + if ranges and self.state_enabled and ranges_complete: try: batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges] is_duplicate = self.state_store.is_processed(connection_name, table_name, batch_ids) @@ -696,7 +710,7 @@ def _process_batch_non_transactional( if is_duplicate: # Skip this batch - already processed self.logger.info( - f'Skipping duplicate batch: {len(ranges)} ranges already processed for {table_name}' + f'Skipping duplicate microbatch: {len(ranges)} ranges already processed for {table_name}' ) return LoadResult( rows_loaded=0, @@ -711,14 +725,16 @@ def _process_batch_non_transactional( # BlockRange missing hash - log and continue without idempotency check self.logger.warning(f'Cannot check for duplicates: {e}. Processing batch anyway.') - # Load batch + # Load batch (always load, even if part of larger microbatch) result = self.load_batch(batch_data, table_name, **kwargs) - if result.success and ranges and self.state_enabled: - # Mark batch as processed (for exactly-once semantics) + # Mark batch as processed ONLY when microbatch is complete + # This ensures we don't skip subsequent RecordBatches within the same microbatch + if result.success and ranges and self.state_enabled and ranges_complete: try: batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges] self.state_store.mark_processed(connection_name, table_name, batch_ids) + self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs') except Exception as e: self.logger.error(f'Failed to mark batches as processed: {e}') # Continue anyway - state store provides resume capability diff --git a/src/amp/loaders/implementations/postgresql_loader.py b/src/amp/loaders/implementations/postgresql_loader.py index 6e84703..7bae9f1 100644 --- a/src/amp/loaders/implementations/postgresql_loader.py +++ b/src/amp/loaders/implementations/postgresql_loader.py @@ -119,6 +119,7 @@ def load_batch_transactional( table_name: str, connection_name: str, ranges: List[BlockRange], + ranges_complete: bool = False, ) -> int: """ Load a batch with transactional exactly-once semantics using in-memory state. @@ -135,6 +136,7 @@ def load_batch_transactional( table_name: Target table name connection_name: Connection identifier for tracking ranges: Block ranges covered by this batch + ranges_complete: True when this RecordBatch completes a microbatch (streaming only) Returns: Number of rows loaded (0 if duplicate) @@ -149,24 +151,27 @@ def load_batch_transactional( self.logger.warning(f'Cannot create batch identifiers: {e}. Loading without duplicate check.') batch_ids = [] - # Check if already processed (using in-memory state) - if batch_ids and self.state_store.is_processed(connection_name, table_name, batch_ids): + # Check if already processed ONLY when microbatch is complete + # Multiple RecordBatches can share the same microbatch ID (BlockRange) + if batch_ids and ranges_complete and self.state_store.is_processed(connection_name, table_name, batch_ids): self.logger.info( f'Batch already processed (ranges: {[f"{r.network}:{r.start}-{r.end}" for r in ranges]}), ' f'skipping (state check)' ) return 0 - # Load data + # Load data (always load, even if part of larger microbatch) conn = self.pool.getconn() try: with conn.cursor() as cur: self._copy_arrow_data(cur, batch, table_name) conn.commit() - # Mark as processed after successful load - if batch_ids: + # Mark as processed ONLY when microbatch is complete + # This ensures we don't skip subsequent RecordBatches within the same microbatch + if batch_ids and ranges_complete: self.state_store.mark_processed(connection_name, table_name, batch_ids) + self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs') self.logger.debug( f'Batch load committed: {batch.num_rows} rows, ' diff --git a/tests/integration/test_postgresql_loader.py b/tests/integration/test_postgresql_loader.py index 8b68186..61d379e 100644 --- a/tests/integration/test_postgresql_loader.py +++ b/tests/integration/test_postgresql_loader.py @@ -692,3 +692,126 @@ def test_reorg_preserves_different_networks(self, postgresql_test_config, test_t finally: loader.pool.putconn(conn) + + def test_microbatch_deduplication(self, postgresql_test_config, test_table_name, cleanup_tables): + """ + Test that multiple RecordBatches within the same microbatch are all loaded, + and deduplication only happens at microbatch boundaries when ranges_complete=True. + + This test verifies the fix for the critical bug where we were marking batches + as processed after every RecordBatch instead of waiting for ranges_complete=True. + """ + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + cleanup_tables.append(test_table_name) + + # Enable state management to test deduplication + config_with_state = { + **postgresql_test_config, + 'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True}, + } + loader = PostgreSQLLoader(config_with_state) + + with loader: + # Create table first from the schema + batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) + loader._create_table_from_schema(batch1_data.schema, test_table_name) + + # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange + # This happens when the server sends large microbatches in smaller chunks + + # First RecordBatch of the microbatch (ranges_complete=False) + response1 = ResponseBatch.data_batch( + data=batch1_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=False, # Not the last batch in this microbatch + ), + ) + + # Second RecordBatch of the microbatch (ranges_complete=False) + batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) + response2 = ResponseBatch.data_batch( + data=batch2_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! + ranges_complete=False, # Still not the last batch + ), + ) + + # Third RecordBatch of the microbatch (ranges_complete=True) + batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) + response3 = ResponseBatch.data_batch( + data=batch3_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! + ranges_complete=True, # Last batch in this microbatch - safe to mark as processed + ), + ) + + # Process the microbatch stream + stream = [response1, response2, response3] + results = list( + loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') + ) + + # CRITICAL: All 3 RecordBatches should be loaded successfully + # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates") + assert len(results) == 3, 'All RecordBatches within microbatch should be processed' + assert all(r.success for r in results), 'All batches should succeed' + assert results[0].rows_loaded == 2, 'First batch should load 2 rows' + assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' + assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' + + # Verify total rows in table (all batches loaded) + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + total_count = cur.fetchone()[0] + assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' + + # Verify the actual IDs are present + cur.execute(f'SELECT id FROM {test_table_name} ORDER BY id') + all_ids = [row[0] for row in cur.fetchall()] + assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present' + + finally: + loader.pool.putconn(conn) + + # Now test that re-sending the complete microbatch is properly deduplicated + # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch) + duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) + duplicate_response = ResponseBatch.data_batch( + data=duplicate_batch, + metadata=BatchMetadata( + ranges=[ + BlockRange(network='ethereum', start=100, end=110, hash='0xabc123') + ], # Same range as before! + ranges_complete=True, # Complete microbatch + ), + ) + + # Process duplicate microbatch + duplicate_results = list( + loader.load_stream_continuous( + iter([duplicate_response]), test_table_name, connection_name='test_connection' + ) + ) + + # The duplicate microbatch should be skipped (already processed) + assert len(duplicate_results) == 1 + assert duplicate_results[0].success is True + assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' + assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate' + + # Verify row count unchanged (duplicate was skipped) + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + final_count = cur.fetchone()[0] + assert final_count == 6, 'Row count should not increase after duplicate microbatch' + + finally: + loader.pool.putconn(conn) diff --git a/tests/integration/test_snowflake_loader.py b/tests/integration/test_snowflake_loader.py index 78c2c17..50f96c7 100644 --- a/tests/integration/test_snowflake_loader.py +++ b/tests/integration/test_snowflake_loader.py @@ -1105,3 +1105,111 @@ def test_streaming_error_handling(self, snowflake_streaming_config, test_table_n # and ignores columns that don't exist in the table assert result.success is True assert result.rows_loaded == 2 + + def test_microbatch_deduplication(self, snowflake_config, test_table_name, cleanup_tables): + """ + Test that multiple RecordBatches within the same microbatch are all loaded, + and deduplication only happens at microbatch boundaries when ranges_complete=True. + + This test verifies the fix for the critical bug where we were marking batches + as processed after every RecordBatch instead of waiting for ranges_complete=True. + """ + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + cleanup_tables.append(test_table_name) + + # Enable state management to test deduplication + config_with_state = { + **snowflake_config, + 'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True}, + } + loader = SnowflakeLoader(config_with_state) + + with loader: + # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange + # This happens when the server sends large microbatches in smaller chunks + + # First RecordBatch of the microbatch (ranges_complete=False) + batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) + response1 = ResponseBatch.data_batch( + data=batch1_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=False, # Not the last batch in this microbatch + ), + ) + + # Second RecordBatch of the microbatch (ranges_complete=False) + batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) + response2 = ResponseBatch.data_batch( + data=batch2_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! + ranges_complete=False, # Still not the last batch + ), + ) + + # Third RecordBatch of the microbatch (ranges_complete=True) + batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) + response3 = ResponseBatch.data_batch( + data=batch3_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! + ranges_complete=True, # Last batch in this microbatch - safe to mark as processed + ), + ) + + # Process the microbatch stream + stream = [response1, response2, response3] + results = list( + loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') + ) + + # CRITICAL: All 3 RecordBatches should be loaded successfully + # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates") + assert len(results) == 3, 'All RecordBatches within microbatch should be processed' + assert all(r.success for r in results), 'All batches should succeed' + assert results[0].rows_loaded == 2, 'First batch should load 2 rows' + assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' + assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' + + # Verify total rows in table (all batches loaded) + loader.cursor.execute(f'SELECT COUNT(*) as count FROM {test_table_name}') + total_count = loader.cursor.fetchone()['COUNT'] + assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' + + # Verify the actual IDs are present + loader.cursor.execute(f'SELECT "id" FROM {test_table_name} ORDER BY "id"') + all_ids = [row['id'] for row in loader.cursor.fetchall()] + assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present' + + # Now test that re-sending the complete microbatch is properly deduplicated + # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch) + duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) + duplicate_response = ResponseBatch.data_batch( + data=duplicate_batch, + metadata=BatchMetadata( + ranges=[ + BlockRange(network='ethereum', start=100, end=110, hash='0xabc123') + ], # Same range as before! + ranges_complete=True, # Complete microbatch + ), + ) + + # Process duplicate microbatch + duplicate_results = list( + loader.load_stream_continuous( + iter([duplicate_response]), test_table_name, connection_name='test_connection' + ) + ) + + # The duplicate microbatch should be skipped (already processed) + assert len(duplicate_results) == 1 + assert duplicate_results[0].success is True + assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' + assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate' + + # Verify row count unchanged (duplicate was skipped) + loader.cursor.execute(f'SELECT COUNT(*) as count FROM {test_table_name}') + final_count = loader.cursor.fetchone()['COUNT'] + assert final_count == 6, 'Row count should not increase after duplicate microbatch'