diff --git a/grain/_src/python/dataset/transformations/batch.py b/grain/_src/python/dataset/transformations/batch.py index 765e83ec2..dc09814fb 100644 --- a/grain/_src/python/dataset/transformations/batch.py +++ b/grain/_src/python/dataset/transformations/batch.py @@ -78,6 +78,7 @@ def __call__(self, values: Sequence[T]) -> T: def _batch_fn(*xs: Sequence[T]) -> T: # If the thread pool is not available or the elements are not NumPy # arrays, fall back to the standard serial `np.stack` operation. + # TODO: Support parallel batch when elements are not NumPy if (self._parallel_batch_executor is None) or not isinstance( xs[0], np.ndarray ):