Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/amp/loaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ def load_stream_continuous(
table_name,
connection_name,
response.metadata.ranges,
ranges_complete=response.metadata.ranges_complete,
)
else:
# Non-transactional loading (separate check, load, mark)
Expand All @@ -494,6 +495,7 @@ def load_stream_continuous(
table_name,
connection_name,
response.metadata.ranges,
ranges_complete=response.metadata.ranges_complete,
**filtered_kwargs,
)

Expand Down Expand Up @@ -611,6 +613,7 @@ def _process_batch_transactional(
table_name: str,
connection_name: str,
ranges: List[BlockRange],
ranges_complete: bool = False,
) -> LoadResult:
"""
Process a data batch using transactional exactly-once semantics.
Expand All @@ -622,6 +625,7 @@ def _process_batch_transactional(
table_name: Target table name
connection_name: Connection identifier
ranges: Block ranges for this batch
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)

Returns:
LoadResult with operation outcome
Expand All @@ -630,13 +634,17 @@ def _process_batch_transactional(
try:
# Delegate to loader-specific transactional implementation
# Loaders that support transactions implement load_batch_transactional()
rows_loaded_batch = self.load_batch_transactional(batch_data, table_name, connection_name, ranges)
rows_loaded_batch = self.load_batch_transactional(
batch_data, table_name, connection_name, ranges, ranges_complete
)
duration = time.time() - start_time

# Mark batches as processed in state store after successful transaction
if ranges:
# Mark batches as processed ONLY when microbatch is complete
# multiple RecordBatches can share the same microbatch ID
if ranges and ranges_complete:
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
self.state_store.mark_processed(connection_name, table_name, batch_ids)
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')

return LoadResult(
rows_loaded=rows_loaded_batch,
Expand All @@ -648,6 +656,7 @@ def _process_batch_transactional(
metadata={
'operation': 'transactional_load' if rows_loaded_batch > 0 else 'skip_duplicate',
'ranges': [r.to_dict() for r in ranges],
'ranges_complete': ranges_complete,
},
)

Expand All @@ -670,6 +679,7 @@ def _process_batch_non_transactional(
table_name: str,
connection_name: str,
ranges: Optional[List[BlockRange]],
ranges_complete: bool = False,
**kwargs,
) -> Optional[LoadResult]:
"""
Expand All @@ -682,21 +692,25 @@ def _process_batch_non_transactional(
table_name: Target table name
connection_name: Connection identifier
ranges: Block ranges for this batch (if available)
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
**kwargs: Additional options passed to load_batch

Returns:
LoadResult, or None if batch was skipped as duplicate
"""
# Check if batch already processed (idempotency / exactly-once)
if ranges and self.state_enabled:
# For streaming: only check when ranges_complete=True (end of microbatch)
# Multiple RecordBatches can share the same microbatch ID, so we must wait
# until the entire microbatch is delivered before checking/marking as processed
if ranges and self.state_enabled and ranges_complete:
try:
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
is_duplicate = self.state_store.is_processed(connection_name, table_name, batch_ids)

if is_duplicate:
# Skip this batch - already processed
self.logger.info(
f'Skipping duplicate batch: {len(ranges)} ranges already processed for {table_name}'
f'Skipping duplicate microbatch: {len(ranges)} ranges already processed for {table_name}'
)
return LoadResult(
rows_loaded=0,
Expand All @@ -711,14 +725,16 @@ def _process_batch_non_transactional(
# BlockRange missing hash - log and continue without idempotency check
self.logger.warning(f'Cannot check for duplicates: {e}. Processing batch anyway.')

# Load batch
# Load batch (always load, even if part of larger microbatch)
result = self.load_batch(batch_data, table_name, **kwargs)

if result.success and ranges and self.state_enabled:
# Mark batch as processed (for exactly-once semantics)
# Mark batch as processed ONLY when microbatch is complete
# This ensures we don't skip subsequent RecordBatches within the same microbatch
if result.success and ranges and self.state_enabled and ranges_complete:
try:
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
self.state_store.mark_processed(connection_name, table_name, batch_ids)
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
except Exception as e:
self.logger.error(f'Failed to mark batches as processed: {e}')
# Continue anyway - state store provides resume capability
Expand Down
15 changes: 10 additions & 5 deletions src/amp/loaders/implementations/postgresql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def load_batch_transactional(
table_name: str,
connection_name: str,
ranges: List[BlockRange],
ranges_complete: bool = False,
) -> int:
"""
Load a batch with transactional exactly-once semantics using in-memory state.
Expand All @@ -135,6 +136,7 @@ def load_batch_transactional(
table_name: Target table name
connection_name: Connection identifier for tracking
ranges: Block ranges covered by this batch
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)

Returns:
Number of rows loaded (0 if duplicate)
Expand All @@ -149,24 +151,27 @@ def load_batch_transactional(
self.logger.warning(f'Cannot create batch identifiers: {e}. Loading without duplicate check.')
batch_ids = []

# Check if already processed (using in-memory state)
if batch_ids and self.state_store.is_processed(connection_name, table_name, batch_ids):
# Check if already processed ONLY when microbatch is complete
# Multiple RecordBatches can share the same microbatch ID (BlockRange)
if batch_ids and ranges_complete and self.state_store.is_processed(connection_name, table_name, batch_ids):
self.logger.info(
f'Batch already processed (ranges: {[f"{r.network}:{r.start}-{r.end}" for r in ranges]}), '
f'skipping (state check)'
)
return 0

# Load data
# Load data (always load, even if part of larger microbatch)
conn = self.pool.getconn()
try:
with conn.cursor() as cur:
self._copy_arrow_data(cur, batch, table_name)
conn.commit()

# Mark as processed after successful load
if batch_ids:
# Mark as processed ONLY when microbatch is complete
# This ensures we don't skip subsequent RecordBatches within the same microbatch
if batch_ids and ranges_complete:
self.state_store.mark_processed(connection_name, table_name, batch_ids)
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')

self.logger.debug(
f'Batch load committed: {batch.num_rows} rows, '
Expand Down
123 changes: 123 additions & 0 deletions tests/integration/test_postgresql_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,126 @@ def test_reorg_preserves_different_networks(self, postgresql_test_config, test_t

finally:
loader.pool.putconn(conn)

def test_microbatch_deduplication(self, postgresql_test_config, test_table_name, cleanup_tables):
"""
Test that multiple RecordBatches within the same microbatch are all loaded,
and deduplication only happens at microbatch boundaries when ranges_complete=True.

This test verifies the fix for the critical bug where we were marking batches
as processed after every RecordBatch instead of waiting for ranges_complete=True.
"""
from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch

cleanup_tables.append(test_table_name)

# Enable state management to test deduplication
config_with_state = {
**postgresql_test_config,
'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True},
}
loader = PostgreSQLLoader(config_with_state)

with loader:
# Create table first from the schema
batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]})
loader._create_table_from_schema(batch1_data.schema, test_table_name)

# Simulate a microbatch sent as 3 RecordBatches with the same BlockRange
# This happens when the server sends large microbatches in smaller chunks

# First RecordBatch of the microbatch (ranges_complete=False)
response1 = ResponseBatch.data_batch(
data=batch1_data,
metadata=BatchMetadata(
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')],
ranges_complete=False, # Not the last batch in this microbatch
),
)

# Second RecordBatch of the microbatch (ranges_complete=False)
batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]})
response2 = ResponseBatch.data_batch(
data=batch2_data,
metadata=BatchMetadata(
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange!
ranges_complete=False, # Still not the last batch
),
)

