Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 93 additions & 1 deletion src/amp/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,19 @@ def __init__(self, config: Dict[str, Any], label_manager=None) -> None:
else:
self.state_store = NullStreamStateStore()

# Track tables that have undergone crash recovery
self._crash_recovery_done: set[str] = set()

@property
def is_connected(self) -> bool:
"""Check if the loader is connected to the target system."""
return self._is_connected

@property
def loader_type(self) -> str:
"""Get the loader type identifier (e.g., 'postgresql', 'redis')."""
return self.__class__.__name__.replace('Loader', '').lower()

def _parse_config(self, config: Dict[str, Any]) -> TConfig:
"""
Parse configuration into loader-specific format.
Expand Down Expand Up @@ -446,11 +454,21 @@ def load_stream_continuous(
if not self._is_connected:
self.connect()

connection_name = kwargs.get('connection_name')
if connection_name is None:
connection_name = self.loader_type

if table_name not in self._crash_recovery_done:
self.logger.info(f'Running crash recovery for table {table_name} (connection: {connection_name})')
self._rewind_to_watermark(table_name, connection_name)
self._crash_recovery_done.add(table_name)
else:
self.logger.info(f'Crash recovery already done for table {table_name}')

rows_loaded = 0
start_time = time.time()
batch_count = 0
reorg_count = 0
connection_name = kwargs.get('connection_name', 'unknown')
worker_id = kwargs.get('worker_id', 0)

try:
Expand Down Expand Up @@ -748,6 +766,80 @@ def _handle_reorg(self, invalidation_ranges: List[BlockRange], table_name: str,
'Streaming with reorg detection requires implementing this method.'
)

def _rewind_to_watermark(self, table_name: Optional[str] = None, connection_name: Optional[str] = None) -> None:
"""
Reset state and data to the last checkpointed watermark.

Removes any data written after the last completed watermark,
ensuring resumable streams start from a consistent state.

This handles crash recovery by removing uncommitted data from
incomplete microbatches between watermarks.

Args:
table_name: Table to clean up. If None, processes all tables.
connection_name: Connection identifier. If None, uses default.

Example:
def connect(self):
# Connect to database
self._establish_connection()
self._is_connected = True

# Crash recovery - clean up uncommitted data
self._rewind_to_watermark()
"""
if not self.state_enabled:
self.logger.debug('State tracking disabled, skipping crash recovery')
return

if connection_name is None:
connection_name = self.loader_type

tables_to_process = []
if table_name is None:
self.logger.debug('table_name=None not yet implemented, skipping crash recovery')
return
else:
tables_to_process = [table_name]

for table in tables_to_process:
resume_pos = self.state_store.get_resume_position(connection_name, table)
if not resume_pos:
self.logger.debug(f'No watermark found for {table}, skipping crash recovery')
continue

for range_obj in resume_pos.ranges:
from_block = range_obj.end + 1

self.logger.info(
f'Crash recovery: Cleaning up {table} data for {range_obj.network} from block {from_block} onwards'
)

# Create invalidation range for _handle_reorg()
# Note: BlockRange requires 'end' field, but invalidate_from_block() only uses 'start'
# Setting end=from_block is a valid placeholder since the actual range is open-ended
invalidation_ranges = [BlockRange(network=range_obj.network, start=from_block, end=from_block)]

try:
self._handle_reorg(invalidation_ranges, table, connection_name)
self.logger.info(f'Crash recovery completed for {range_obj.network} in {table}')

except NotImplementedError:
invalidated = self.state_store.invalidate_from_block(
connection_name, table, range_obj.network, from_block
)

if invalidated:
self.logger.warning(
f'Crash recovery: Cleared {len(invalidated)} batches from state '
f'for {range_obj.network} but cannot delete data from {table}. '
f'{self.__class__.__name__} does not support data deletion. '
f'Duplicates may occur on resume.'
)
else:
self.logger.debug(f'No uncommitted batches found for {range_obj.network}')

def _add_metadata_columns(self, data: pa.RecordBatch, block_ranges: List[BlockRange]) -> pa.RecordBatch:
"""
Add metadata columns for streaming data with compact batch identification.
Expand Down
20 changes: 6 additions & 14 deletions src/amp/streaming/reorg.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,15 @@ def __next__(self) -> ResponseBatch:
KeyboardInterrupt: When user cancels the stream
"""
try:
# Check if we have a pending batch from a previous reorg detection
if hasattr(self, '_pending_batch'):
pending = self._pending_batch
delattr(self, '_pending_batch')
return pending

# Get next batch from underlying stream
batch = next(self.stream_iterator)

# Note: ranges_complete flag is handled by CheckpointStore in load_stream_continuous
# Check if this batch contains only duplicate ranges
if self._is_duplicate_batch(batch.metadata.ranges):
self.logger.debug(f'Skipping duplicate batch with ranges: {batch.metadata.ranges}')
# Recursively call to get the next non-duplicate batch
return self.__next__()

# Detect reorgs by comparing with previous ranges
invalidation_ranges = self._detect_reorg(batch.metadata.ranges)

Expand All @@ -70,13 +69,6 @@ def __next__(self) -> ResponseBatch:
self._pending_batch = batch
return ResponseBatch.reorg_batch(invalidation_ranges)

# Check if we have a pending batch from a previous reorg detection
# REVIEW: I think we should remove this
if hasattr(self, '_pending_batch'):
pending = self._pending_batch
delattr(self, '_pending_batch')
return pending

# Normal case - just return the data batch
return batch

Expand Down
115 changes: 115 additions & 0 deletions tests/unit/test_crash_recovery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
Unit tests for crash recovery via _rewind_to_watermark() method.

These tests verify the crash recovery logic works correctly in isolation.
"""

