Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions grain/_src/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ py_test(
deps = [
":config",
":traceback_util",
"//grain",
"@abseil-py//absl/testing:absltest",
],
)
63 changes: 20 additions & 43 deletions grain/_src/core/traceback_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from collections.abc import Callable
import functools
import os
import pathlib
import traceback
import types
from typing import Any, TypeVar, cast
Expand All @@ -29,58 +29,29 @@

C = TypeVar("C", bound=Callable[..., Any])

_exclude_paths: list[str] = []


def register_exclusion(path: str):
_exclude_paths.append(path)


register_exclusion(__file__)


_grain_message_append = (
"The stack trace below excludes Grain-internal frames.\n"
"The preceding is the original exception that occurred, unmodified.\n"
"\n--------------------"
)


def _path_starts_with(path: str, path_prefix: str) -> bool:
"""Checks if a given path starts with a specified path prefix.

This function compares two paths after converting them to absolute paths. It
handles cases where paths might be on different drives or might not exist.

Args:
path: The path to check.
path_prefix: The prefix to check against.

Returns:
True if `path` starts with `path_prefix`, False otherwise.
"""
path = os.path.abspath(path)
path_prefix = os.path.abspath(path_prefix)
try:
common = os.path.commonpath([path, path_prefix])
except ValueError:
# path and path_prefix are both absolute, the only case will raise a
# ValueError is different drives.
# https://docs.python.org/3/library/os.path.html#os.path.commonpath
return False
try:
return common == path_prefix or os.path.samefile(common, path_prefix)
except OSError:
# One of the paths may not exist.
return False


def include_frame(f: types.FrameType) -> bool:
return include_filename(f.f_code.co_filename)


def include_filename(filename: str) -> bool:
return not any(_path_starts_with(filename, path) for path in _exclude_paths)
# We want to exclude all files in `grain/_src` and its subdirectories.
# Pathlib ensures path separator differences are accounted for on
# different platforms.
try:
parts = pathlib.Path(filename).parts
except Exception: # pylint: disable=broad-except
return True
for i in range(len(parts) - 1):
if parts[i] == "grain" and parts[i + 1] == "_src":
return False
return True


def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType):
Expand Down Expand Up @@ -184,8 +155,8 @@ def _ipython_supports_tracebackhide() -> bool:


