@@ -120,7 +120,9 @@ def __init__(
120120 self .schema = TIMDEX_DATASET_SCHEMA
121121 self .partition_columns = TIMDEX_DATASET_PARTITION_COLUMNS
122122 self ._written_files : list [ds .WrittenFile ] = None # type: ignore[assignment]
123- self ._dedupe_on_read : bool = False
123+
124+ self ._current_records : bool = False
125+ self ._current_records_dataset : ds .Dataset = None # type: ignore[assignment]
124126
125127 @property
126128 def row_count (self ) -> int :
@@ -153,27 +155,32 @@ def load(
153155 - filters: kwargs typed via DatasetFilters TypedDict
154156 - Filters passed directly in method call, e.g. source="alma",
155157 run_date="2024-12-20", etc., but are typed according to DatasetFilters.
158+ - current_records: bool
159+ - if True, the TIMDEXRunManager will be used to retrieve a list of parquet
160+ files associated with current runs, some internal flags will be set, all
161+ ensuring that only current records are yielded for any read methods
156162 """
157163 start_time = time .perf_counter ()
158164
159165 # reset paths from original location before load
160166 _ , self .paths = self .parse_location (self .location )
161167
162168 # perform initial load of full dataset
163- self ._load_pyarrow_dataset ()
169+ self .dataset = self . _load_pyarrow_dataset ()
164170
165- # if current_records flag set, limit to parquet files associated with current runs
166- self ._dedupe_on_read = current_records
171+ self ._current_records = current_records
167172 if current_records :
168- timdex_run_manager = TIMDEXRunManager (timdex_dataset = self )
169173
170- # update paths, limiting by source if set
174+ timdex_run_manager = TIMDEXRunManager ( dataset = self . dataset )
171175 self .paths = timdex_run_manager .get_current_parquet_files (
172176 source = filters .get ("source" )
173177 )
174178
175- # reload pyarrow dataset
176- self ._load_pyarrow_dataset ()
179+ # reload pyarrow dataset, filtered now to an explicit list of parquet files
180+ # also save an instance of the dataset before any additional filtering
181+ dataset = self ._load_pyarrow_dataset ()
182+ self .dataset = dataset
183+ self ._current_records_dataset = dataset
177184
178185 # filter dataset
179186 self .dataset = self ._get_filtered_dataset (** filters )
@@ -183,9 +190,9 @@ def load(
183190 f"{ round (time .perf_counter ()- start_time , 2 )} s"
184191 )
185192
186- def _load_pyarrow_dataset (self ) -> None :
193+ def _load_pyarrow_dataset (self ) -> ds . Dataset :
187194 """Load the pyarrow dataset per local filesystem and paths attributes."""
188- self . dataset = ds .dataset (
195+ return ds .dataset (
189196 self .paths ,
190197 schema = self .schema ,
191198 format = "parquet" ,
@@ -449,19 +456,14 @@ def read_batches_iter(
449456 """Yield pyarrow.RecordBatches from the dataset.
450457
451458 While batch_size will limit the max rows per batch, filtering may result in some
452- batches have less than this limit.
459+ batches having less than this limit.
460+
461+ If the flag self._current_records is set, this method leans on
462+ self._yield_current_record_deduped_batches() to apply deduplication of records to
463+ ensure only current versions of the record are ever yielded.
453464
454465 Args:
455466 - columns: list[str], list of columns to return from the dataset
456- - batch_size: int, max number of rows to yield per batch
457- - batch_read_ahead: int, the number of batches to read ahead in a file. This
458- might not work for all file formats. Increasing this number will increase
459- RAM usage but could also improve IO utilization. Pyarrow default is 16,
460- but this library defaults to 0 to prioritize memory footprint.
461- - fragment_read_ahead: int, The number of files to read ahead. Increasing this
462- number will increase RAM usage but could also improve IO utilization.
463- Pyarrow default is 4, but this library defaults to 0 to prioritize memory
464- footprint.
465467 - filters: pairs of column:value to filter the dataset
466468 """
467469 if not self .dataset :
@@ -477,47 +479,78 @@ def read_batches_iter(
477479 fragment_readahead = self .config .fragment_read_ahead ,
478480 )
479481
480- if self ._dedupe_on_read :
481- yield from self ._yield_deduped_batches (batches )
482+ if self ._current_records :
483+ yield from self ._yield_current_record_batches (batches )
482484 else :
483485 for batch in batches :
484486 if len (batch ) > 0 :
485487 yield batch
486488
487- def _yield_deduped_batches (
488- self , batches : Iterator [pa .RecordBatch ]
489+ def _yield_current_record_batches (
490+ self ,
491+ batches : Iterator [pa .RecordBatch ],
489492 ) -> Iterator [pa .RecordBatch ]:
490- """Method to yield record deduped batches.
493+ """Method to yield only the most recent version of each record.
494+
495+ When multiple versions of a record (same timdex_record_id) exist in the dataset,
496+ this method ensures only the most recent version is returned. If filtering is
497+ applied that removes this most recent version of a record, that timdex_record_id
498+ will not be yielded at all.
499+
500+ This is achieved by iterating over TWO record batch iterators in parallel:
501+
502+ 1. "batches" - the RecordBatch iterator passed to this method which
503+ contains the actual records and columns we are interested in, and may contain
504+ filtering
505+
506+ 2. "unfiltered_batches" - a lightweight RecordBatch iterator that only
507+ contains the 'timdex_record_id' column from a pre-filtering dataset saved
508+ during .load()
509+
510+ These two iterators are guaranteed to have the same number of total batches based
511+ on how pyarrow.Dataset.to_batches() reads from parquet files. Even if dataset
512+ filtering is applied, this does not affect the batch count; you may just end up
513+ with smaller or empty batches.
491514
492- Extending the normal behavior of yielding batches untouched, this method keeps
493- track of seen timdex_record_id's, yielding them only once. For this method to
494- yield the most current version of a record -- most common usage -- it is required
495- that the batches are pre-ordered so the most recent record version is encountered
496- first.
515+ As we move through the record batches we use unfiltered batches to keep a list of
516+ seen timdex_record_ids. Even if a timdex_record_is not in the record batch --
517+ likely due to filtering -- we will mark that timdex_record_id as "seen" and not
518+ yield it from any future batches.
519+
520+ Args:
521+ - batches: batches of records to actually yield from
522+ - current_record_id_batches: batches of timdex_record_id's that inform when
523+ to yield or skip a record for a batch
497524 """
525+ unfiltered_batches = self ._current_records_dataset .to_batches (
526+ columns = ["timdex_record_id" ],
527+ batch_size = self .config .read_batch_size ,
528+ batch_readahead = self .config .batch_read_ahead ,
529+ fragment_readahead = self .config .fragment_read_ahead ,
530+ )
531+
498532 seen_records = set ()
499- for batch in batches :
500- if len (batch ) > 0 :
501- # init list of batch indices for records unseen
502- unseen_batch_indices = []
503-
504- # get list of timdex ids from batch
505- timdex_ids = batch .column ("timdex_record_id" ).to_pylist ()
506-
507- # check each record id and track unseen ones
508- for i , record_id in enumerate (timdex_ids ):
509- if record_id not in seen_records :
510- unseen_batch_indices .append (i )
511- seen_records .add (record_id )
512-
513- # if all records from batch were seen, continue
514- if not unseen_batch_indices :
515- continue
516-
517- # else, yield unseen records from batch
518- deduped_batch = batch .take (pa .array (unseen_batch_indices )) # type: ignore[arg-type]
519- if len (deduped_batch ) > 0 :
520- yield deduped_batch
533+ for unfiltered_batch , batch in zip (unfiltered_batches , batches , strict = True ):
534+ # init list of indices from the batch for records we have never yielded
535+ unseen_batch_indices = []
536+
537+ # check each record id and track unseen ones
538+ for i , record_id in enumerate (batch .column ("timdex_record_id" ).to_pylist ()):
539+ if record_id not in seen_records :
540+ unseen_batch_indices .append (i )
541+
542+ # even if not a record to yield, update our list of seen records from all
543+ # records in the unfiltered batch
544+ seen_records .update (unfiltered_batch .column ("timdex_record_id" ).to_pylist ())
545+
546+ # if no unseen records from this batch, skip yielding entirely
547+ if not unseen_batch_indices :
548+ continue
549+
550+ # create a new RecordBatch using the unseen indices of the batch
551+ _batch = batch .take (pa .array (unseen_batch_indices )) # type: ignore[arg-type]
552+ if len (_batch ) > 0 :
553+ yield _batch
521554
522555 def read_dataframes_iter (
523556 self ,
0 commit comments