from unittest.mock import Mock

import pytest

from src.amp.loaders.base import LoadResult

Check failure on line 11 in tests/unit/test_crash_recovery.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

tests/unit/test_crash_recovery.py:11:34: F401 `src.amp.loaders.base.LoadResult` imported but unused
from src.amp.streaming.types import BlockRange, ResumeWatermark
from tests.fixtures.mock_clients import MockDataLoader


@pytest.fixture
def mock_loader() -> MockDataLoader:
"""Create a mock loader with state store"""
loader = MockDataLoader({'test': 'config'})
loader.connect()

loader.state_store = Mock()
loader.state_enabled = True

return loader


@pytest.mark.unit
class TestCrashRecovery:
"""Test _rewind_to_watermark() crash recovery method"""

def test_rewind_with_no_state(self, mock_loader):
"""Should return early if state_enabled=False"""
mock_loader.state_enabled = False

mock_loader._rewind_to_watermark('test_table', 'test_conn')

mock_loader.state_store.get_resume_position.assert_not_called()

def test_rewind_with_no_watermark(self, mock_loader):
"""Should return early if no watermark exists"""
mock_loader.state_store.get_resume_position = Mock(return_value=None)

mock_loader._rewind_to_watermark('test_table', 'test_conn')

mock_loader.state_store.get_resume_position.assert_called_once_with('test_conn', 'test_table')

def test_rewind_calls_handle_reorg(self, mock_loader):
"""Should call _handle_reorg with correct invalidation ranges"""
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader._handle_reorg = Mock()

mock_loader._rewind_to_watermark('test_table', 'test_conn')

mock_loader._handle_reorg.assert_called_once()
call_args = mock_loader._handle_reorg.call_args
invalidation_ranges = call_args[0][0]
assert len(invalidation_ranges) == 1
assert invalidation_ranges[0].network == 'ethereum'
assert invalidation_ranges[0].start == 1011
assert call_args[0][1] == 'test_table'
assert call_args[0][2] == 'test_conn'

def test_rewind_handles_not_implemented(self, mock_loader):
"""Should gracefully handle loaders without _handle_reorg"""
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader._handle_reorg = Mock(side_effect=NotImplementedError())
mock_loader.state_store.invalidate_from_block = Mock(return_value=[])

mock_loader._rewind_to_watermark('test_table', 'test_conn')

mock_loader.state_store.invalidate_from_block.assert_called_once_with(
'test_conn', 'test_table', 'ethereum', 1011
)

def test_rewind_with_multiple_networks(self, mock_loader):
"""Should process ethereum and polygon separately"""
watermark = ResumeWatermark(
ranges=[
BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc'),
BlockRange(network='polygon', start=2000, end=2010, hash='0xdef'),
]
)
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader._handle_reorg = Mock()

mock_loader._rewind_to_watermark('test_table', 'test_conn')

assert mock_loader._handle_reorg.call_count == 2

first_call = mock_loader._handle_reorg.call_args_list[0]
assert first_call[0][0][0].network == 'ethereum'
assert first_call[0][0][0].start == 1011

second_call = mock_loader._handle_reorg.call_args_list[1]
assert second_call[0][0][0].network == 'polygon'
assert second_call[0][0][0].start == 2011

def test_rewind_with_table_name_none(self, mock_loader):
"""Should return early when table_name=None (not yet implemented)"""
mock_loader._rewind_to_watermark(table_name=None, connection_name='test_conn')

mock_loader.state_store.get_resume_position.assert_not_called()

def test_rewind_uses_default_connection_name(self, mock_loader):
"""Should use default connection name from loader class"""
watermark = ResumeWatermark(ranges=[BlockRange(network='ethereum', start=1000, end=1010, hash='0xabc')])
mock_loader.state_store.get_resume_position = Mock(return_value=watermark)
mock_loader._handle_reorg = Mock()

mock_loader._rewind_to_watermark('test_table', connection_name=None)

mock_loader.state_store.get_resume_position.assert_called_once_with('mockdata', 'test_table')
Loading