def _filtering_mode() -> str:
mode = config.py_traceback_filtering
if mode is None or mode == "auto":
mode = config.get_or_default("py_traceback_filtering")
if mode == "auto":
if (_running_under_ipython() and _ipython_supports_tracebackhide()):
mode = "tracebackhide"
else:
Expand Down Expand Up @@ -234,6 +205,11 @@ def run_with_traceback_filter(fun: C) -> C:
A wrapped version of `fun` that filters tracebacks.
"""

# Short circuit if the function is already filtered to avoid wrapping
# the same function multiple times.
if getattr(fun, "is_traceback_filtered", False):
return fun

@functools.wraps(fun)
def reraise_with_filtered_traceback(*args, **kwargs):
__tracebackhide__ = True # pylint: disable=invalid-name,unused-variable
Expand Down Expand Up @@ -271,4 +247,5 @@ def reraise_with_filtered_traceback(*args, **kwargs):
raise
finally:
del mode, tb
reraise_with_filtered_traceback.is_traceback_filtered = True
return cast(C, reraise_with_filtered_traceback)
95 changes: 88 additions & 7 deletions grain/_src/core/traceback_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@
# limitations under the License.
"""Tests for traceback filtering utilities."""

import pathlib
import traceback
from typing import Callable
from unittest import mock

from absl.testing import absltest
import grain
from grain._src.core import traceback_util
from grain._src.core.config import config # pylint: disable=g-importing-member

from absl.testing import absltest


traceback_util.register_exclusion(__file__)


def _assert_exception_with_short_traceback(
self: absltest.TestCase,
Expand Down Expand Up @@ -56,8 +55,7 @@ def _assert_exception_with_short_traceback(
f"Expected {expected_error_type} to be raised, but got"
f" {type(e)} instead."
)
print(f"traceback: {tb}")
self.assertLess(len(tb), 15)
self.assertLess(len(tb), 15, f"Traceback is too long: \n{tb}")


@traceback_util.run_with_traceback_filter
Expand Down Expand Up @@ -122,6 +120,89 @@ def test_traceback_mode_tracebackhide_decorator_applies_tracebackhide(self):
# Verify the frames are not actually removed.
self.assertGreater(frame_count, 150)

def test_wrapper_is_not_applied_twice(self):
def f():
pass

f_wrapped = traceback_util.run_with_traceback_filter(f)
f_wrapped_again = traceback_util.run_with_traceback_filter(f_wrapped)
self.assertIs(f_wrapped, f_wrapped_again)

def test_include_filename(self):
self.assertFalse(
traceback_util.include_filename("path/to/grain/_src/foo.py")
)

# Verify Windows path handling (simulated on non-Windows if needed)
with mock.patch.object(pathlib, "Path", pathlib.PureWindowsPath):
self.assertFalse(
traceback_util.include_filename(r"C:\path\to\grain\_src\foo.py")
)

# Check that "grain" and "_src" must be adjacent
self.assertTrue(
traceback_util.include_filename("path/to/grain/something/_src/foo.py")
)


class AddOneTransform(grain.transforms.Map):

def map(self, x: int) -> int:
return x + 1


class RaiseErrorTransform(grain.transforms.Map):

def map(self, x: int) -> int:
raise ValueError("Boom!")


class TracebackFilterTest(absltest.TestCase):

def test_datasource_multiple_transforms_filters_traceback_on_iterator(self):
range_ds = grain.sources.RangeDataSource(0, 10, 1)
sampler = grain.samplers.IndexSampler(num_records=10, seed=42)
ops = [RaiseErrorTransform()]
for _ in range(100):
ops.append(AddOneTransform())
data_loader = grain.DataLoader(
data_source=range_ds, sampler=sampler, operations=ops
)
_assert_exception_with_short_traceback(
self, lambda: next(iter(data_loader)), ValueError
)

def test_dataset_multiple_transforms_filters_traceback_on_iterator_on_iterator(
self,
):
range_ds = grain.MapDataset.range(0, 10)
range_ds = range_ds.map(RaiseErrorTransform())
for _ in range(100):
range_ds = range_ds.map(AddOneTransform())
_assert_exception_with_short_traceback(
self, lambda: next(iter(range_ds)), ValueError
)

def test_dataset_multiple_transforms_filters_traceback_on_getitem(self):
range_ds = grain.MapDataset.range(0, 10)
range_ds = range_ds.map(RaiseErrorTransform())
for _ in range(100):
range_ds = range_ds.map(AddOneTransform())
_assert_exception_with_short_traceback(
self, lambda: range_ds[0], ValueError
)

def test_dataset_multiple_transforms_filters_traceback_on_batch(
self,
):
range_ds = grain.MapDataset.range(0, 10)
range_ds = range_ds.map(RaiseErrorTransform())
for _ in range(100):
range_ds = range_ds.map(AddOneTransform())
_assert_exception_with_short_traceback(
self, lambda: range_ds.batch(2)[0], ValueError
)


if __name__ == "__main__":
absltest.main()
1 change: 1 addition & 0 deletions grain/_src/python/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ py_library(
"//grain/_src/core:exceptions",
"//grain/_src/core:monitoring",
"//grain/_src/core:sharding",
"//grain/_src/core:traceback_util",
"//grain/_src/core:transforms",
"//grain/_src/core:tree_lib",
"//grain/_src/python:checkpointing",
Expand Down
64 changes: 57 additions & 7 deletions grain/_src/python/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@
from collections.abc import Awaitable, Callable, Iterable, Iterator, Mapping, Sequence
import functools
import json
from typing import Any, Generic, TypeVar, Union, cast, overload
from typing import Any, Generic, Protocol, TypeVar, Union, cast, overload, runtime_checkable
import warnings

from etils import epath
from grain._src.core import config as grain_config # pylint: disable=g-importing-member.
from grain._src.core import monitoring as grain_monitoring
from grain._src.core import traceback_util
from grain._src.core import transforms
from grain._src.python import checkpointing
from grain._src.python import options as grain_options
Expand All @@ -79,6 +81,11 @@
S = TypeVar("S")


@runtime_checkable
class ParentDataset(Protocol):
has_child: bool


class _Dataset:
"""Node of a dataset tree structure that represents data transformation.

