Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ py_library(
":options",
":record",
":shared_memory_array",
":variable_size_queue",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"//grain/_src/core:parallel",
Expand All @@ -245,6 +246,7 @@ py_test(
":grain_pool",
":options",
":record",
":variable_size_queue",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"@abseil-py//absl/flags",
Expand Down Expand Up @@ -370,3 +372,19 @@ py_library(
"@pypi//etils:pkg",
],
)

py_library(
name = "variable_size_queue",
srcs = ["variable_size_queue.py"],
srcs_version = "PY3",
)

py_test(
name = "variable_size_queue_test",
srcs = ["variable_size_queue_test.py"],
srcs_version = "PY3",
deps = [
":variable_size_queue",
"@abseil-py//absl/testing:absltest",
],
)
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ py_library(
"//grain/_src/python:grain_pool",
"//grain/_src/python:options",
"//grain/_src/python:shared_memory_array",
"//grain/_src/python:variable_size_queue",
"//grain/proto:execution_summary_py_pb2",
"@abseil-py//absl/logging",
"@pypi//cloudpickle:pkg",
Expand Down
19 changes: 17 additions & 2 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from grain._src.python import grain_pool
from grain._src.python import options as grain_options
from grain._src.python import shared_memory_array
from grain._src.python import variable_size_queue
from grain._src.python.dataset import base
from grain._src.python.dataset import dataset
from grain._src.python.dataset import stats as dataset_stats
Expand Down Expand Up @@ -760,6 +761,14 @@ def __str__(self) -> str:
f"multiprocessing_options={self._multiprocessing_options})"
)

def set_per_worker_buffer_size(self, per_worker_buffer_size: int):
if self._raw_iterator is None:
raise ValueError(
"Cannot change per worker buffer size before the iterator has been"
" initialized."
)
self._raw_iterator.set_per_worker_buffer_size(per_worker_buffer_size)


class ThreadPrefetchIterDataset(dataset.IterDataset[T]):
"""Iterable dataset that uses a synchronized queue for prefetching.
Expand Down Expand Up @@ -858,8 +867,8 @@ def __init__(
self._closed = False
self._prefetch_thread: threading.Thread | None = None
self._prefetch_should_stop: threading.Event = threading.Event()
self._buffer: queue.Queue[tuple[T, StateT, Exception | None]] = queue.Queue(
maxsize=self._prefetch_buffer_size
self._buffer: variable_size_queue.VariableSizeQueue = (
variable_size_queue.VariableSizeQueue(self._prefetch_buffer_size)
)

# pytype: disable=attribute-error
Expand Down Expand Up @@ -953,6 +962,12 @@ def _stop_prefetch(self):
# exit.
self._clear_buffer()

def set_prefetch_buffer_size(self, prefetch_buffer_size: int):
if prefetch_buffer_size <= 0:
raise ValueError("`prefetch_buffer_size` must be positive.")
self._prefetch_buffer_size = prefetch_buffer_size
self._buffer.set_max_size(prefetch_buffer_size)

def get_state(self) -> StateT:
return self._step_zero_state if self._state is None else self._state

Expand Down
132 changes: 132 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,9 +917,97 @@ def map_fn(x):
],
)

def test_set_per_worker_buffer_size_increase(self):
ds = dataset.MapDataset.range(10).map(lambda x: x + 1).to_iter_dataset()
mp_options = options.MultiprocessingOptions(
num_workers=1, per_worker_buffer_size=1
)
ds = prefetch.MultiprocessPrefetchIterDataset(
ds,
mp_options,
)
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, ds.__iter__())
self.assertEqual(next(it), 1)
time.sleep(1)
self.assertEqual(
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 1 # pytype: disable=attribute-error
)
it.set_per_worker_buffer_size(2)
self.assertEqual(
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 2 # pytype: disable=attribute-error
)
self.assertEqual(next(it), 2)
self.assertEqual(list(it), list(range(3, 11)))

def test_set_per_worker_buffer_size_decrease(self):
ds = dataset.MapDataset.range(10).map(lambda x: x + 1).to_iter_dataset()
mp_options = options.MultiprocessingOptions(
num_workers=1, per_worker_buffer_size=2
)
ds = prefetch.MultiprocessPrefetchIterDataset(
ds,
mp_options,
)
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, ds.__iter__())
self.assertEqual(next(it), 1)
time.sleep(1)
self.assertEqual(
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 2 # pytype: disable=attribute-error
)
it.set_per_worker_buffer_size(1)
self.assertEqual(
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 1 # pytype: disable=attribute-error
)
self.assertEqual(next(it), 2)
self.assertEqual(list(it), list(range(3, 11)))

def test_set_per_worker_buffer_size_to_trigger_error(self):
def f(x):
if x >= 5:
raise ValueError(f'x={x} is too large')
return x

ds = (
dataset.MapDataset.range(10)
.map(f)
.to_iter_dataset(
read_options=options.ReadOptions(prefetch_buffer_size=0)
)
)
mp_options = options.MultiprocessingOptions(
num_workers=1, per_worker_buffer_size=1
)
it = prefetch.MultiprocessPrefetchIterDataset(ds, mp_options).__iter__()
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, it)
self.assertEqual(next(it), 0)
it.set_per_worker_buffer_size(10)
next(it)
time.sleep(3)
q = it._raw_iterator._reader_queue # pytype: disable=attribute-error
# Prefetching will end once an error is put into the reader queue. The
# elements 2, 3, 4 will be in the queue along with the error for 5.
self.assertEqual(q.qsize(), 4)


class ThreadPrefetchIterDatasetTest(parameterized.TestCase):

def _wait_for_buffer_size(
self,
it: prefetch.ThreadPrefetchDatasetIterator,
size: int,
timeout_s: float = 5,
):
"""Waits until iterator's buffer reaches size."""
start_time = time.time()
while it._buffer.qsize() != size:
if time.time() - start_time > timeout_s:
raise TimeoutError(
f'Buffer size {it._buffer.qsize()} did not reach {size} within'
f' {timeout_s} seconds.'
)
time.sleep(0.01)
self.assertEqual(it._buffer.qsize(), size)

def setUp(self):
super().setUp()
self.ds = (
Expand Down Expand Up @@ -1148,6 +1236,50 @@ def test_no_mem_leak_with_double_prefetch(self, close: bool):
if close:
it.close() # pytype: disable=attribute-error

def test_set_prefetch_buffer_size_increase(self):
ds = prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=1)
it = ds.__iter__()
self.assertIsInstance(it, prefetch.ThreadPrefetchDatasetIterator)
it = cast(prefetch.ThreadPrefetchDatasetIterator, it)

self.assertEqual(it._prefetch_buffer_size, 1)
self.assertEqual(next(it), 1)
self._wait_for_buffer_size(it, 1)

it.set_prefetch_buffer_size(2)
self.assertEqual(it._prefetch_buffer_size, 2)
self.assertEqual(next(it), 3)
self._wait_for_buffer_size(it, 2)
self.assertEqual(next(it), 5)
self._wait_for_buffer_size(it, 2)

def test_set_prefetch_buffer_size_decrease(self):
ds = prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=2)
it = ds.__iter__()
self.assertIsInstance(it, prefetch.ThreadPrefetchDatasetIterator)
it = cast(prefetch.ThreadPrefetchDatasetIterator, it)

self.assertEqual(it._prefetch_buffer_size, 2)
self.assertEqual(next(it), 1)
self._wait_for_buffer_size(it, 2) # 3, 5
it.set_prefetch_buffer_size(1)
self.assertEqual(it._prefetch_buffer_size, 1)
self.assertEqual(it._buffer.qsize(), 2) # 3, 5
self.assertEqual(next(it), 3) # reads 3, buffer has 5
self.assertEqual(it._buffer.qsize(), 1) # 5
self.assertEqual(next(it), 5) # reads 5, buffer has 0
self._wait_for_buffer_size(it, 1) # prefetch gets 7

def test_set_prefetch_buffer_size_to_0_raises_error(self):
ds = prefetch.ThreadPrefetchIterDataset(self.ds, prefetch_buffer_size=1)
it = ds.__iter__()
self.assertIsInstance(it, prefetch.ThreadPrefetchDatasetIterator)
it = cast(prefetch.ThreadPrefetchDatasetIterator, it)
with self.assertRaisesRegex(
ValueError, '`prefetch_buffer_size` must be positive.'
):
it.set_prefetch_buffer_size(0)


if __name__ == '__main__':
absltest.main()
Loading
Loading