From efd84d66430314be556efbd11cd746dc95800520 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Tue, 6 Jan 2026 10:51:23 -0500 Subject: [PATCH] Propagate close() from DatasetIterator to data sources This change ensures that when `DatasetIterator.close()` is called, the cleanup propagates all the way down to the underlying data source, allowing proper resource cleanup (e.g., closing file handles, database connections, or thread-local resources). Changes: - Add `close()` to `MapDataset` that propagates to parent datasets - Add `close()` to `SourceMapDataset` that calls the data source's `close()` method if available - Update `PrefetchDatasetIterator.close()` to call `_map_parent.close()` - Update `GetElementProducerFn.__call__` to call `it.close()` in finally block for multiprocessing cleanup - Add tests for close propagation --- grain/_src/python/dataset/dataset.py | 8 ++ .../dataset/transformations/prefetch.py | 42 +++++---- .../dataset/transformations/prefetch_test.py | 89 +++++++++++++++++++ .../python/dataset/transformations/source.py | 5 ++ 4 files changed, 128 insertions(+), 16 deletions(-) diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index ba04c95bf..a76d51185 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -907,6 +907,14 @@ def _stats(self) -> dataset_stats.Stats: """Returns the Stats object for recording statistics about this dataset.""" return self._initialize_stats(base.ExecutionTrackingMode.DISABLED) + def close(self) -> None: + """Closes the dataset and releases any resources by recursively closing all + parent datasets in the pipeline. This method is safe to call multiple + times. + """ + for parent in self._parents: + parent.close() + # pytype: enable=attribute-error # pylint: enable=protected-access diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index e7a6392fa..6e2bc90fa 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -319,13 +319,20 @@ def close(self) -> None: if self._closed: return self._closed = True - # Shutdown the thread pool executor if it exists. if hasattr(self, "_executor"): + # Submit close to each worker thread for thread-local resource cleanup. + cleanup_futures = [ + self._executor.submit(self._map_parent.close) + for _ in range(self._num_threads) + ] + futures.wait(cleanup_futures, timeout=5.0) self._executor.shutdown(wait=False) # Cancel all pending futures in the buffer. while self._buffer: future = self._buffer.popleft() future.cancel() + # Also close from main thread for any main-thread resources. + self._map_parent.close() def _iterator_with_context( @@ -590,21 +597,24 @@ def __call__( for _ in range(self._state[_ITERATIONS_TO_SKIP][str(worker_index)]): _ = next(it) last_recorded_state_time = time.time() - for element in it: - now = time.time() - element = _copy_struct_to_shm(element, min_size=min_shm_size) - # If the node is prefetch, we already record the bytes produced in it's - # __next__ method. - if not it._stats._config.is_prefetch: - it._stats.record_bytes_produced(element) - if ( - self._always_report_worker_state - or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S - ): - last_recorded_state_time = now - yield (element, it.get_state()) # pytype: disable=attribute-error - else: - yield (element, None) + try: + for element in it: + now = time.time() + element = _copy_struct_to_shm(element, min_size=min_shm_size) + # If the node is prefetch, we already record the bytes produced in it's + # __next__ method. + if not it._stats._config.is_prefetch: + it._stats.record_bytes_produced(element) + if ( + self._always_report_worker_state + or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S + ): + last_recorded_state_time = now + yield (element, it.get_state()) # pytype: disable=attribute-error + else: + yield (element, None) + finally: + it.close() def serialize(self) -> bytes: """Overrides the default implementation to generate better error messages.""" diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 3e45a8ad6..0f48620e3 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -1426,5 +1426,94 @@ def test_mp_context_is_set_correctly(self): self.assertEqual(context.process_count, num_workers) +class _CloseTrackingDataSource: + """A data source that tracks close() calls and which threads called it.""" + + def __init__(self, data): + self._data = data + self.close_threads = [] + self._lock = threading.Lock() + + def __len__(self): + return len(self._data) + + def __getitem__(self, index): + return self._data[index] + + def close(self): + with self._lock: + self.close_threads.append(threading.current_thread().ident) + + +class DataSourceCloseTest(parameterized.TestCase): + """Tests for data source close() propagation.""" + + def test_close_propagates_to_data_source(self): + source = _CloseTrackingDataSource([1, 2, 3, 4, 5]) + ds = dataset.MapDataset.source(source) + num_threads = 2 + read_options = options.ReadOptions( + num_threads=num_threads, prefetch_buffer_size=2 + ) + iter_ds = ds.to_iter_dataset(read_options=read_options) + it = iter(iter_ds) + + _ = [next(it) for _ in range(3)] + self.assertEmpty(source.close_threads) + + it.close() + # Called from each worker thread + main thread. + self.assertLen(source.close_threads, num_threads + 1) + + def test_close_called_from_each_worker_thread(self): + source = _CloseTrackingDataSource([1, 2, 3, 4, 5]) + ds = dataset.MapDataset.source(source) + num_threads = 4 + read_options = options.ReadOptions( + num_threads=num_threads, prefetch_buffer_size=4 + ) + iter_ds = ds.to_iter_dataset(read_options=read_options) + it = iter(iter_ds) + + _ = next(it) + it.close() + + # Called from each worker thread + main thread. + self.assertLen(source.close_threads, num_threads + 1) + + def test_close_without_prefetch(self): + source = _CloseTrackingDataSource([1, 2, 3]) + ds = dataset.MapDataset.source(source) + read_options = options.ReadOptions(prefetch_buffer_size=0) + iter_ds = ds.to_iter_dataset(read_options=read_options) + it = iter(iter_ds) + + _ = next(it) + it.close() + # No executor, so close called once from main thread. + self.assertLen(source.close_threads, 1) + + def test_iterator_close_is_idempotent(self): + source = _CloseTrackingDataSource([1, 2, 3]) + ds = dataset.MapDataset.source(source) + read_options = options.ReadOptions(prefetch_buffer_size=0) + iter_ds = ds.to_iter_dataset(read_options=read_options) + it = iter(iter_ds) + + _ = next(it) + it.close() + it.close() + # Iterator close is idempotent. + self.assertLen(source.close_threads, 1) + + def test_map_dataset_close_propagates(self): + source = _CloseTrackingDataSource([1, 2, 3, 4, 5]) + ds = dataset.MapDataset.source(source).map(lambda x: x * 2).batch(2) + + self.assertEmpty(source.close_threads) + ds.close() + self.assertLen(source.close_threads, 1) + + if __name__ == '__main__': absltest.main() diff --git a/grain/_src/python/dataset/transformations/source.py b/grain/_src/python/dataset/transformations/source.py index f7da76bd3..9e42c37a8 100644 --- a/grain/_src/python/dataset/transformations/source.py +++ b/grain/_src/python/dataset/transformations/source.py @@ -138,6 +138,11 @@ def paths(self) -> str | Sequence[str]: else: return [] + def close(self) -> None: + """Closes the underlying data source if it has a close method.""" + if hasattr(self._source, "close"): + self._source.close() + def log_lineage_for_sources( root: Union[dataset.MapDataset, dataset.IterDataset],