Expand Down Expand Up @@ -358,6 +365,10 @@ def __init__(self, parents: MapDataset | Sequence[MapDataset] = ()):
super().__init__(parents)
self._parents = cast(Sequence[MapDataset], self._parents)
_api_usage_counter.Increment("MapDataset")
self.has_child = False
for parent in self._parents:
if isinstance(parent, ParentDataset):
parent.has_child = True

@property
def parents(self) -> Sequence[MapDataset]:
Expand Down Expand Up @@ -854,6 +865,10 @@ def to_iter_dataset(
def __iter__(self) -> DatasetIterator[T]:
return self.to_iter_dataset().__iter__()

def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls.__getitem__ = map_dataset_injector(cls.__getitem__)

# pytype: disable=attribute-error
# pylint: disable=protected-access

Expand Down Expand Up @@ -911,6 +926,26 @@ def _stats(self) -> dataset_stats.Stats:
# pylint: enable=protected-access


def map_dataset_injector(delegate_func: Callable[..., Any]):
"""Injects arbitrary logic to the passed-in delegate function at the end of the pipeline."""

@functools.wraps(delegate_func)
def wrapper(self, index):
# Filter traceback if the dataset is a leaf node in the pipeline and
# traceback filtering is enabled.
if (
isinstance(self, ParentDataset)
and not self.has_child
and traceback_filter_mode() != "off"
):
return traceback_util.run_with_traceback_filter(delegate_func)(
self, index
)
return delegate_func(self, index)

return wrapper


def output_dataset_injector(create_iter_func: Callable[..., DatasetIterator]):
"""Injects an _OutputIterDataset to the end of the dataset pipeline.

Expand All @@ -924,12 +959,12 @@ def output_dataset_injector(create_iter_func: Callable[..., DatasetIterator]):
@functools.wraps(create_iter_func)
def wrapper(self, *args, **kwargs):
if (
isinstance(self, _OutputIterDataset)
or not hasattr(self, "has_child")
or self.has_child
not isinstance(self, _OutputIterDataset)
and isinstance(self, ParentDataset)
and not self.has_child
):
return create_iter_func(self, *args, **kwargs)
return _OutputIterDataset(self).__iter__(*args, **kwargs)
return _OutputIterDataset(self).__iter__(*args, **kwargs)
return create_iter_func(self, *args, **kwargs)

return wrapper

Expand Down Expand Up @@ -1001,7 +1036,7 @@ def __init__(

self.has_child = False
for parent in self._parents:
if isinstance(parent, IterDataset) and hasattr(parent, "has_child"):
if isinstance(parent, ParentDataset):
parent.has_child = True
if not isinstance(self, _OutputIterDataset):
_api_usage_counter.Increment("IterDataset")
Expand Down Expand Up @@ -1699,6 +1734,11 @@ def _element_spec(self) -> Any:
return get_element_spec(self._parent)


def traceback_filter_mode() -> str:
"""Returns the traceback filter mode."""
return grain_config.config.get_or_default("py_traceback_filtering")


def is_thread_prefetch_injection_enabled() -> bool:
"""Returns whether thread prefetch injection experiment is enabled."""
return False
Expand All @@ -1721,6 +1761,16 @@ def __iter__(self) -> DatasetIterator[T]:
):
if not prefetch.is_prefetch_iterator(iterator):
iterator = prefetch.ThreadPrefetchDatasetIterator(iterator, 1)

filter_mode = traceback_filter_mode()
if filter_mode != "off":
# pytype's ParameterizedClass (i.e. the type of
# DatasetIterator[T]) currently doesn't properly support setting
# attributes, so we cast to a non-parameterized type as a workaround.
cast(DatasetIterator, iterator).__class__.__next__ = (
traceback_util.run_with_traceback_filter(iterator.__class__.__next__)
)

return iterator


Expand Down