Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions penzai/core/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,42 @@
T = typing.TypeVar("T")


def _shift_negative_indices(
indices: Iterable[int], shift: int
) -> tuple[int, ...]:
"""Adds `shift` to negative indices and leaves non-negative indices unchanged

Can be used to handle negative indices. For example, if we expect indices in
`r = range(6)` and we get `[0, 3, -2]` as input, we can use ::

shift_negative_indices([0, 3, -2], len(r))

to get `(0, 3, 4)`. The same can be achieved in more generality with ::

pz.select((0, 3, -2))
.at_instances_of(int)
.where(lambda i: i < 0)
.apply(lambda i: i + shift)

which is why this method is private to this module

Args:
indices: The integers to shift
shift: The offset to add to negative indices. Usually, this is the largest
index + 1, i.e. the length of the range of indices

Returns:
The indices as a tuple, with negative indices increased by `shift`
"""
maybe_shifted_indices = []
for index in indices:
if index < 0:
maybe_shifted_indices.append(index + shift)
else:
maybe_shifted_indices.append(index)
return tuple(maybe_shifted_indices)


@struct.pytree_dataclass
class SelectionHole(struct.Struct):
"""A hole in a structure, taking the place of a selected subtree.
Expand Down Expand Up @@ -1356,6 +1392,8 @@ def pick_nth_selected(self, n: int | Sequence[int]) -> Selection:
else:
indices = n

indices = _shift_negative_indices(indices, len(self.selected_by_path))

with _wrap_selection_errors(self):
keep = _InProgressSelectionBoundary
new_selected_by_path = collections.OrderedDict({
Expand Down
32 changes: 32 additions & 0 deletions tests/core/selectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ class SELECTED_PART: # pylint: disable=invalid-name

class SelectorsTest(absltest.TestCase):

def test_shift_negative_indices(self):
for input_indices, shift, expected_output in [
((), 1, ()),
([0, 3, -2], len(range(6)), (0, 3, 4)),
]:
self.assertEqual(
penzai.core.selectors._shift_negative_indices(
input_indices, shift=shift
),
expected_output,
)

def test_select(self):
self.assertEqual(
pz.select(make_example_object()),
Expand Down Expand Up @@ -565,6 +577,26 @@ def test_pick_nth_selected(self):
),
[0, 1, 2, 3, SELECTED_PART(value=4), 5, 6, 7, 8, 9],
)
# Test negative indices for `pick_nth_selected`
self.assertEqual(
(
pz.select(list(range(10)))
.at_instances_of(int)
.pick_nth_selected(-2)
.apply(SELECTED_PART)
),
[0, 1, 2, 3, 4, 5, 6, 7, SELECTED_PART(value=8), 9],
)
# Don't select anything if index is out of range
self.assertEqual(
(
pz.select([0, 1, 2])
.at_instances_of(int)
.pick_nth_selected(5)
.apply(SELECTED_PART)
),
[0, 1, 2],
)

def test_invert__example_1(self):
predicate = lambda node: isinstance(node, CustomLeaf) and node.tag <= 10
Expand Down