From 953ab0aeefd08a9e013a30d08358115d94b9683d Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 14 Apr 2025 10:39:19 +0200 Subject: [PATCH 1/4] Allow negative indices in `pick_nth_selected` --- penzai/core/selectors.py | 28 ++++++++++++++++++++++++ tests/core/selectors_test.py | 41 +++++++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/penzai/core/selectors.py b/penzai/core/selectors.py index d89f8e0..c9f75e8 100644 --- a/penzai/core/selectors.py +++ b/penzai/core/selectors.py @@ -39,6 +39,32 @@ T = typing.TypeVar("T") +def shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ...]: + """Adds `shift` to negative indices and leaves non-negative indices unchanged + + Can be used to handle negative indices. For example, if we expect indices in + `r = range(6)` and we get `[0, 3, -2]` as input, we can use + + ```py + shift_negative_indices([0, 3, -2], len(r)) + ``` + + to get `(0, 3, 4)` + + Args: + indices: The integers to shift + shift: The offset to add to negative indices. Usually, this is the largest + index + 1, i.e. the length of the range of indices + """ + maybe_shifted_indices = [] + for index in indices: + if index < 0: + maybe_shifted_indices.append(index + shift) + else: + maybe_shifted_indices.append(index) + return tuple(maybe_shifted_indices) + + @struct.pytree_dataclass class SelectionHole(struct.Struct): """A hole in a structure, taking the place of a selected subtree. @@ -1356,6 +1382,8 @@ def pick_nth_selected(self, n: int | Sequence[int]) -> Selection: else: indices = n + indices = shift_negative_indices(indices, len(self.selected_by_path)) + with _wrap_selection_errors(self): keep = _InProgressSelectionBoundary new_selected_by_path = collections.OrderedDict({ diff --git a/tests/core/selectors_test.py b/tests/core/selectors_test.py index 66eb1a2..f55578e 100644 --- a/tests/core/selectors_test.py +++ b/tests/core/selectors_test.py @@ -20,12 +20,31 @@ import collections import dataclasses -from typing import Any +from typing import Any, Iterable from absl.testing import absltest import jax from penzai import pz import penzai.core.selectors +import pytest + + +@pytest.mark.parametrize( + "input_indices, shift, expected_output", + [ + ((,), 1, (,)), + ([0, 3, -2], len(range(6)), (0, 3, 4)), + ] +) +def test_shift_negative_indices( + input_indices: Iterable[int], + shift: int, + expected_output: tuple[int, ...], +): + assert ( + penzai.core.selectors.shift_negative_indices(input_indices, shift=shift) + == expected_output + ) @dataclasses.dataclass @@ -565,6 +584,26 @@ def test_pick_nth_selected(self): ), [0, 1, 2, 3, SELECTED_PART(value=4), 5, 6, 7, 8, 9], ) + # Test negative indices for `pick_nth_selected` + self.assertEqual( + ( + pz.select(list(range(10))) + .at_instances_of(int) + .pick_nth_selected(-2) + .apply(SELECTED_PART) + ), + [0, 1, 2, 3, 4, 5, 6, 7, SELECTED_PART(value=8), 9], + ) + # Don't select anything if index is out of range + self.assertEqual( + ( + pz.select([0, 1, 2]) + .at_instances_of(int) + .pick_nth_selected(5) + .apply(SELECTED_PART) + ), + [0, 1, 2], + ) def test_invert__example_1(self): predicate = lambda node: isinstance(node, CustomLeaf) and node.tag <= 10 From 62712fa90aa8db6998ce6dc0b6f8daeec6895903 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Mon, 14 Apr 2025 11:05:54 +0200 Subject: [PATCH 2/4] Fix test case --- penzai/core/selectors.py | 18 +++++++++++++++--- tests/core/selectors_test.py | 4 ++-- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/penzai/core/selectors.py b/penzai/core/selectors.py index c9f75e8..cc3be19 100644 --- a/penzai/core/selectors.py +++ b/penzai/core/selectors.py @@ -39,7 +39,7 @@ T = typing.TypeVar("T") -def shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ...]: +def _shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ...]: """Adds `shift` to negative indices and leaves non-negative indices unchanged Can be used to handle negative indices. For example, if we expect indices in @@ -49,12 +49,24 @@ def shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ... shift_negative_indices([0, 3, -2], len(r)) ``` - to get `(0, 3, 4)` + to get `(0, 3, 4)`. The same can be achieved in more generality with + + ```py + pz.select((0, 3, -2)) \ + .at_instances_of(int) \ + .where(lambda i: i < 0) \ + .apply(lambda i: i + shift) + ``` + + which is why this method is private to this module Args: indices: The integers to shift shift: The offset to add to negative indices. Usually, this is the largest index + 1, i.e. the length of the range of indices + + Returns: + The indices as a tuple, with negative indices increased by `shift` """ maybe_shifted_indices = [] for index in indices: @@ -1382,7 +1394,7 @@ def pick_nth_selected(self, n: int | Sequence[int]) -> Selection: else: indices = n - indices = shift_negative_indices(indices, len(self.selected_by_path)) + indices = _shift_negative_indices(indices, len(self.selected_by_path)) with _wrap_selection_errors(self): keep = _InProgressSelectionBoundary diff --git a/tests/core/selectors_test.py b/tests/core/selectors_test.py index f55578e..9a44262 100644 --- a/tests/core/selectors_test.py +++ b/tests/core/selectors_test.py @@ -32,7 +32,7 @@ @pytest.mark.parametrize( "input_indices, shift, expected_output", [ - ((,), 1, (,)), + ((), 1, ()), ([0, 3, -2], len(range(6)), (0, 3, 4)), ] ) @@ -42,7 +42,7 @@ def test_shift_negative_indices( expected_output: tuple[int, ...], ): assert ( - penzai.core.selectors.shift_negative_indices(input_indices, shift=shift) + penzai.core.selectors._shift_negative_indices(input_indices, shift=shift) == expected_output ) From a3fef671e92584f5a9f6bba0a3183393d71394f9 Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 24 Apr 2025 07:29:49 +0000 Subject: [PATCH 3/4] Adress comments on PR --- penzai/core/selectors.py | 4 +++- tests/core/selectors_test.py | 33 +++++++++++++-------------------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/penzai/core/selectors.py b/penzai/core/selectors.py index cc3be19..24f73ca 100644 --- a/penzai/core/selectors.py +++ b/penzai/core/selectors.py @@ -39,7 +39,9 @@ T = typing.TypeVar("T") -def _shift_negative_indices(indices: Iterable[int], shift: int) -> tuple[int, ...]: +def _shift_negative_indices( + indices: Iterable[int], shift: int +) -> tuple[int, ...]: """Adds `shift` to negative indices and leaves non-negative indices unchanged Can be used to handle negative indices. For example, if we expect indices in diff --git a/tests/core/selectors_test.py b/tests/core/selectors_test.py index 9a44262..f296ba0 100644 --- a/tests/core/selectors_test.py +++ b/tests/core/selectors_test.py @@ -20,31 +20,12 @@ import collections import dataclasses -from typing import Any, Iterable +from typing import Any from absl.testing import absltest import jax from penzai import pz import penzai.core.selectors -import pytest - - -@pytest.mark.parametrize( - "input_indices, shift, expected_output", - [ - ((), 1, ()), - ([0, 3, -2], len(range(6)), (0, 3, 4)), - ] -) -def test_shift_negative_indices( - input_indices: Iterable[int], - shift: int, - expected_output: tuple[int, ...], -): - assert ( - penzai.core.selectors._shift_negative_indices(input_indices, shift=shift) - == expected_output - ) @dataclasses.dataclass @@ -75,6 +56,18 @@ class SELECTED_PART: # pylint: disable=invalid-name class SelectorsTest(absltest.TestCase): + def test_shift_negative_indices(self): + for input_indices, shift, expected_output in [ + ((), 1, ()), + ([0, 3, -2], len(range(6)), (0, 3, 4)), + ]: + self.assertEqual( + penzai.core.selectors._shift_negative_indices( + input_indices, shift=shift + ), + expected_output, + ) + def test_select(self): self.assertEqual( pz.select(make_example_object()), From e27824ca9b2ae963c0ffa3466479dc031e6b402b Mon Sep 17 00:00:00 2001 From: "Johannes E. M. Mosig" Date: Thu, 24 Apr 2025 07:42:35 +0000 Subject: [PATCH 4/4] Fix docstring format --- penzai/core/selectors.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/penzai/core/selectors.py b/penzai/core/selectors.py index 24f73ca..3832fd9 100644 --- a/penzai/core/selectors.py +++ b/penzai/core/selectors.py @@ -45,20 +45,16 @@ def _shift_negative_indices( """Adds `shift` to negative indices and leaves non-negative indices unchanged Can be used to handle negative indices. For example, if we expect indices in - `r = range(6)` and we get `[0, 3, -2]` as input, we can use + `r = range(6)` and we get `[0, 3, -2]` as input, we can use :: - ```py - shift_negative_indices([0, 3, -2], len(r)) - ``` + shift_negative_indices([0, 3, -2], len(r)) - to get `(0, 3, 4)`. The same can be achieved in more generality with + to get `(0, 3, 4)`. The same can be achieved in more generality with :: - ```py - pz.select((0, 3, -2)) \ - .at_instances_of(int) \ - .where(lambda i: i < 0) \ - .apply(lambda i: i + shift) - ``` + pz.select((0, 3, -2)) + .at_instances_of(int) + .where(lambda i: i < 0) + .apply(lambda i: i + shift) which is why this method is private to this module