From ac94124e76dd99e18bf9043034acbf4660dc7349 Mon Sep 17 00:00:00 2001 From: Krishnanand V P Date: Tue, 25 Nov 2025 16:20:13 +0530 Subject: [PATCH] Implement a LMDB stream state store --- src/amp/streaming/__init__.py | 2 + src/amp/streaming/lmdb_state.py | 371 +++++++++++++++++++++++++++++++ tests/unit/test_lmdb_state.py | 377 ++++++++++++++++++++++++++++++++ 3 files changed, 750 insertions(+) create mode 100644 src/amp/streaming/lmdb_state.py create mode 100644 tests/unit/test_lmdb_state.py diff --git a/src/amp/streaming/__init__.py b/src/amp/streaming/__init__.py index 9361aee..7107764 100644 --- a/src/amp/streaming/__init__.py +++ b/src/amp/streaming/__init__.py @@ -1,5 +1,6 @@ # Streaming module for continuous data loading from .iterator import StreamingResultIterator +from .lmdb_state import LMDBStreamStateStore from .parallel import ( BlockRangePartitionStrategy, ParallelConfig, @@ -35,6 +36,7 @@ 'StreamStateStore', 'InMemoryStreamStateStore', 'NullStreamStateStore', + 'LMDBStreamStateStore', 'BatchIdentifier', 'ProcessedBatch', ] diff --git a/src/amp/streaming/lmdb_state.py b/src/amp/streaming/lmdb_state.py new file mode 100644 index 0000000..28e8da1 --- /dev/null +++ b/src/amp/streaming/lmdb_state.py @@ -0,0 +1,371 @@ +""" +LMDB-based stream state store for durable batch tracking. + +This implementation uses LMDB (Lightning Memory-Mapped Database) for fast, +embedded, durable storage of batch processing state. It can be used with any +loader (Kafka, PostgreSQL, etc.) to provide crash recovery and idempotency. +""" + +import json +import logging +from pathlib import Path +from typing import Dict, List, Optional + +import lmdb + +from .state import BatchIdentifier, StreamStateStore +from .types import BlockRange, ResumeWatermark + + +class LMDBStreamStateStore(StreamStateStore): + env: lmdb.Environment + """ + Generic LMDB-based state store for tracking processed batches. + + Uses LMDB for fast, durable key-value storage with ACID transactions. + Tracks individual batches with unique hash-based IDs to support: + - Crash recovery and resume + - Idempotency (duplicate detection) + - Reorg handling (invalidate by block hash) + - Gap detection for parallel loading + + Uses two LMDB sub-databases for efficient queries: + 1. "batches" - Individual batch records keyed by batch_id + 2. "metadata" - Max block metadata per network for fast resume position queries + + Batch database layout: + - Key: {connection_name}|{table_name}|{batch_id} + - Value: JSON with {network, start_block, end_block, end_hash, start_parent_hash} + + Metadata database layout: + - Key: {connection_name}|{table_name}|{network} + - Value: JSON with {end_block, end_hash, start_parent_hash} (max processed block) + """ + + def __init__( + self, + connection_name: str, + data_dir: str = '.amp_state', + map_size: int = 10 * 1024 * 1024 * 1024, + sync: bool = True, + ): + """ + Initialize LMDB state store with two sub-databases. + + Args: + connection_name: Name of the connection (for multi-connection support) + data_dir: Directory to store LMDB database files + map_size: Maximum database size in bytes (default: 10GB) + sync: Whether to sync writes to disk (True for durability, False for speed) + """ + self.connection_name = connection_name + self.data_dir = Path(data_dir) + self.data_dir.mkdir(parents=True, exist_ok=True) + + self.logger = logging.getLogger(__name__) + + self.env = lmdb.open(str(self.data_dir), map_size=map_size, sync=sync, max_dbs=2) + + self.batches_db = self.env.open_db(b'batches') + self.metadata_db = self.env.open_db(b'metadata') + + self.logger.info(f'Initialized LMDB state store at {self.data_dir} with 2 sub-databases') + + def _make_batch_key(self, connection_name: str, table_name: str, batch_id: str) -> bytes: + """Create composite key for batch database.""" + return f'{connection_name}|{table_name}|{batch_id}'.encode('utf-8') + + def _make_metadata_key(self, connection_name: str, table_name: str, network: str) -> bytes: + """Create composite key for metadata database.""" + return f'{connection_name}|{table_name}|{network}'.encode('utf-8') + + def _parse_key(self, key: bytes) -> tuple[str, str, str]: + """Parse composite key into (connection_name, table_name, batch_id/network).""" + parts = key.decode('utf-8').split('|') + return parts[0], parts[1], parts[2] + + def _serialize_batch(self, batch: BatchIdentifier) -> bytes: + """Serialize BatchIdentifier to JSON bytes.""" + batch_value_dict = { + 'network': batch.network, + 'start_block': batch.start_block, + 'end_block': batch.end_block, + 'end_hash': batch.end_hash, + 'start_parent_hash': batch.start_parent_hash, + } + return json.dumps(batch_value_dict).encode('utf-8') + + def _serialize_metadata(self, end_block: int, end_hash: str, start_parent_hash: str) -> bytes: + """Serialize metadata to JSON bytes.""" + meta_value_dict = { + 'end_block': end_block, + 'end_hash': end_hash, + 'start_parent_hash': start_parent_hash, + } + return json.dumps(meta_value_dict).encode('utf-8') + + def _deserialize_batch(self, value: bytes) -> Dict: + """Deserialize batch data from JSON bytes.""" + return json.loads(value.decode('utf-8')) + + def is_processed(self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier]) -> bool: + """ + Check if all given batches have already been processed. + + Args: + connection_name: Connection identifier + table_name: Name of the table being loaded + batch_ids: List of batch identifiers to check + + Returns: + True only if ALL batches are already processed + """ + if not batch_ids: + return True + + with self.env.begin(db=self.batches_db) as txn: + for batch_id in batch_ids: + key = self._make_batch_key(connection_name, table_name, batch_id.unique_id) + value = txn.get(key) + if value is None: + return False + + return True + + def mark_processed(self, connection_name: str, table_name: str, batch_ids: List[BatchIdentifier]) -> None: + """ + Mark batches as processed in durable storage. + + Atomically updates both batch records and metadata (max block per network). + + Args: + connection_name: Connection identifier + table_name: Name of the table being loaded + batch_ids: List of batch identifiers to mark as processed + """ + with self.env.begin(write=True) as txn: + for batch in batch_ids: + batch_key = self._make_batch_key(connection_name, table_name, batch.unique_id) + batch_value = self._serialize_batch(batch) + txn.put(batch_key, batch_value, db=self.batches_db) + + meta_key = self._make_metadata_key(connection_name, table_name, batch.network) + current_meta = txn.get(meta_key, db=self.metadata_db) + + should_update = False + if current_meta is None: + should_update = True + else: + current_meta_dict = self._deserialize_batch(current_meta) + if batch.end_block > current_meta_dict['end_block']: + should_update = True + + if should_update: + meta_value = self._serialize_metadata(batch.end_block, batch.end_hash, batch.start_parent_hash) + txn.put(meta_key, meta_value, db=self.metadata_db) + + self.logger.debug(f'Marked {len(batch_ids)} batches as processed in {table_name}') + + def get_resume_position( + self, connection_name: str, table_name: str, detect_gaps: bool = False + ) -> Optional[ResumeWatermark]: + """ + Get the resume watermark (max processed block per network). + + Reads only from metadata database. Does not scan batch records. + + Args: + connection_name: Connection identifier + table_name: Destination table name + detect_gaps: If True, detect gaps. Not implemented - raises error. + + Returns: + ResumeWatermark with max block ranges for all networks, or None if no state exists + + Raises: + NotImplementedError: If detect_gaps=True + """ + if detect_gaps: + raise NotImplementedError('Gap detection not implemented in LMDB state store') + + prefix = f'{connection_name}|{table_name}|'.encode('utf-8') + ranges = [] + + with self.env.begin(db=self.metadata_db) as txn: + cursor = txn.cursor() + + if not cursor.set_range(prefix): + return None + + for key, value in cursor: + if not key.startswith(prefix): + break + + try: + _, _, network = self._parse_key(key) + meta_data = self._deserialize_batch(value) + + ranges.append( + BlockRange( + network=network, + start=meta_data['end_block'], + end=meta_data['end_block'], + hash=meta_data.get('end_hash'), + prev_hash=meta_data.get('start_parent_hash'), + ) + ) + + except (json.JSONDecodeError, KeyError) as e: + self.logger.warning(f'Failed to parse metadata: {e}') + continue + + if not ranges: + return None + + return ResumeWatermark(ranges=ranges) + + def invalidate_from_block( + self, connection_name: str, table_name: str, network: str, from_block: int + ) -> List[BatchIdentifier]: + """ + Invalidate (delete) all batches from a specific block onwards. + + Used for reorg handling to remove invalidated data. Requires full scan + of batches database to find matching batches. + + Args: + connection_name: Connection identifier + table_name: Name of the table + network: Network name + from_block: Block number to invalidate from (inclusive) + + Returns: + List of BatchIdentifier objects that were invalidated + """ + prefix = f'{connection_name}|{table_name}|'.encode('utf-8') + invalidated_batch_ids = [] + keys_to_delete = [] + + with self.env.begin(db=self.batches_db) as txn: + cursor = txn.cursor() + + if not cursor.set_range(prefix): + return [] + + for key, value in cursor: + if not key.startswith(prefix): + break + + try: + batch_data = self._deserialize_batch(value) + + if batch_data['network'] == network and batch_data['end_block'] >= from_block: + batch_id = BatchIdentifier( + network=batch_data['network'], + start_block=batch_data['start_block'], + end_block=batch_data['end_block'], + end_hash=batch_data.get('end_hash'), + start_parent_hash=batch_data.get('start_parent_hash'), + ) + invalidated_batch_ids.append(batch_id) + keys_to_delete.append(key) + + except (json.JSONDecodeError, KeyError) as e: + self.logger.warning(f'Failed to parse batch data during invalidation: {e}') + continue + + if keys_to_delete: + with self.env.begin(write=True) as txn: + for key in keys_to_delete: + txn.delete(key, db=self.batches_db) + + meta_key = self._make_metadata_key(connection_name, table_name, network) + + remaining_batches = [] + cursor = txn.cursor(db=self.batches_db) + if cursor.set_range(prefix): + for key, value in cursor: + if not key.startswith(prefix): + break + try: + batch_data = self._deserialize_batch(value) + if batch_data['network'] == network: + remaining_batches.append(batch_data) + except (json.JSONDecodeError, KeyError) as e: + self.logger.warning(f'Failed to parse batch data during metadata recalculation: {e}') + continue + + if remaining_batches: + remaining_batches.sort(key=lambda b: b['end_block']) + max_batch = remaining_batches[-1] + meta_value = self._serialize_metadata( + max_batch['end_block'], max_batch.get('end_hash'), max_batch.get('start_parent_hash') + ) + txn.put(meta_key, meta_value, db=self.metadata_db) + else: + txn.delete(meta_key, db=self.metadata_db) + + self.logger.info( + f'Invalidated {len(invalidated_batch_ids)} batches from block {from_block} on {network} in {table_name}' + ) + + return invalidated_batch_ids + + def cleanup_before_block(self, connection_name: str, table_name: str, network: str, before_block: int) -> None: + """ + Clean up old batch records before a specific block. + + Removes batches where end_block < before_block. Requires full scan + to find matching batches for the given network. + + Args: + connection_name: Connection identifier + table_name: Name of the table + network: Network name + before_block: Block number to clean up before (exclusive) + """ + prefix = f'{connection_name}|{table_name}|'.encode('utf-8') + keys_to_delete = [] + + with self.env.begin(db=self.batches_db) as txn: + cursor = txn.cursor() + + if not cursor.set_range(prefix): + return + + for key, value in cursor: + if not key.startswith(prefix): + break + + try: + batch_data = self._deserialize_batch(value) + + if batch_data['network'] == network and batch_data['end_block'] < before_block: + keys_to_delete.append(key) + + except (json.JSONDecodeError, KeyError) as e: + self.logger.warning(f'Failed to parse batch data during cleanup: {e}') + continue + + if keys_to_delete: + with self.env.begin(write=True, db=self.batches_db) as txn: + for key in keys_to_delete: + txn.delete(key) + + self.logger.info( + f'Cleaned up {len(keys_to_delete)} old batches before block {before_block} on {network} in {table_name}' + ) + + def close(self) -> None: + """Close the LMDB environment.""" + if self.env: + self.env.close() + self.logger.info('Closed LMDB state store') + + def __enter__(self) -> 'LMDBStreamStateStore': + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Context manager exit.""" + self.close() diff --git a/tests/unit/test_lmdb_state.py b/tests/unit/test_lmdb_state.py new file mode 100644 index 0000000..0834c15 --- /dev/null +++ b/tests/unit/test_lmdb_state.py @@ -0,0 +1,377 @@ +""" +Unit tests for LMDB-based stream state store. + +Tests the LMDBStreamStateStore implementation that provides durable, +crash-recoverable state tracking using LMDB key-value database. +""" + +import tempfile +from pathlib import Path + +import pytest + +from amp.streaming.lmdb_state import LMDBStreamStateStore +from amp.streaming.state import BatchIdentifier + + +@pytest.fixture +def temp_lmdb_dir(): + with tempfile.TemporaryDirectory() as tmpdir: + yield tmpdir + + +@pytest.fixture +def lmdb_store(temp_lmdb_dir): + store = LMDBStreamStateStore(connection_name='test_conn', data_dir=temp_lmdb_dir, sync=True) + yield store + store.close() + + +class TestLMDBStreamStateStore: + def test_initialization(self, temp_lmdb_dir): + store = LMDBStreamStateStore(connection_name='test', data_dir=temp_lmdb_dir) + + assert store.connection_name == 'test' + assert Path(temp_lmdb_dir).exists() + + store.close() + + def test_mark_and_check_processed(self, lmdb_store): + batch_id = BatchIdentifier('ethereum', 100, 200, '0xabc') + + assert lmdb_store.is_processed('conn1', 'table1', [batch_id]) is False + + lmdb_store.mark_processed('conn1', 'table1', [batch_id]) + + assert lmdb_store.is_processed('conn1', 'table1', [batch_id]) is True + + def test_is_processed_empty_list(self, lmdb_store): + result = lmdb_store.is_processed('conn1', 'table1', []) + assert result is True + + def test_multiple_batches_all_must_be_processed(self, lmdb_store): + batch_id1 = BatchIdentifier('ethereum', 100, 200, '0xabc') + batch_id2 = BatchIdentifier('ethereum', 200, 300, '0xdef') + batch_id3 = BatchIdentifier('ethereum', 300, 400, '0x123') + batch_id4 = BatchIdentifier('ethereum', 400, 500, '0x456') + batch_id5 = BatchIdentifier('ethereum', 500, 600, '0x789') + + lmdb_store.mark_processed('conn1', 'table1', [batch_id1, batch_id3, batch_id5]) + + assert lmdb_store.is_processed('conn1', 'table1', [batch_id1]) is True + assert lmdb_store.is_processed('conn1', 'table1', [batch_id1, batch_id2]) is False + assert lmdb_store.is_processed('conn1', 'table1', [batch_id1, batch_id3, batch_id5]) is True + + lmdb_store.mark_processed('conn1', 'table1', [batch_id2, batch_id4]) + + assert ( + lmdb_store.is_processed('conn1', 'table1', [batch_id1, batch_id2, batch_id3, batch_id4, batch_id5]) is True + ) + + def test_separate_networks(self, lmdb_store): + eth_batch = BatchIdentifier('ethereum', 100, 200, '0xabc') + poly_batch = BatchIdentifier('polygon', 100, 200, '0xdef') + + lmdb_store.mark_processed('conn1', 'table1', [eth_batch]) + + assert lmdb_store.is_processed('conn1', 'table1', [eth_batch]) is True + assert lmdb_store.is_processed('conn1', 'table1', [poly_batch]) is False + + def test_separate_connections_and_tables(self, lmdb_store): + batch_id = BatchIdentifier('ethereum', 100, 200, '0xabc') + + lmdb_store.mark_processed('conn1', 'table1', [batch_id]) + + assert lmdb_store.is_processed('conn2', 'table1', [batch_id]) is False + assert lmdb_store.is_processed('conn1', 'table2', [batch_id]) is False + + def test_get_resume_position_empty(self, lmdb_store): + watermark = lmdb_store.get_resume_position('conn1', 'table1') + + assert watermark is None + + def test_get_resume_position_single_network(self, lmdb_store): + batch1 = BatchIdentifier('ethereum', 100, 200, '0xabc', '0xparent1') + batch2 = BatchIdentifier('ethereum', 200, 300, '0xdef', '0xparent2') + batch3 = BatchIdentifier('ethereum', 300, 400, '0x123', '0xparent3') + + lmdb_store.mark_processed('conn1', 'table1', [batch1]) + lmdb_store.mark_processed('conn1', 'table1', [batch2]) + lmdb_store.mark_processed('conn1', 'table1', [batch3]) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + + assert watermark is not None + assert len(watermark.ranges) == 1 + assert watermark.ranges[0].network == 'ethereum' + assert watermark.ranges[0].end == 400 + assert watermark.ranges[0].hash == '0x123' + assert watermark.ranges[0].prev_hash == '0xparent3' + + def test_get_resume_position_multiple_networks(self, lmdb_store): + eth_batch = BatchIdentifier('ethereum', 100, 200, '0xabc', '0xeth_parent') + poly_batch = BatchIdentifier('polygon', 500, 600, '0xdef', '0xpoly_parent') + arb_batch = BatchIdentifier('arbitrum', 1000, 1100, '0x123', '0xarb_parent') + + lmdb_store.mark_processed('conn1', 'table1', [eth_batch]) + lmdb_store.mark_processed('conn1', 'table1', [poly_batch]) + lmdb_store.mark_processed('conn1', 'table1', [arb_batch]) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + + assert watermark is not None + assert len(watermark.ranges) == 3 + + networks = {r.network: r.end for r in watermark.ranges} + assert networks['ethereum'] == 200 + assert networks['polygon'] == 600 + assert networks['arbitrum'] == 1100 + + def test_metadata_updates_with_max_block(self, lmdb_store): + batch1 = BatchIdentifier('ethereum', 100, 200, '0xabc') + batch2 = BatchIdentifier('ethereum', 300, 400, '0xdef') + batch3 = BatchIdentifier('ethereum', 200, 250, '0x123') + batch4 = BatchIdentifier('ethereum', 50, 75, '0x456') + batch5 = BatchIdentifier('ethereum', 350, 380, '0x789') + + lmdb_store.mark_processed('conn1', 'table1', [batch1]) + lmdb_store.mark_processed('conn1', 'table1', [batch2]) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + assert watermark.ranges[0].end == 400 + + lmdb_store.mark_processed('conn1', 'table1', [batch3, batch4, batch5]) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + assert watermark.ranges[0].end == 400 + + def test_invalidate_from_block(self, lmdb_store): + batch1 = BatchIdentifier('ethereum', 100, 200, '0xabc') + batch2 = BatchIdentifier('ethereum', 200, 300, '0xdef') + batch3 = BatchIdentifier('ethereum', 300, 400, '0x123') + batch4 = BatchIdentifier('ethereum', 400, 500, '0x456') + batch5 = BatchIdentifier('ethereum', 500, 600, '0x789') + batch6 = BatchIdentifier('ethereum', 50, 100, '0xaaa') + + lmdb_store.mark_processed('conn1', 'table1', [batch1, batch2, batch3, batch4, batch5, batch6]) + + invalidated = lmdb_store.invalidate_from_block('conn1', 'table1', 'ethereum', 250) + + assert len(invalidated) == 4 + invalidated_ids = {b.unique_id for b in invalidated} + assert batch2.unique_id in invalidated_ids + assert batch3.unique_id in invalidated_ids + assert batch4.unique_id in invalidated_ids + assert batch5.unique_id in invalidated_ids + + assert lmdb_store.is_processed('conn1', 'table1', [batch1]) is True + assert lmdb_store.is_processed('conn1', 'table1', [batch6]) is True + assert lmdb_store.is_processed('conn1', 'table1', [batch2]) is False + assert lmdb_store.is_processed('conn1', 'table1', [batch3]) is False + assert lmdb_store.is_processed('conn1', 'table1', [batch4]) is False + assert lmdb_store.is_processed('conn1', 'table1', [batch5]) is False + + def test_invalidate_updates_metadata(self, lmdb_store): + batch1 = BatchIdentifier('ethereum', 100, 200, '0xabc', '0xparent1') + batch2 = BatchIdentifier('ethereum', 200, 300, '0xdef', '0xparent2') + batch3 = BatchIdentifier('ethereum', 300, 400, '0x123', '0xparent3') + + lmdb_store.mark_processed('conn1', 'table1', [batch1, batch2, batch3]) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + assert watermark.ranges[0].end == 400 + + lmdb_store.invalidate_from_block('conn1', 'table1', 'ethereum', 250) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + assert watermark.ranges[0].end == 200 + assert watermark.ranges[0].hash == '0xabc' + + def test_invalidate_all_batches_clears_metadata(self, lmdb_store): + batch1 = BatchIdentifier('ethereum', 100, 200, '0xabc') + + lmdb_store.mark_processed('conn1', 'table1', [batch1]) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + assert watermark is not None + + lmdb_store.invalidate_from_block('conn1', 'table1', 'ethereum', 50) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + assert watermark is None + + def test_invalidate_only_affects_specified_network(self, lmdb_store): + eth_batch = BatchIdentifier('ethereum', 100, 200, '0xabc') + poly_batch = BatchIdentifier('polygon', 100, 200, '0xdef') + + lmdb_store.mark_processed('conn1', 'table1', [eth_batch, poly_batch]) + + invalidated = lmdb_store.invalidate_from_block('conn1', 'table1', 'ethereum', 150) + + assert len(invalidated) == 1 + assert invalidated[0].network == 'ethereum' + + assert lmdb_store.is_processed('conn1', 'table1', [poly_batch]) is True + + def test_cleanup_before_block(self, lmdb_store): + batch1 = BatchIdentifier('ethereum', 50, 100, '0xold1') + batch2 = BatchIdentifier('ethereum', 100, 200, '0xold2') + batch3 = BatchIdentifier('ethereum', 150, 220, '0xold3') + batch4 = BatchIdentifier('ethereum', 200, 300, '0xkeep1') + batch5 = BatchIdentifier('ethereum', 300, 400, '0xkeep2') + batch6 = BatchIdentifier('ethereum', 400, 500, '0xkeep3') + + lmdb_store.mark_processed('conn1', 'table1', [batch1, batch2, batch3, batch4, batch5, batch6]) + + lmdb_store.cleanup_before_block('conn1', 'table1', 'ethereum', 250) + + assert lmdb_store.is_processed('conn1', 'table1', [batch1]) is False + assert lmdb_store.is_processed('conn1', 'table1', [batch2]) is False + assert lmdb_store.is_processed('conn1', 'table1', [batch3]) is False + assert lmdb_store.is_processed('conn1', 'table1', [batch4]) is True + assert lmdb_store.is_processed('conn1', 'table1', [batch5]) is True + assert lmdb_store.is_processed('conn1', 'table1', [batch6]) is True + + def test_cleanup_only_affects_specified_network(self, lmdb_store): + eth_batch = BatchIdentifier('ethereum', 100, 200, '0xabc') + poly_batch = BatchIdentifier('polygon', 100, 200, '0xdef') + + lmdb_store.mark_processed('conn1', 'table1', [eth_batch, poly_batch]) + + lmdb_store.cleanup_before_block('conn1', 'table1', 'ethereum', 250) + + assert lmdb_store.is_processed('conn1', 'table1', [eth_batch]) is False + assert lmdb_store.is_processed('conn1', 'table1', [poly_batch]) is True + + def test_context_manager(self, temp_lmdb_dir): + batch_id = BatchIdentifier('ethereum', 100, 200, '0xabc') + + with LMDBStreamStateStore(connection_name='test', data_dir=temp_lmdb_dir) as store: + store.mark_processed('conn1', 'table1', [batch_id]) + assert store.is_processed('conn1', 'table1', [batch_id]) is True + + def test_persistence_across_close_reopen(self, temp_lmdb_dir): + batch_id = BatchIdentifier('ethereum', 100, 200, '0xabc') + + store1 = LMDBStreamStateStore(connection_name='test', data_dir=temp_lmdb_dir) + store1.mark_processed('conn1', 'table1', [batch_id]) + store1.close() + + store2 = LMDBStreamStateStore(connection_name='test', data_dir=temp_lmdb_dir) + assert store2.is_processed('conn1', 'table1', [batch_id]) is True + store2.close() + + def test_detect_gaps_raises_not_implemented(self, lmdb_store): + with pytest.raises(NotImplementedError, match='Gap detection not implemented'): + lmdb_store.get_resume_position('conn1', 'table1', detect_gaps=True) + + def test_resume_position_with_many_out_of_order_batches(self, lmdb_store): + batches = [ + BatchIdentifier('ethereum', 100, 150, '0x1', '0xp1'), + BatchIdentifier('ethereum', 500, 600, '0x2', '0xp2'), + BatchIdentifier('ethereum', 200, 300, '0x3', '0xp3'), + BatchIdentifier('ethereum', 50, 100, '0x4', '0xp4'), + BatchIdentifier('ethereum', 300, 400, '0x5', '0xp5'), + BatchIdentifier('ethereum', 150, 200, '0x6', '0xp6'), + BatchIdentifier('ethereum', 400, 500, '0x7', '0xp7'), + BatchIdentifier('ethereum', 600, 700, '0x8', '0xp8'), + BatchIdentifier('ethereum', 700, 800, '0x9', '0xp9'), + BatchIdentifier('ethereum', 250, 280, '0xa', '0xpa'), + ] + + for batch in batches: + lmdb_store.mark_processed('conn1', 'table1', [batch]) + + watermark = lmdb_store.get_resume_position('conn1', 'table1') + assert watermark.ranges[0].end == 800 + assert watermark.ranges[0].hash == '0x9' + assert watermark.ranges[0].prev_hash == '0xp9' + + +class TestIntegrationScenarios: + def test_streaming_with_resume(self, lmdb_store): + batch1 = BatchIdentifier('ethereum', 100, 200, '0xabc') + batch2 = BatchIdentifier('ethereum', 200, 300, '0xdef') + + lmdb_store.mark_processed('conn1', 'transfers', [batch1]) + lmdb_store.mark_processed('conn1', 'transfers', [batch2]) + + watermark = lmdb_store.get_resume_position('conn1', 'transfers') + assert watermark.ranges[0].end == 300 + + batch3 = BatchIdentifier('ethereum', 300, 400, '0x123') + batch4 = BatchIdentifier('ethereum', 400, 500, '0x456') + + assert lmdb_store.is_processed('conn1', 'transfers', [batch2]) is True + + lmdb_store.mark_processed('conn1', 'transfers', [batch3]) + lmdb_store.mark_processed('conn1', 'transfers', [batch4]) + + watermark = lmdb_store.get_resume_position('conn1', 'transfers') + assert watermark.ranges[0].end == 500 + + def test_reorg_scenario(self, lmdb_store): + batch1 = BatchIdentifier('ethereum', 100, 200, '0xabc') + batch2 = BatchIdentifier('ethereum', 200, 300, '0xdef') + batch3 = BatchIdentifier('ethereum', 300, 400, '0x123') + + lmdb_store.mark_processed('conn1', 'blocks', [batch1, batch2, batch3]) + + invalidated = lmdb_store.invalidate_from_block('conn1', 'blocks', 'ethereum', 250) + + assert len(invalidated) == 2 + + watermark = lmdb_store.get_resume_position('conn1', 'blocks') + assert watermark.ranges[0].end == 200 + + batch2_new = BatchIdentifier('ethereum', 200, 300, '0xNEWHASH1') + batch3_new = BatchIdentifier('ethereum', 300, 400, '0xNEWHASH2') + + lmdb_store.mark_processed('conn1', 'blocks', [batch2_new, batch3_new]) + + assert lmdb_store.is_processed('conn1', 'blocks', [batch2_new]) is True + assert lmdb_store.is_processed('conn1', 'blocks', [batch2]) is False + + def test_multi_network_streaming(self, lmdb_store): + eth_batch1 = BatchIdentifier('ethereum', 100, 200, '0xeth1') + eth_batch2 = BatchIdentifier('ethereum', 200, 300, '0xeth2') + poly_batch1 = BatchIdentifier('polygon', 500, 600, '0xpoly1') + arb_batch1 = BatchIdentifier('arbitrum', 1000, 1100, '0xarb1') + + lmdb_store.mark_processed('conn1', 'transfers', [eth_batch1, eth_batch2]) + lmdb_store.mark_processed('conn1', 'transfers', [poly_batch1]) + lmdb_store.mark_processed('conn1', 'transfers', [arb_batch1]) + + watermark = lmdb_store.get_resume_position('conn1', 'transfers') + + assert len(watermark.ranges) == 3 + networks = {r.network: r.end for r in watermark.ranges} + assert networks['ethereum'] == 300 + assert networks['polygon'] == 600 + assert networks['arbitrum'] == 1100 + + invalidated = lmdb_store.invalidate_from_block('conn1', 'transfers', 'ethereum', 250) + assert len(invalidated) == 1 + + assert lmdb_store.is_processed('conn1', 'transfers', [poly_batch1]) is True + assert lmdb_store.is_processed('conn1', 'transfers', [arb_batch1]) is True + + def test_crash_recovery_with_persistence(self, temp_lmdb_dir): + batch1 = BatchIdentifier('ethereum', 100, 200, '0xabc', '0xparent1') + batch2 = BatchIdentifier('ethereum', 200, 300, '0xdef', '0xparent2') + + store1 = LMDBStreamStateStore(connection_name='test', data_dir=temp_lmdb_dir) + store1.mark_processed('conn1', 'transfers', [batch1, batch2]) + watermark1 = store1.get_resume_position('conn1', 'transfers') + store1.close() + + store2 = LMDBStreamStateStore(connection_name='test', data_dir=temp_lmdb_dir) + watermark2 = store2.get_resume_position('conn1', 'transfers') + + assert watermark2 is not None + assert watermark2.ranges[0].end == watermark1.ranges[0].end + assert watermark2.ranges[0].hash == '0xdef' + + assert store2.is_processed('conn1', 'transfers', [batch1]) is True + assert store2.is_processed('conn1', 'transfers', [batch2]) is True + store2.close()