From c2e66088b62f4bfaf27294ba57b2dc5c883d113c Mon Sep 17 00:00:00 2001 From: Grain Team Date: Mon, 20 Oct 2025 17:09:20 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 821865877 --- .../python/dataset/transformations/prefetch.py | 3 +++ grain/_src/python/grain_pool.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 1f58925e4..05ff23d17 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -534,6 +534,7 @@ def __call__( worker_count: int, start_profiling_event: synchronize.Event | None = None, stop_profiling_event: synchronize.Event | None = None, + profiling_timeout: Any | None = None, stats_out_queue: queues.Queue | None = None, ) -> Iterator[tuple[T, Optional[dict[str, Any]]]]: if worker_count > 1: @@ -644,6 +645,7 @@ def __init__( ) self._start_profiling_event = mp.get_context("spawn").Event() self._stop_profiling_event = mp.get_context("spawn").Event() + self._profiling_timeout = mp.get_context("spawn").Value("i", -1) self._state: dict[str, dict[str, Any] | int] = { _WORKERS_STATE: workers_state, @@ -751,6 +753,7 @@ def _create_iterator_context(self) -> grain_pool.MultiProcessIterator[T]: self._worker_init_fn, self._start_profiling_event, self._stop_profiling_event, + self._profiling_timeout, self._stats_in_queues, ) diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py index 0bcf381e3..0860afc6c 100644 --- a/grain/_src/python/grain_pool.py +++ b/grain/_src/python/grain_pool.py @@ -137,6 +137,7 @@ def __call__( worker_count: int, start_profiling_event: synchronize.Event | None = None, stop_profiling_event: synchronize.Event | None = None, + profiling_timeout: Any | None = None, stats_out_queue: queues.Queue | None = None, ) -> Iterator[T]: """Returns a generator of elements.""" @@ -188,6 +189,7 @@ def _initialize_and_get_element_producer( worker_count: int, start_profiling_event: synchronize.Event, stop_profiling_event: synchronize.Event, + profiling_timeout: Any, stats_out_queue: queues.Queue, ) -> Iterator[Any]: """Unpickles the element producer from the args queue and closes the queue.""" @@ -214,6 +216,7 @@ def _initialize_and_get_element_producer( worker_count=worker_count, start_profiling_event=start_profiling_event, stop_profiling_event=stop_profiling_event, + profiling_timeout=profiling_timeout, stats_out_queue=stats_out_queue, ) # args_queue has only a single argument and thus can be safely closed. @@ -229,6 +232,7 @@ def _worker_loop( termination_event: synchronize.Event, start_profiling_event: synchronize.Event, stop_profiling_event: synchronize.Event, + profiling_timeout: Any, worker_index: int, worker_count: int, enable_profiling: bool, @@ -250,6 +254,7 @@ def _worker_loop( worker_count=worker_count, start_profiling_event=start_profiling_event, stop_profiling_event=stop_profiling_event, + profiling_timeout=profiling_timeout, stats_out_queue=stats_out_queue, ) profiling_enabled = enable_profiling and worker_index == 0 @@ -339,6 +344,7 @@ def __init__( termination_event: threading.Event | None = None, start_profiling_event: synchronize.Event | None = None, stop_profiling_event: synchronize.Event | None = None, + profiling_timeout: Any | None = None, options: MultiprocessingOptions, worker_init_fn: Callable[[int, int], None] | None = None, stats_in_queues: tuple[queues.Queue, ...] | None = None, @@ -356,6 +362,7 @@ def __init__( all workers are done processing data. GrainPool will not set this event. start_profiling_event: Event to start prism profiling. stop_profiling_event: Event to stop prism profiling. + profiling_timeout: Shared value for profiling timeout. options: Options for multiprocessing. See MultiprocessingOptions. worker_init_fn: Function to run in each worker process before the element producer. The function takes two arguments: the current worker index and @@ -409,6 +416,7 @@ def __init__( termination_event=self._workers_termination_event, start_profiling_event=start_profiling_event, stop_profiling_event=stop_profiling_event, + profiling_timeout=profiling_timeout, worker_index=worker_index, worker_count=options.num_workers, enable_profiling=options.enable_profiling, @@ -614,6 +622,7 @@ def _process_elements_in_grain_pool( termination_event: threading.Event, start_profiling_event: synchronize.Event | None, stop_profiling_event: synchronize.Event | None, + profiling_timeout: Any | None, worker_index_to_start_reading: int, worker_init_fn: Callable[[int, int], None] | None, stats_in_queues: tuple[queues.Queue, ...] | None, @@ -633,6 +642,7 @@ def read_thread_should_stop(): termination_event=termination_event, start_profiling_event=start_profiling_event, stop_profiling_event=stop_profiling_event, + profiling_timeout=profiling_timeout, options=multiprocessing_options, worker_init_fn=worker_init_fn, stats_in_queues=stats_in_queues, @@ -691,6 +701,7 @@ def __init__( worker_init_fn: Callable[[int, int], None] | None = None, start_profiling_event: synchronize.Event | None = None, stop_profiling_event: synchronize.Event | None = None, + profiling_timeout: Any | None = None, stats_in_queues: tuple[queues.Queue, ...] | None = None, ): """Initializes MultiProcessIterator. @@ -706,6 +717,7 @@ def __init__( the total worker count. start_profiling_event: Event to start prism profiling. stop_profiling_event: Event to stop prism profiling. + profiling_timeout: Shared value for profiling timeout. stats_in_queues: Queues to send execution summaries from worker processes to the main process. """ @@ -720,6 +732,7 @@ def __init__( self._stats_in_queues = stats_in_queues self._start_profiling_event = start_profiling_event self._stop_profiling_event = stop_profiling_event + self._profiling_timeout = profiling_timeout def __del__(self): if self._reader_thread: @@ -749,6 +762,7 @@ def start_prefetch(self) -> None: termination_event=self._termination_event, start_profiling_event=self._start_profiling_event, stop_profiling_event=self._stop_profiling_event, + profiling_timeout=self._profiling_timeout, worker_index_to_start_reading=self._last_worker_index + 1, worker_init_fn=self._worker_init_fn, stats_in_queues=self._stats_in_queues,