Skip to content

Commit 8da7a7d

Browse files
authored
Merge pull request #145 from MITLibraries/TIMX-497-filtering-current-records
TIMX 497 - filtering current records
2 parents 00b8d2a + 8d448db commit 8da7a7d

File tree

4 files changed

+162
-67
lines changed

4 files changed

+162
-67
lines changed

tests/test_dataset.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# ruff: noqa: S105, S106, SLF001, PLR2004
1+
# ruff: noqa: D205, S105, S106, SLF001, PD901, PLR2004
2+
23
import os
34
from datetime import date
45
from unittest.mock import MagicMock, patch
@@ -397,3 +398,68 @@ def test_dataset_all_read_methods_get_deduplication(
397398
transformed_records = list(local_dataset_with_runs.read_transformed_records_iter())
398399

399400
assert len(full_df) == len(all_records) == len(transformed_records)
401+
402+
403+
def test_dataset_current_records_no_additional_filtering_accurate_records_yielded(
404+
local_dataset_with_runs,
405+
):
406+
local_dataset_with_runs.load(current_records=True, source="alma")
407+
df = local_dataset_with_runs.read_dataframe()
408+
assert df.action.value_counts().to_dict() == {"index": 99, "delete": 1}
409+
410+
411+
def test_dataset_current_records_action_filtering_accurate_records_yielded(
412+
local_dataset_with_runs,
413+
):
414+
local_dataset_with_runs.load(current_records=True, source="alma")
415+
df = local_dataset_with_runs.read_dataframe(action="index")
416+
assert df.action.value_counts().to_dict() == {"index": 99}
417+
418+
419+
def test_dataset_current_records_index_filtering_accurate_records_yielded(
420+
local_dataset_with_runs,
421+
):
422+
"""This is a somewhat complex test, but demonstrates that only 'current' records
423+
are yielded when .load(current_records=True) is applied.
424+
425+
Given these runs from the fixture:
426+
[
427+
...
428+
(25, "alma", "2025-01-03", "daily", "index", "run-5"), <---- filtered to
429+
(10, "alma", "2025-01-04", "daily", "delete", "run-6"), <---- influences current
430+
...
431+
]
432+
433+
Though we are filtering to run-5, which has 25 total records to-index, we see only 15
434+
records yielded. Why? This is because while we have filtered to only yield from
435+
run-5, run-6 had 10 deletes which made records alma:0|9 no longer "current" in run-5.
436+
As we yielded records reverse chronologically, the deletes from run-6 (alma:0-alma:9)
437+
"influenced" what records we would see as we continue backwards in time.
438+
"""
439+
# with current_records=False, we get all 25 records from run-5
440+
local_dataset_with_runs.load(current_records=False, source="alma")
441+
df = local_dataset_with_runs.read_dataframe(run_id="run-5")
442+
assert len(df) == 25
443+
444+
# with current_records=True, we only get 15 records from run-5
445+
# because newer run-6 influenced what records are current for older run-5
446+
local_dataset_with_runs.load(current_records=True, source="alma")
447+
df = local_dataset_with_runs.read_dataframe(run_id="run-5")
448+
assert len(df) == 15
449+
assert list(df.timdex_record_id) == [
450+
"alma:10",
451+
"alma:11",
452+
"alma:12",
453+
"alma:13",
454+
"alma:14",
455+
"alma:15",
456+
"alma:16",
457+
"alma:17",
458+
"alma:18",
459+
"alma:19",
460+
"alma:20",
461+
"alma:21",
462+
"alma:22",
463+
"alma:23",
464+
"alma:24",
465+
]

tests/test_runs.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,21 @@
1212
@pytest.fixture
1313
def timdex_run_manager(dataset_with_runs_location):
1414
timdex_dataset = TIMDEXDataset(dataset_with_runs_location)
15-
return TIMDEXRunManager(timdex_dataset=timdex_dataset)
15+
timdex_dataset.load()
16+
return TIMDEXRunManager(dataset=timdex_dataset.dataset)
1617

1718

1819
def test_timdex_run_manager_init(dataset_with_runs_location):
1920
timdex_dataset = TIMDEXDataset(dataset_with_runs_location)
20-
timdex_run_manager = TIMDEXRunManager(timdex_dataset=timdex_dataset)
21+
timdex_dataset.load()
22+
timdex_run_manager = TIMDEXRunManager(dataset=timdex_dataset.dataset)
2123
assert timdex_run_manager._runs_metadata_cache is None
2224

2325

2426
def test_timdex_run_manager_parse_single_parquet_file_success(timdex_run_manager):
2527
"""Parse run metadata from first parquet file in fixture dataset. We know the details
2628
of this ETL run in advance given the deterministic fixture that generated it."""
27-
parquet_filepath = timdex_run_manager.timdex_dataset.dataset.files[0]
29+
parquet_filepath = timdex_run_manager.dataset.files[0]
2830
run_metadata = timdex_run_manager._parse_run_metadata_from_parquet_file(
2931
parquet_filepath
3032
)

timdex_dataset_api/dataset.py

Lines changed: 85 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def __init__(
120120
self.schema = TIMDEX_DATASET_SCHEMA
121121
self.partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
122122
self._written_files: list[ds.WrittenFile] = None # type: ignore[assignment]
123-
self._dedupe_on_read: bool = False
123+
124+
self._current_records: bool = False
125+
self._current_records_dataset: ds.Dataset = None # type: ignore[assignment]
124126

125127
@property
126128
def row_count(self) -> int:
@@ -153,27 +155,32 @@ def load(
153155
- filters: kwargs typed via DatasetFilters TypedDict
154156
- Filters passed directly in method call, e.g. source="alma",
155157
run_date="2024-12-20", etc., but are typed according to DatasetFilters.
158+
- current_records: bool
159+
- if True, the TIMDEXRunManager will be used to retrieve a list of parquet
160+
files associated with current runs, some internal flags will be set, all
161+
ensuring that only current records are yielded for any read methods
156162
"""
157163
start_time = time.perf_counter()
158164

159165
# reset paths from original location before load
160166
_, self.paths = self.parse_location(self.location)
161167

162168
# perform initial load of full dataset
163-
self._load_pyarrow_dataset()
169+
self.dataset = self._load_pyarrow_dataset()
164170

165-
# if current_records flag set, limit to parquet files associated with current runs
166-
self._dedupe_on_read = current_records
171+
self._current_records = current_records
167172
if current_records:
168-
timdex_run_manager = TIMDEXRunManager(timdex_dataset=self)
169173

170-
# update paths, limiting by source if set
174+
timdex_run_manager = TIMDEXRunManager(dataset=self.dataset)
171175
self.paths = timdex_run_manager.get_current_parquet_files(
172176
source=filters.get("source")
173177
)
174178

175-
# reload pyarrow dataset
176-
self._load_pyarrow_dataset()
179+
# reload pyarrow dataset, filtered now to an explicit list of parquet files
180+
# also save an instance of the dataset before any additional filtering
181+
dataset = self._load_pyarrow_dataset()
182+
self.dataset = dataset
183+
self._current_records_dataset = dataset
177184

178185
# filter dataset
179186
self.dataset = self._get_filtered_dataset(**filters)
@@ -183,9 +190,9 @@ def load(
183190
f"{round(time.perf_counter()-start_time, 2)}s"
184191
)
185192

186-
def _load_pyarrow_dataset(self) -> None:
193+
def _load_pyarrow_dataset(self) -> ds.Dataset:
187194
"""Load the pyarrow dataset per local filesystem and paths attributes."""
188-
self.dataset = ds.dataset(
195+
return ds.dataset(
189196
self.paths,
190197
schema=self.schema,
191198
format="parquet",
@@ -449,19 +456,14 @@ def read_batches_iter(
449456
"""Yield pyarrow.RecordBatches from the dataset.
450457
451458
While batch_size will limit the max rows per batch, filtering may result in some
452-
batches have less than this limit.
459+
batches having less than this limit.
460+
461+
If the flag self._current_records is set, this method leans on
462+
self._yield_current_record_deduped_batches() to apply deduplication of records to
463+
ensure only current versions of the record are ever yielded.
453464
454465
Args:
455466
- columns: list[str], list of columns to return from the dataset
456-
- batch_size: int, max number of rows to yield per batch
457-
- batch_read_ahead: int, the number of batches to read ahead in a file. This
458-
might not work for all file formats. Increasing this number will increase
459-
RAM usage but could also improve IO utilization. Pyarrow default is 16,
460-
but this library defaults to 0 to prioritize memory footprint.
461-
- fragment_read_ahead: int, The number of files to read ahead. Increasing this
462-
number will increase RAM usage but could also improve IO utilization.
463-
Pyarrow default is 4, but this library defaults to 0 to prioritize memory
464-
footprint.
465467
- filters: pairs of column:value to filter the dataset
466468
"""
467469
if not self.dataset:
@@ -477,47 +479,78 @@ def read_batches_iter(
477479
fragment_readahead=self.config.fragment_read_ahead,
478480
)
479481

480-
if self._dedupe_on_read:
481-
yield from self._yield_deduped_batches(batches)
482+
if self._current_records:
483+
yield from self._yield_current_record_batches(batches)
482484
else:
483485
for batch in batches:
484486
if len(batch) > 0:
485487
yield batch
486488

487-
def _yield_deduped_batches(
488-
self, batches: Iterator[pa.RecordBatch]
489+
def _yield_current_record_batches(
490+
self,
491+
batches: Iterator[pa.RecordBatch],
489492
) -> Iterator[pa.RecordBatch]:
490-
"""Method to yield record deduped batches.
493+
"""Method to yield only the most recent version of each record.
494+
495+
When multiple versions of a record (same timdex_record_id) exist in the dataset,
496+
this method ensures only the most recent version is returned. If filtering is
497+
applied that removes this most recent version of a record, that timdex_record_id
498+
will not be yielded at all.
499+
500+
This is achieved by iterating over TWO record batch iterators in parallel:
501+
502+
1. "batches" - the RecordBatch iterator passed to this method which
503+
contains the actual records and columns we are interested in, and may contain
504+
filtering
505+
506+
2. "unfiltered_batches" - a lightweight RecordBatch iterator that only
507+
contains the 'timdex_record_id' column from a pre-filtering dataset saved
508+
during .load()
509+
510+
These two iterators are guaranteed to have the same number of total batches based
511+
on how pyarrow.Dataset.to_batches() reads from parquet files. Even if dataset
512+
filtering is applied, this does not affect the batch count; you may just end up
513+
with smaller or empty batches.
491514
492-
Extending the normal behavior of yielding batches untouched, this method keeps
493-
track of seen timdex_record_id's, yielding them only once. For this method to
494-
yield the most current version of a record -- most common usage -- it is required
495-
that the batches are pre-ordered so the most recent record version is encountered
496-
first.
515+
As we move through the record batches we use unfiltered batches to keep a list of
516+
seen timdex_record_ids. Even if a timdex_record_is not in the record batch --
517+
likely due to filtering -- we will mark that timdex_record_id as "seen" and not
518+
yield it from any future batches.
519+
520+
Args:
521+
- batches: batches of records to actually yield from
522+
- current_record_id_batches: batches of timdex_record_id's that inform when
523+
to yield or skip a record for a batch
497524
"""
525+
unfiltered_batches = self._current_records_dataset.to_batches(
526+
columns=["timdex_record_id"],
527+
batch_size=self.config.read_batch_size,
528+
batch_readahead=self.config.batch_read_ahead,
529+
fragment_readahead=self.config.fragment_read_ahead,
530+
)
531+
498532
seen_records = set()
499-
for batch in batches:
500-
if len(batch) > 0:
501-
# init list of batch indices for records unseen
502-
unseen_batch_indices = []
503-
504-
# get list of timdex ids from batch
505-
timdex_ids = batch.column("timdex_record_id").to_pylist()
506-
507-
# check each record id and track unseen ones
508-
for i, record_id in enumerate(timdex_ids):
509-
if record_id not in seen_records:
510-
unseen_batch_indices.append(i)
511-
seen_records.add(record_id)
512-
513-
# if all records from batch were seen, continue
514-
if not unseen_batch_indices:
515-
continue
516-
517-
# else, yield unseen records from batch
518-
deduped_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
519-
if len(deduped_batch) > 0:
520-
yield deduped_batch
533+
for unfiltered_batch, batch in zip(unfiltered_batches, batches, strict=True):
534+
# init list of indices from the batch for records we have never yielded
535+
unseen_batch_indices = []
536+
537+
# check each record id and track unseen ones
538+
for i, record_id in enumerate(batch.column("timdex_record_id").to_pylist()):
539+
if record_id not in seen_records:
540+
unseen_batch_indices.append(i)
541+
542+
# even if not a record to yield, update our list of seen records from all
543+
# records in the unfiltered batch
544+
seen_records.update(unfiltered_batch.column("timdex_record_id").to_pylist())
545+
546+
# if no unseen records from this batch, skip yielding entirely
547+
if not unseen_batch_indices:
548+
continue
549+
550+
# create a new RecordBatch using the unseen indices of the batch
551+
_batch = batch.take(pa.array(unseen_batch_indices)) # type: ignore[arg-type]
552+
if len(_batch) > 0:
553+
yield _batch
521554

522555
def read_dataframes_iter(
523556
self,

timdex_dataset_api/run.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,19 @@
33
import concurrent.futures
44
import logging
55
import time
6-
from typing import TYPE_CHECKING
76

87
import pandas as pd
8+
import pyarrow.dataset as ds
99
import pyarrow.parquet as pq
1010

11-
if TYPE_CHECKING:
12-
from timdex_dataset_api.dataset import TIMDEXDataset
13-
1411
logger = logging.getLogger(__name__)
1512

1613

1714
class TIMDEXRunManager:
1815
"""Manages and provides access to ETL run metadata from the TIMDEX parquet dataset."""
1916

20-
def __init__(self, timdex_dataset: "TIMDEXDataset"):
21-
self.timdex_dataset: TIMDEXDataset = timdex_dataset
22-
if self.timdex_dataset.dataset is None:
23-
self.timdex_dataset.load()
24-
17+
def __init__(self, dataset: ds.Dataset):
18+
self.dataset = dataset
2519
self._runs_metadata_cache: pd.DataFrame | None = None
2620

2721
def clear_cache(self) -> None:
@@ -143,7 +137,7 @@ def _get_parquet_files_run_metadata(self, max_workers: int = 250) -> pd.DataFram
143137
"""
144138
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
145139
futures = []
146-
for parquet_filepath in self.timdex_dataset.dataset.files: # type: ignore[attr-defined]
140+
for parquet_filepath in self.dataset.files: # type: ignore[attr-defined]
147141
future = executor.submit(
148142
self._parse_run_metadata_from_parquet_file,
149143
parquet_filepath,
@@ -181,7 +175,7 @@ def _parse_run_metadata_from_parquet_file(self, parquet_filepath: str) -> dict:
181175
"""
182176
parquet_file = pq.ParquetFile(
183177
parquet_filepath,
184-
filesystem=self.timdex_dataset.filesystem,
178+
filesystem=self.dataset.filesystem, # type: ignore[attr-defined]
185179
)
186180

187181
file_meta = parquet_file.metadata.to_dict()

0 commit comments

Comments
 (0)