From 84942dc261da927a658032c1ab531c83eafc4763 Mon Sep 17 00:00:00 2001 From: Ihor Indyk Date: Thu, 8 Jan 2026 13:21:40 -0800 Subject: [PATCH] Add support for creating a `MapTransform` from callable. PiperOrigin-RevId: 853861917 --- CHANGELOG.md | 2 ++ grain/_src/core/transforms.py | 24 +++++++++++++++++++++++- grain/_src/core/transforms_test.py | 27 +++++++++++++++++++++++++-- 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index be43667b4..e260687bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ changes. Best viewed [here](https://google-grain.readthedocs.io/en/latest/change errors. * Adds experimental support for `get_next_index` and `set_next_index` to fetch and advance a `grain.DatasetIterator` to the given produced element index. + * `grain.transforms.Map` can now be created from a callable with + `grain.transforms.Map.from_callable`. * Breaking changes: diff --git a/grain/_src/core/transforms.py b/grain/_src/core/transforms.py index 3d8773193..683710569 100644 --- a/grain/_src/core/transforms.py +++ b/grain/_src/core/transforms.py @@ -32,7 +32,16 @@ import numpy as np -class MapTransform(abc.ABC): +# We use a metaclass to avoid having `from_callable` on a `MapTransform` +# instance and instead only have it on the class. +class _MapMeta(abc.ABCMeta): + + def from_callable(cls, fn: Callable[..., Any]) -> MapTransform: + """Returns a MapTransform that applies the given callable.""" + return _MapFromCallable(fn) + + +class MapTransform(metaclass=_MapMeta): """Abstract base class for all 1:1 transformations of elements. Implementations should be threadsafe since they are often executed in @@ -44,6 +53,19 @@ def map(self, element): """Maps a single element.""" +class _MapFromCallable(MapTransform): + """A MapTransform that applies the given callable.""" + + def __init__(self, fn: Callable[..., Any]): + self._fn = fn + + def map(self, element): + return self._fn(element) + + def __repr__(self): + return f"MapFromCallable<{get_pretty_transform_name(self._fn)}>" + + class RandomMapTransform(abc.ABC): """Abstract base class for all random 1:1 transformations of elements. diff --git a/grain/_src/core/transforms_test.py b/grain/_src/core/transforms_test.py index a35f1e5b2..c8604d412 100644 --- a/grain/_src/core/transforms_test.py +++ b/grain/_src/core/transforms_test.py @@ -43,6 +43,10 @@ def __repr__(self): return "CustomRepr" +def add_one(x): + return x + 1 + + @absltest.skipIf( platform.system() == "Windows", "Skipped due to windows paths." ) @@ -70,8 +74,16 @@ class GetPrettyTransformNameTest(parameterized.TestCase): expected_substring="CustomStr", ), dict( - transform=_TestMapWithRepr(), - expected_substring="CustomRepr", + transform=transforms.MapTransform.from_callable(add_one), + expected_substring=( + "MapFromCallable @ .../_src/core/transforms_test.py:" + ), ), ) def test_get_pretty_transform_name(self, transform, expected_substring): @@ -80,5 +92,16 @@ def test_get_pretty_transform_name(self, transform, expected_substring): ) +class MapFromCallableTest(parameterized.TestCase): + + def test_local_function(self): + map_transform = transforms.MapTransform.from_callable(add_one) + self.assertEqual(map_transform.map(1), 2) + + def test_lambda(self): + map_transform = transforms.MapTransform.from_callable(lambda x: x + 1) + self.assertEqual(map_transform.map(1), 2) + + if __name__ == "__main__": absltest.main()