From 7892de2cf66a0196d76012ba4e9ef73d57a9fd9a Mon Sep 17 00:00:00 2001 From: Grain Team Date: Tue, 6 Jan 2026 13:43:51 -0800 Subject: [PATCH] Apply traceback filtering to MapDataset `__getitem__`. PiperOrigin-RevId: 852919724 --- grain/_src/core/BUILD | 1 + grain/_src/core/traceback_util.py | 63 ++++++----------- grain/_src/core/traceback_util_test.py | 95 ++++++++++++++++++++++++-- grain/_src/python/dataset/BUILD | 1 + grain/_src/python/dataset/dataset.py | 64 +++++++++++++++-- 5 files changed, 167 insertions(+), 57 deletions(-) diff --git a/grain/_src/core/BUILD b/grain/_src/core/BUILD index fc848206b..8b7e38d86 100644 --- a/grain/_src/core/BUILD +++ b/grain/_src/core/BUILD @@ -238,6 +238,7 @@ py_test( deps = [ ":config", ":traceback_util", + "//grain", "@abseil-py//absl/testing:absltest", ], ) diff --git a/grain/_src/core/traceback_util.py b/grain/_src/core/traceback_util.py index bebdbe771..f2950e04f 100644 --- a/grain/_src/core/traceback_util.py +++ b/grain/_src/core/traceback_util.py @@ -20,7 +20,7 @@ from collections.abc import Callable import functools -import os +import pathlib import traceback import types from typing import Any, TypeVar, cast @@ -29,16 +29,6 @@ C = TypeVar("C", bound=Callable[..., Any]) -_exclude_paths: list[str] = [] - - -def register_exclusion(path: str): - _exclude_paths.append(path) - - -register_exclusion(__file__) - - _grain_message_append = ( "The stack trace below excludes Grain-internal frames.\n" "The preceding is the original exception that occurred, unmodified.\n" @@ -46,41 +36,22 @@ def register_exclusion(path: str): ) -def _path_starts_with(path: str, path_prefix: str) -> bool: - """Checks if a given path starts with a specified path prefix. - - This function compares two paths after converting them to absolute paths. It - handles cases where paths might be on different drives or might not exist. - - Args: - path: The path to check. - path_prefix: The prefix to check against. - - Returns: - True if `path` starts with `path_prefix`, False otherwise. - """ - path = os.path.abspath(path) - path_prefix = os.path.abspath(path_prefix) - try: - common = os.path.commonpath([path, path_prefix]) - except ValueError: - # path and path_prefix are both absolute, the only case will raise a - # ValueError is different drives. - # https://docs.python.org/3/library/os.path.html#os.path.commonpath - return False - try: - return common == path_prefix or os.path.samefile(common, path_prefix) - except OSError: - # One of the paths may not exist. - return False - - def include_frame(f: types.FrameType) -> bool: return include_filename(f.f_code.co_filename) def include_filename(filename: str) -> bool: - return not any(_path_starts_with(filename, path) for path in _exclude_paths) + # We want to exclude all files in `grain/_src` and its subdirectories. + # Pathlib ensures path separator differences are accounted for on + # different platforms. + try: + parts = pathlib.Path(filename).parts + except Exception: # pylint: disable=broad-except + return True + for i in range(len(parts) - 1): + if parts[i] == "grain" and parts[i + 1] == "_src": + return False + return True def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType): @@ -184,8 +155,8 @@ def _ipython_supports_tracebackhide() -> bool: def _filtering_mode() -> str: - mode = config.py_traceback_filtering - if mode is None or mode == "auto": + mode = config.get_or_default("py_traceback_filtering") + if mode == "auto": if (_running_under_ipython() and _ipython_supports_tracebackhide()): mode = "tracebackhide" else: @@ -234,6 +205,11 @@ def run_with_traceback_filter(fun: C) -> C: A wrapped version of `fun` that filters tracebacks. """ + # Short circuit if the function is already filtered to avoid wrapping + # the same function multiple times. + if getattr(fun, "is_traceback_filtered", False): + return fun + @functools.wraps(fun) def reraise_with_filtered_traceback(*args, **kwargs): __tracebackhide__ = True # pylint: disable=invalid-name,unused-variable @@ -271,4 +247,5 @@ def reraise_with_filtered_traceback(*args, **kwargs): raise finally: del mode, tb + reraise_with_filtered_traceback.is_traceback_filtered = True return cast(C, reraise_with_filtered_traceback) diff --git a/grain/_src/core/traceback_util_test.py b/grain/_src/core/traceback_util_test.py index 91cce90f1..641aac9c0 100644 --- a/grain/_src/core/traceback_util_test.py +++ b/grain/_src/core/traceback_util_test.py @@ -13,17 +13,16 @@ # limitations under the License. """Tests for traceback filtering utilities.""" +import pathlib import traceback from typing import Callable +from unittest import mock +from absl.testing import absltest +import grain from grain._src.core import traceback_util from grain._src.core.config import config # pylint: disable=g-importing-member -from absl.testing import absltest - - -traceback_util.register_exclusion(__file__) - def _assert_exception_with_short_traceback( self: absltest.TestCase, @@ -56,8 +55,7 @@ def _assert_exception_with_short_traceback( f"Expected {expected_error_type} to be raised, but got" f" {type(e)} instead." ) - print(f"traceback: {tb}") - self.assertLess(len(tb), 15) + self.assertLess(len(tb), 15, f"Traceback is too long: \n{tb}") @traceback_util.run_with_traceback_filter @@ -122,6 +120,89 @@ def test_traceback_mode_tracebackhide_decorator_applies_tracebackhide(self): # Verify the frames are not actually removed. self.assertGreater(frame_count, 150) + def test_wrapper_is_not_applied_twice(self): + def f(): + pass + + f_wrapped = traceback_util.run_with_traceback_filter(f) + f_wrapped_again = traceback_util.run_with_traceback_filter(f_wrapped) + self.assertIs(f_wrapped, f_wrapped_again) + + def test_include_filename(self): + self.assertFalse( + traceback_util.include_filename("path/to/grain/_src/foo.py") + ) + + # Verify Windows path handling (simulated on non-Windows if needed) + with mock.patch.object(pathlib, "Path", pathlib.PureWindowsPath): + self.assertFalse( + traceback_util.include_filename(r"C:\path\to\grain\_src\foo.py") + ) + + # Check that "grain" and "_src" must be adjacent + self.assertTrue( + traceback_util.include_filename("path/to/grain/something/_src/foo.py") + ) + + +class AddOneTransform(grain.transforms.Map): + + def map(self, x: int) -> int: + return x + 1 + + +class RaiseErrorTransform(grain.transforms.Map): + + def map(self, x: int) -> int: + raise ValueError("Boom!") + + +class TracebackFilterTest(absltest.TestCase): + + def test_datasource_multiple_transforms_filters_traceback_on_iterator(self): + range_ds = grain.sources.RangeDataSource(0, 10, 1) + sampler = grain.samplers.IndexSampler(num_records=10, seed=42) + ops = [RaiseErrorTransform()] + for _ in range(100): + ops.append(AddOneTransform()) + data_loader = grain.DataLoader( + data_source=range_ds, sampler=sampler, operations=ops + ) + _assert_exception_with_short_traceback( + self, lambda: next(iter(data_loader)), ValueError + ) + + def test_dataset_multiple_transforms_filters_traceback_on_iterator_on_iterator( + self, + ): + range_ds = grain.MapDataset.range(0, 10) + range_ds = range_ds.map(RaiseErrorTransform()) + for _ in range(100): + range_ds = range_ds.map(AddOneTransform()) + _assert_exception_with_short_traceback( + self, lambda: next(iter(range_ds)), ValueError + ) + + def test_dataset_multiple_transforms_filters_traceback_on_getitem(self): + range_ds = grain.MapDataset.range(0, 10) + range_ds = range_ds.map(RaiseErrorTransform()) + for _ in range(100): + range_ds = range_ds.map(AddOneTransform()) + _assert_exception_with_short_traceback( + self, lambda: range_ds[0], ValueError + ) + + def test_dataset_multiple_transforms_filters_traceback_on_batch( + self, + ): + range_ds = grain.MapDataset.range(0, 10) + range_ds = range_ds.map(RaiseErrorTransform()) + for _ in range(100): + range_ds = range_ds.map(AddOneTransform()) + _assert_exception_with_short_traceback( + self, lambda: range_ds.batch(2)[0], ValueError + ) + if __name__ == "__main__": absltest.main() diff --git a/grain/_src/python/dataset/BUILD b/grain/_src/python/dataset/BUILD index 3d6b75cba..796045ba8 100644 --- a/grain/_src/python/dataset/BUILD +++ b/grain/_src/python/dataset/BUILD @@ -49,6 +49,7 @@ py_library( "//grain/_src/core:exceptions", "//grain/_src/core:monitoring", "//grain/_src/core:sharding", + "//grain/_src/core:traceback_util", "//grain/_src/core:transforms", "//grain/_src/core:tree_lib", "//grain/_src/python:checkpointing", diff --git a/grain/_src/python/dataset/dataset.py b/grain/_src/python/dataset/dataset.py index 8508f2b96..e0793cf7a 100644 --- a/grain/_src/python/dataset/dataset.py +++ b/grain/_src/python/dataset/dataset.py @@ -50,11 +50,13 @@ from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence import functools import json -from typing import Any, Generic, TypeVar, Union, cast, overload +from typing import Any, Generic, Protocol, TypeVar, Union, cast, overload, runtime_checkable import warnings from etils import epath +from grain._src.core import config as grain_config # pylint: disable=g-importing-member. from grain._src.core import monitoring as grain_monitoring +from grain._src.core import traceback_util from grain._src.core import transforms from grain._src.python import checkpointing from grain._src.python import options as grain_options @@ -79,6 +81,11 @@ S = TypeVar("S") +@runtime_checkable +class ParentDataset(Protocol): + has_child: bool + + class _Dataset: """Node of a dataset tree structure that represents data transformation. @@ -358,6 +365,10 @@ def __init__(self, parents: MapDataset | Sequence[MapDataset] = ()): super().__init__(parents) self._parents = cast(Sequence[MapDataset], self._parents) _api_usage_counter.Increment("MapDataset") + self.has_child = False + for parent in self._parents: + if isinstance(parent, ParentDataset): + parent.has_child = True @property def parents(self) -> Sequence[MapDataset]: @@ -854,6 +865,10 @@ def to_iter_dataset( def __iter__(self) -> DatasetIterator[T]: return self.to_iter_dataset().__iter__() + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.__getitem__ = map_dataset_injector(cls.__getitem__) + # pytype: disable=attribute-error # pylint: disable=protected-access @@ -911,6 +926,26 @@ def _stats(self) -> dataset_stats.Stats: # pylint: enable=protected-access +def map_dataset_injector(delegate_func: Callable[..., Any]): + """Injects arbitrary logic to the passed-in delegate function at the end of the pipeline.""" + + @functools.wraps(delegate_func) + def wrapper(self, index): + # Filter traceback if the dataset is a leaf node in the pipeline and + # traceback filtering is enabled. + if ( + isinstance(self, ParentDataset) + and not self.has_child + and traceback_filter_mode() != "off" + ): + return traceback_util.run_with_traceback_filter(delegate_func)( + self, index + ) + return delegate_func(self, index) + + return wrapper + + def output_dataset_injector(create_iter_func: Callable[..., DatasetIterator]): """Injects an _OutputIterDataset to the end of the dataset pipeline. @@ -924,12 +959,12 @@ def output_dataset_injector(create_iter_func: Callable[..., DatasetIterator]): @functools.wraps(create_iter_func) def wrapper(self, *args, **kwargs): if ( - isinstance(self, _OutputIterDataset) - or not hasattr(self, "has_child") - or self.has_child + not isinstance(self, _OutputIterDataset) + and isinstance(self, ParentDataset) + and not self.has_child ): - return create_iter_func(self, *args, **kwargs) - return _OutputIterDataset(self).__iter__(*args, **kwargs) + return _OutputIterDataset(self).__iter__(*args, **kwargs) + return create_iter_func(self, *args, **kwargs) return wrapper @@ -1001,7 +1036,7 @@ def __init__( self.has_child = False for parent in self._parents: - if isinstance(parent, IterDataset) and hasattr(parent, "has_child"): + if isinstance(parent, ParentDataset): parent.has_child = True if not isinstance(self, _OutputIterDataset): _api_usage_counter.Increment("IterDataset") @@ -1699,6 +1734,11 @@ def _element_spec(self) -> Any: return get_element_spec(self._parent) +def traceback_filter_mode() -> str: + """Returns the traceback filter mode.""" + return grain_config.config.get_or_default("py_traceback_filtering") + + def is_thread_prefetch_injection_enabled() -> bool: """Returns whether thread prefetch injection experiment is enabled.""" return False @@ -1721,6 +1761,16 @@ def __iter__(self) -> DatasetIterator[T]: ): if not prefetch.is_prefetch_iterator(iterator): iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1) + + filter_mode = traceback_filter_mode() + if filter_mode != "off": + # pytype's ParameterizedClass (i.e. the type of + # DatasetIterator[T]) currently doesn't properly support setting + # attributes, so we cast to a non-parameterized type as a workaround. + cast(DatasetIterator, iterator).__class__.__next__ = ( + traceback_util.run_with_traceback_filter(iterator.__class__.__next__) + ) + return iterator