diff --git a/grain/_src/python/dataset/transformations/process_prefetch.py b/grain/_src/python/dataset/transformations/process_prefetch.py index d5552dfc..4b271ff1 100644 --- a/grain/_src/python/dataset/transformations/process_prefetch.py +++ b/grain/_src/python/dataset/transformations/process_prefetch.py @@ -616,7 +616,11 @@ def multiprocess_prefetch( shards.append(worker_ds) ds = interleave.InterleaveIterDataset( - shards, cycle_length=num_workers, iter_buffer_size=buffer_size + shards, + cycle_length=num_workers, + num_make_iter_threads=num_workers, + make_iter_buffer_size=num_workers, + iter_buffer_size=buffer_size, ) # Apply options from parent dataset because interleave dataset does not # propagate options.