Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,17 @@ def canonicalize_subscripts(self) -> EquivariantPolynomial:
self.inputs, self.outputs, self.polynomial.canonicalize_subscripts()
)

def squeeze_modes(self) -> EquivariantPolynomial:
def squeeze_modes(self, modes: str | None = None) -> EquivariantPolynomial:
"""Squeeze modes that are always 1 in all operations.

Args:
modes (str, optional): The modes to squeeze. If None, squeeze all modes that are always 1.

Returns:
:class:`cue.EquivariantPolynomial <cuequivariance.EquivariantPolynomial>`: Polynomial with squeezed modes.
"""
return EquivariantPolynomial(
self.inputs, self.outputs, self.polynomial.squeeze_modes()
self.inputs, self.outputs, self.polynomial.squeeze_modes(modes)
)

def split_mode(self, mode: str, size: int) -> EquivariantPolynomial:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -683,15 +683,27 @@ def canonicalize_subscripts(self) -> SegmentedPolynomial:
[(ope, stp.canonicalize_subscripts()) for ope, stp in self.operations],
)

def squeeze_modes(self) -> SegmentedPolynomial:
def squeeze_modes(self, modes: str | None = None) -> SegmentedPolynomial:
"""Squeeze modes that are always 1 in all operations.

Args:
modes (str, optional): The modes to squeeze. If None, squeeze all modes that are always 1.

Returns:
:class:`cue.SegmentedPolynomial <cuequivariance.SegmentedPolynomial>`: Polynomial with squeezed modes.

Note:
When ``modes`` is None, all operands with 1-dimensions are squeezed, including unused
operands. When ``modes`` is specified, unused operands (not linked to any operation) are
not squeezed because they don't have mode information to determine which dimensions
correspond to which modes.
"""
ops = [(ope, stp.squeeze_modes()) for ope, stp in self.operations]
inputs = tuple(op.squeeze() for op in self.inputs)
outputs = tuple(op.squeeze() for op in self.outputs)
ops = [(ope, stp.squeeze_modes(modes)) for ope, stp in self.operations]
if modes is not None:
inputs, outputs = self.inputs, self.outputs
else:
inputs = tuple(op.squeeze() for op in self.inputs)
outputs = tuple(op.squeeze() for op in self.outputs)
return SegmentedPolynomial._from_default_operands(inputs, outputs, ops)

def split_mode(self, mode: str, size: int) -> SegmentedPolynomial:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,90 @@ def test_consolidate_squeezes_unused_operands():
assert consolidated.inputs[0].segment_shape == (8,)
assert consolidated.inputs[2].segment_shape == (8,)
assert consolidated.outputs[0].segment_shape == (8,)


def test_squeeze_modes_unused_operand_shape_one():
"""Test squeeze_modes behavior with unused operand that has shape (1,) in all segments."""
# Create an STP with 2 operands: 1 input + 1 output
# Mode "v" has dimension 1 (squeezable), mode "u" has dimension 8
stp = cue.SegmentedTensorProduct.from_subscripts("uv,uv+")
stp.add_segment(0, (8, 1))
stp.add_segment(0, (8, 1))
stp.add_segment(1, (8, 1))
stp.add_segment(1, (8, 1))
stp.add_path(0, 0, c=1.0)
stp.add_path(1, 1, c=2.0)

# Create an unused operand with shape (1,) in all segments
unused_operand = cue.SegmentedOperand(ndim=1, segments=[(1,), (1,), (1,)])

# Create polynomial: input 0 is used, input 1 is unused
# Operation uses input 0 and output 2, so input 1 is unused
poly = cue.SegmentedPolynomial(
[stp.operands[0], unused_operand],
[stp.operands[1]],
[(cue.Operation((0, 2)), stp)],
)

# Before squeeze: unused operand has shape (1,)
assert poly.inputs[1].segment_shape == (1,)
assert poly.inputs[1].ndim == 1

# squeeze_modes() without argument: squeezes ALL 1-dimensions including unused operands
squeezed_all = poly.squeeze_modes()
assert squeezed_all.inputs[1].segment_shape == ()
assert squeezed_all.inputs[1].ndim == 0

# squeeze_modes("v") with a specific mode: unused operand is NOT squeezed
# because we don't know which dimension corresponds to which mode
squeezed_v = poly.squeeze_modes("v")
assert squeezed_v.inputs[1].segment_shape == (1,)
assert squeezed_v.inputs[1].ndim == 1


def test_squeeze_modes_selective():
"""Test that squeeze_modes with modes argument only squeezes specified modes."""
# Create an STP with 3 operands where two modes (u and v) are squeezable (dim=1)
stp = cue.SegmentedTensorProduct.from_subscripts("uvi,uvi,uvi+")

# All segments have u=1, v=1, i=8
stp.add_segment(0, (1, 1, 8))
stp.add_segment(0, (1, 1, 8))
stp.add_segment(1, (1, 1, 8))
stp.add_segment(1, (1, 1, 8))
stp.add_segment(2, (1, 1, 8))
stp.add_segment(2, (1, 1, 8))

stp.add_path(0, 0, 0, c=1.0)
stp.add_path(1, 1, 1, c=2.0)

poly = cue.SegmentedPolynomial.eval_last_operand(stp)

# Before squeeze: all operands have shape (1, 1, 8)
assert poly.inputs[0].segment_shape == (1, 1, 8)
assert poly.inputs[1].segment_shape == (1, 1, 8)
assert poly.outputs[0].segment_shape == (1, 1, 8)

# squeeze_modes() without argument: squeezes both u and v
squeezed_all = poly.squeeze_modes()
assert squeezed_all.inputs[0].segment_shape == (8,)
assert squeezed_all.inputs[1].segment_shape == (8,)
assert squeezed_all.outputs[0].segment_shape == (8,)

# squeeze_modes("v") only squeezes v, leaves u intact
squeezed_v = poly.squeeze_modes("v")
assert squeezed_v.inputs[0].segment_shape == (1, 8)
assert squeezed_v.inputs[1].segment_shape == (1, 8)
assert squeezed_v.outputs[0].segment_shape == (1, 8)

# squeeze_modes("u") only squeezes u, leaves v intact
squeezed_u = poly.squeeze_modes("u")
assert squeezed_u.inputs[0].segment_shape == (1, 8)
assert squeezed_u.inputs[1].segment_shape == (1, 8)
assert squeezed_u.outputs[0].segment_shape == (1, 8)

# squeeze_modes("uv") squeezes both u and v
squeezed_uv = poly.squeeze_modes("uv")
assert squeezed_uv.inputs[0].segment_shape == (8,)
assert squeezed_uv.inputs[1].segment_shape == (8,)
assert squeezed_uv.outputs[0].segment_shape == (8,)