diff --git a/penzai/core/selectors.py b/penzai/core/selectors.py index d89f8e0..3832fd9 100644 --- a/penzai/core/selectors.py +++ b/penzai/core/selectors.py @@ -39,6 +39,42 @@ T = typing.TypeVar("T") +def _shift_negative_indices( + indices: Iterable[int], shift: int +) -> tuple[int, ...]: + """Adds `shift` to negative indices and leaves non-negative indices unchanged + + Can be used to handle negative indices. For example, if we expect indices in + `r = range(6)` and we get `[0, 3, -2]` as input, we can use :: + + shift_negative_indices([0, 3, -2], len(r)) + + to get `(0, 3, 4)`. The same can be achieved in more generality with :: + + pz.select((0, 3, -2)) + .at_instances_of(int) + .where(lambda i: i < 0) + .apply(lambda i: i + shift) + + which is why this method is private to this module + + Args: + indices: The integers to shift + shift: The offset to add to negative indices. Usually, this is the largest + index + 1, i.e. the length of the range of indices + + Returns: + The indices as a tuple, with negative indices increased by `shift` + """ + maybe_shifted_indices = [] + for index in indices: + if index < 0: + maybe_shifted_indices.append(index + shift) + else: + maybe_shifted_indices.append(index) + return tuple(maybe_shifted_indices) + + @struct.pytree_dataclass class SelectionHole(struct.Struct): """A hole in a structure, taking the place of a selected subtree. @@ -1356,6 +1392,8 @@ def pick_nth_selected(self, n: int | Sequence[int]) -> Selection: else: indices = n + indices = _shift_negative_indices(indices, len(self.selected_by_path)) + with _wrap_selection_errors(self): keep = _InProgressSelectionBoundary new_selected_by_path = collections.OrderedDict({ diff --git a/tests/core/selectors_test.py b/tests/core/selectors_test.py index 66eb1a2..f296ba0 100644 --- a/tests/core/selectors_test.py +++ b/tests/core/selectors_test.py @@ -56,6 +56,18 @@ class SELECTED_PART: # pylint: disable=invalid-name class SelectorsTest(absltest.TestCase): + def test_shift_negative_indices(self): + for input_indices, shift, expected_output in [ + ((), 1, ()), + ([0, 3, -2], len(range(6)), (0, 3, 4)), + ]: + self.assertEqual( + penzai.core.selectors._shift_negative_indices( + input_indices, shift=shift + ), + expected_output, + ) + def test_select(self): self.assertEqual( pz.select(make_example_object()), @@ -565,6 +577,26 @@ def test_pick_nth_selected(self): ), [0, 1, 2, 3, SELECTED_PART(value=4), 5, 6, 7, 8, 9], ) + # Test negative indices for `pick_nth_selected` + self.assertEqual( + ( + pz.select(list(range(10))) + .at_instances_of(int) + .pick_nth_selected(-2) + .apply(SELECTED_PART) + ), + [0, 1, 2, 3, 4, 5, 6, 7, SELECTED_PART(value=8), 9], + ) + # Don't select anything if index is out of range + self.assertEqual( + ( + pz.select([0, 1, 2]) + .at_instances_of(int) + .pick_nth_selected(5) + .apply(SELECTED_PART) + ), + [0, 1, 2], + ) def test_invert__example_1(self): predicate = lambda node: isinstance(node, CustomLeaf) and node.tag <= 10