# Third RecordBatch of the microbatch (ranges_complete=True)
batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]})
response3 = ResponseBatch.data_batch(
data=batch3_data,
metadata=BatchMetadata(
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange!
ranges_complete=True, # Last batch in this microbatch - safe to mark as processed
),
)

# Process the microbatch stream
stream = [response1, response2, response3]
results = list(
loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection')
)

# CRITICAL: All 3 RecordBatches should be loaded successfully
# Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates")
assert len(results) == 3, 'All RecordBatches within microbatch should be processed'
assert all(r.success for r in results), 'All batches should succeed'
assert results[0].rows_loaded == 2, 'First batch should load 2 rows'
assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)'
assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)'

# Verify total rows in table (all batches loaded)
conn = loader.pool.getconn()
try:
with conn.cursor() as cur:
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
total_count = cur.fetchone()[0]
assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table'

# Verify the actual IDs are present
cur.execute(f'SELECT id FROM {test_table_name} ORDER BY id')
all_ids = [row[0] for row in cur.fetchall()]
assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present'

finally:
loader.pool.putconn(conn)

# Now test that re-sending the complete microbatch is properly deduplicated
# This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch)
duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]})
duplicate_response = ResponseBatch.data_batch(
data=duplicate_batch,
metadata=BatchMetadata(
ranges=[
BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')
], # Same range as before!
ranges_complete=True, # Complete microbatch
),
)

# Process duplicate microbatch
duplicate_results = list(
loader.load_stream_continuous(
iter([duplicate_response]), test_table_name, connection_name='test_connection'
)
)

# The duplicate microbatch should be skipped (already processed)
assert len(duplicate_results) == 1
assert duplicate_results[0].success is True
assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped'
assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate'

# Verify row count unchanged (duplicate was skipped)
conn = loader.pool.getconn()
try:
with conn.cursor() as cur:
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
final_count = cur.fetchone()[0]
assert final_count == 6, 'Row count should not increase after duplicate microbatch'

finally:
loader.pool.putconn(conn)
Loading
Loading