From 8d45a011d3631ed04f359cdf3136803ad570735c Mon Sep 17 00:00:00 2001 From: Krishnanand V P Date: Mon, 15 Dec 2025 14:10:22 +0400 Subject: [PATCH 1/2] Fix incorrect duplicate batch skipping, Fix pending batch being skipped --- src/amp/streaming/reorg.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/amp/streaming/reorg.py b/src/amp/streaming/reorg.py index 9083db7..f250631 100644 --- a/src/amp/streaming/reorg.py +++ b/src/amp/streaming/reorg.py @@ -46,16 +46,15 @@ def __next__(self) -> ResponseBatch: KeyboardInterrupt: When user cancels the stream """ try: + # Check if we have a pending batch from a previous reorg detection + if hasattr(self, '_pending_batch'): + pending = self._pending_batch + delattr(self, '_pending_batch') + return pending + # Get next batch from underlying stream batch = next(self.stream_iterator) - # Note: ranges_complete flag is handled by CheckpointStore in load_stream_continuous - # Check if this batch contains only duplicate ranges - if self._is_duplicate_batch(batch.metadata.ranges): - self.logger.debug(f'Skipping duplicate batch with ranges: {batch.metadata.ranges}') - # Recursively call to get the next non-duplicate batch - return self.__next__() - # Detect reorgs by comparing with previous ranges invalidation_ranges = self._detect_reorg(batch.metadata.ranges) @@ -70,13 +69,6 @@ def __next__(self) -> ResponseBatch: self._pending_batch = batch return ResponseBatch.reorg_batch(invalidation_ranges) - # Check if we have a pending batch from a previous reorg detection - # REVIEW: I think we should remove this - if hasattr(self, '_pending_batch'): - pending = self._pending_batch - delattr(self, '_pending_batch') - return pending - # Normal case - just return the data batch return batch From 2cfe41b8635af0f81cbb39093c829695861a8039 Mon Sep 17 00:00:00 2001 From: Krishnanand V P Date: Tue, 16 Dec 2025 18:00:13 +0400 Subject: [PATCH 2/2] Better crash recovery with state store --- src/amp/loaders/base.py | 94 +++++++++++++++++++++++- tests/unit/test_crash_recovery.py | 115 ++++++++++++++++++++++++++++++ 2 files changed, 208 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_crash_recovery.py diff --git a/src/amp/loaders/base.py b/src/amp/loaders/base.py index 6d8a736..abbc635 100644 --- a/src/amp/loaders/base.py +++ b/src/amp/loaders/base.py @@ -77,11 +77,19 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None: else: self.state_store = NullStreamStateStore() + # Track tables that have undergone crash recovery + self._crash_recovery_done: set[str] = set() + @property def is_connected(self) -> bool: """Check if the loader is connected to the target system.""" return self._is_connected + @property + def loader_type(self) -> str: + """Get the loader type identifier (e.g., 'postgresql', 'redis').""" + return self.__class__.__name__.replace('Loader', '').lower() + def _parse_config(self, config: Dict[str, Any]) -> TConfig: """ Parse configuration into loader-specific format. @@ -446,11 +454,21 @@ def load_stream_continuous( if not self._is_connected: self.connect() + connection_name = kwargs.get('connection_name') + if connection_name is None: + connection_name = self.loader_type + + if table_name not in self._crash_recovery_done: + self.logger.info(f'Running crash recovery for table {table_name} (connection: {connection_name})') + self._rewind_to_watermark(table_name, connection_name) + self._crash_recovery_done.add(table_name) + else: + self.logger.info(f'Crash recovery already done for table {table_name}') + rows_loaded = 0 start_time = time.time() batch_count = 0 reorg_count = 0 - connection_name = kwargs.get('connection_name', 'unknown') worker_id = kwargs.get('worker_id', 0) try: @@ -748,6 +766,80 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str, 'Streaming with reorg detection requires implementing this method.' ) + def _rewind_to_watermark(self, table_name: Optional[str] = None, connection_name: Optional[str] = None) -> None: + """ + Reset state and data to the last checkpointed watermark. + + Removes any data written after the last completed watermark, + ensuring resumable streams start from a consistent state. + + This handles crash recovery by removing uncommitted data from + incomplete microbatches between watermarks. + + Args: + table_name: Table to clean up. If None, processes all tables. + connection_name: Connection identifier. If None, uses default. + + Example: + def connect(self): + # Connect to database + self._establish_connection() + self._is_connected = True + + # Crash recovery - clean up uncommitted data + self._rewind_to_watermark() + """ + if not self.state_enabled: + self.logger.debug('State tracking disabled, skipping crash recovery') + return + + if connection_name is None: + connection_name = self.loader_type + + tables_to_process = [] + if table_name is None: + self.logger.debug('table_name=None not yet implemented, skipping crash recovery') + return + else: + tables_to_process = [table_name] + + for table in tables_to_process: + resume_pos = self.state_store.get_resume_position(connection_name, table) + if not resume_pos: + self.logger.debug(f'No watermark found for {table}, skipping crash recovery') + continue + + for range_obj in resume_pos.ranges: + from_block = range_obj.end + 1 + + self.logger.info( + f'Crash recovery: Cleaning up {table} data for {range_obj.network} from block {from_block} onwards' + ) + + # Create invalidation range for _handle_reorg() + # Note: BlockRange requires 'end' field, but invalidate_from_block() only uses 'start' + # Setting end=from_block is a valid placeholder since the actual range is open-ended + invalidation_ranges = [BlockRange(network=range_obj.network, start=from_block, end=from_block)] + + try: + self._handle_reorg(invalidation_ranges, table, connection_name) + self.logger.info(f'Crash recovery completed for {range_obj.network} in {table}') + + except NotImplementedError: + invalidated = self.state_store.invalidate_from_block( + connection_name, table, range_obj.network, from_block + ) + + if invalidated: + self.logger.warning( + f'Crash recovery: Cleared {len(invalidated)} batches from state ' + f'for {range_obj.network} but cannot delete data from {table}. ' + f'{self.__class__.__name__} does not support data deletion. ' + f'Duplicates may occur on resume.' + ) + else: + self.logger.debug(f'No uncommitted batches found for {range_obj.network}') + def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRange]) -> pa.RecordBatch: """ Add metadata columns for streaming data with compact batch identification. diff --git a/tests/unit/test_crash_recovery.py b/tests/unit/test_crash_recovery.py new file mode 100644 index 0000000..4875114 --- /dev/null +++ b/tests/unit/test_crash_recovery.py @@ -0,0 +1,115 @@ +""" +Unit tests for crash recovery via _rewind_to_watermark() method. + +These tests verify the crash recovery logic works correctly in isolation. +""" + +from unittest.mock import Mock + +import pytest + +from src.amp.loaders.base import LoadResult +from src.amp.streaming.types import BlockRange, ResumeWatermark +from tests.fixtures.mock_clients import MockDataLoader + + +@pytest.fixture +def mock_loader() -> MockDataLoader: + """Create a mock loader with state store""" + loader = MockDataLoader({'test': 'config'}) + loader.connect() + + loader.state_store = Mock() + loader.state_enabled = True + + return loader + + +@pytest.mark.unit +class TestCrashRecovery: + """Test _rewind_to_watermark() crash recovery method""" + + def test_rewind_with_no_state(self, mock_loader): + """Should return early if state_enabled=False""" + mock_loader.state_enabled = False + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader.state_store.get_resume_position.assert_not_called() + + def test_rewind_with_no_watermark(self, mock_loader): + """Should return early if no watermark exists""" + mock_loader.state_store.get_resume_position = Mock(return_value=None) + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader.state_store.get_resume_position.assert_called_once_with('test_conn', 'test_table') + + def test_rewind_calls_handle_reorg(self, mock_loader): + """Should call _handle_reorg with correct invalidation ranges""" + watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader._handle_reorg = Mock() + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader._handle_reorg.assert_called_once() + call_args = mock_loader._handle_reorg.call_args + invalidation_ranges = call_args[0][0] + assert len(invalidation_ranges) == 1 + assert invalidation_ranges[0].network == 'ethereum' + assert invalidation_ranges[0].start == 1011 + assert call_args[0][1] == 'test_table' + assert call_args[0][2] == 'test_conn' + + def test_rewind_handles_not_implemented(self, mock_loader): + """Should gracefully handle loaders without _handle_reorg""" + watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader._handle_reorg = Mock(side_effect=NotImplementedError()) + mock_loader.state_store.invalidate_from_block = Mock(return_value=[]) + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + mock_loader.state_store.invalidate_from_block.assert_called_once_with( + 'test_conn', 'test_table', 'ethereum', 1011 + ) + + def test_rewind_with_multiple_networks(self, mock_loader): + """Should process ethereum and polygon separately""" + watermark = ResumeWatermark( + ranges=[ + BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc'), + BlockRange(network='polygon', start=2000, end=2010, hash='0xdef'), + ] + ) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader._handle_reorg = Mock() + + mock_loader._rewind_to_watermark('test_table', 'test_conn') + + assert mock_loader._handle_reorg.call_count == 2 + + first_call = mock_loader._handle_reorg.call_args_list[0] + assert first_call[0][0][0].network == 'ethereum' + assert first_call[0][0][0].start == 1011 + + second_call = mock_loader._handle_reorg.call_args_list[1] + assert second_call[0][0][0].network == 'polygon' + assert second_call[0][0][0].start == 2011 + + def test_rewind_with_table_name_none(self, mock_loader): + """Should return early when table_name=None (not yet implemented)""" + mock_loader._rewind_to_watermark(table_name=None, connection_name='test_conn') + + mock_loader.state_store.get_resume_position.assert_not_called() + + def test_rewind_uses_default_connection_name(self, mock_loader): + """Should use default connection name from loader class""" + watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')]) + mock_loader.state_store.get_resume_position = Mock(return_value=watermark) + mock_loader._handle_reorg = Mock() + + mock_loader._rewind_to_watermark('test_table', connection_name=None) + + mock_loader.state_store.get_resume_position.assert_called_once_with('mockdata', 'test_table')