Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,14 @@ def _stats(self) -> dataset_stats.Stats:
"""Returns the Stats object for recording statistics about this dataset."""
return self._initialize_stats(base.ExecutionTrackingMode.DISABLED)

def close(self) -> None:
"""Closes the dataset and releases any resources by recursively closing all
parent datasets in the pipeline. This method is safe to call multiple
times.
"""
for parent in self._parents:
parent.close()

# pytype: enable=attribute-error
# pylint: enable=protected-access

Expand Down
42 changes: 26 additions & 16 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,20 @@ def close(self) -> None:
if self._closed:
return
self._closed = True
# Shutdown the thread pool executor if it exists.
if hasattr(self, "_executor"):
# Submit close to each worker thread for thread-local resource cleanup.
cleanup_futures = [
self._executor.submit(self._map_parent.close)
for _ in range(self._num_threads)
]
futures.wait(cleanup_futures, timeout=5.0)
self._executor.shutdown(wait=False)
# Cancel all pending futures in the buffer.
while self._buffer:
future = self._buffer.popleft()
future.cancel()
# Also close from main thread for any main-thread resources.
self._map_parent.close()


def _iterator_with_context(
Expand Down Expand Up @@ -590,21 +597,24 @@ def __call__(
for _ in range(self._state[_ITERATIONS_TO_SKIP][str(worker_index)]):
_ = next(it)
last_recorded_state_time = time.time()
for element in it:
now = time.time()
element = _copy_struct_to_shm(element, min_size=min_shm_size)
# If the node is prefetch, we already record the bytes produced in it's
# __next__ method.
if not it._stats._config.is_prefetch:
it._stats.record_bytes_produced(element)
if (
self._always_report_worker_state
or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S
):
last_recorded_state_time = now
yield (element, it.get_state()) # pytype: disable=attribute-error
else:
yield (element, None)
try:
for element in it:
now = time.time()
element = _copy_struct_to_shm(element, min_size=min_shm_size)
# If the node is prefetch, we already record the bytes produced in it's
# __next__ method.
if not it._stats._config.is_prefetch:
it._stats.record_bytes_produced(element)
if (
self._always_report_worker_state
or now - last_recorded_state_time >= _RECORD_STATE_INTERVAL_S
):
last_recorded_state_time = now
yield (element, it.get_state()) # pytype: disable=attribute-error
else:
yield (element, None)
finally:
it.close()

def serialize(self) -> bytes:
"""Overrides the default implementation to generate better error messages."""
Expand Down
89 changes: 89 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1426,5 +1426,94 @@ def test_mp_context_is_set_correctly(self):
self.assertEqual(context.process_count, num_workers)


class _CloseTrackingDataSource:
"""A data source that tracks close() calls and which threads called it."""

def __init__(self, data):
self._data = data
self.close_threads = []
self._lock = threading.Lock()

def __len__(self):
return len(self._data)

def __getitem__(self, index):
return self._data[index]

def close(self):
with self._lock:
self.close_threads.append(threading.current_thread().ident)


class DataSourceCloseTest(parameterized.TestCase):
"""Tests for data source close() propagation."""

def test_close_propagates_to_data_source(self):
source = _CloseTrackingDataSource([1, 2, 3, 4, 5])
ds = dataset.MapDataset.source(source)
num_threads = 2
read_options = options.ReadOptions(
num_threads=num_threads, prefetch_buffer_size=2
)
iter_ds = ds.to_iter_dataset(read_options=read_options)
it = iter(iter_ds)

_ = [next(it) for _ in range(3)]
self.assertEmpty(source.close_threads)

it.close()
# Called from each worker thread + main thread.
self.assertLen(source.close_threads, num_threads + 1)

def test_close_called_from_each_worker_thread(self):
source = _CloseTrackingDataSource([1, 2, 3, 4, 5])
ds = dataset.MapDataset.source(source)
num_threads = 4
read_options = options.ReadOptions(
num_threads=num_threads, prefetch_buffer_size=4
)
iter_ds = ds.to_iter_dataset(read_options=read_options)
it = iter(iter_ds)

_ = next(it)
it.close()

# Called from each worker thread + main thread.
self.assertLen(source.close_threads, num_threads + 1)

def test_close_without_prefetch(self):
source = _CloseTrackingDataSource([1, 2, 3])
ds = dataset.MapDataset.source(source)
read_options = options.ReadOptions(prefetch_buffer_size=0)
iter_ds = ds.to_iter_dataset(read_options=read_options)
it = iter(iter_ds)

_ = next(it)
it.close()
# No executor, so close called once from main thread.
self.assertLen(source.close_threads, 1)

def test_iterator_close_is_idempotent(self):
source = _CloseTrackingDataSource([1, 2, 3])
ds = dataset.MapDataset.source(source)
read_options = options.ReadOptions(prefetch_buffer_size=0)
iter_ds = ds.to_iter_dataset(read_options=read_options)
it = iter(iter_ds)

_ = next(it)
it.close()
it.close()
# Iterator close is idempotent.
self.assertLen(source.close_threads, 1)

def test_map_dataset_close_propagates(self):
source = _CloseTrackingDataSource([1, 2, 3, 4, 5])
ds = dataset.MapDataset.source(source).map(lambda x: x * 2).batch(2)

self.assertEmpty(source.close_threads)
ds.close()
self.assertLen(source.close_threads, 1)


if __name__ == '__main__':
absltest.main()
5 changes: 5 additions & 0 deletions grain/_src/python/dataset/transformations/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def paths(self) -> str | Sequence[str]:
else:
return []

def close(self) -> None:
"""Closes the underlying data source if it has a close method."""
if hasattr(self._source, "close"):
self._source.close()


def log_lineage_for_sources(
root: Union[dataset.MapDataset, dataset.IterDataset],
Expand Down