diff --git a/cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py b/cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py index c6bc640..35e8fdb 100644 --- a/cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py +++ b/cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py @@ -354,14 +354,17 @@ def canonicalize_subscripts(self) -> EquivariantPolynomial: self.inputs, self.outputs, self.polynomial.canonicalize_subscripts() ) - def squeeze_modes(self) -> EquivariantPolynomial: + def squeeze_modes(self, modes: str | None = None) -> EquivariantPolynomial: """Squeeze modes that are always 1 in all operations. + Args: + modes (str, optional): The modes to squeeze. If None, squeeze all modes that are always 1. + Returns: :class:`cue.EquivariantPolynomial `: Polynomial with squeezed modes. """ return EquivariantPolynomial( - self.inputs, self.outputs, self.polynomial.squeeze_modes() + self.inputs, self.outputs, self.polynomial.squeeze_modes(modes) ) def split_mode(self, mode: str, size: int) -> EquivariantPolynomial: diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index 6f23fab..d8e624e 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -683,15 +683,27 @@ def canonicalize_subscripts(self) -> SegmentedPolynomial: [(ope, stp.canonicalize_subscripts()) for ope, stp in self.operations], ) - def squeeze_modes(self) -> SegmentedPolynomial: + def squeeze_modes(self, modes: str | None = None) -> SegmentedPolynomial: """Squeeze modes that are always 1 in all operations. + Args: + modes (str, optional): The modes to squeeze. If None, squeeze all modes that are always 1. + Returns: :class:`cue.SegmentedPolynomial `: Polynomial with squeezed modes. + + Note: + When ``modes`` is None, all operands with 1-dimensions are squeezed, including unused + operands. When ``modes`` is specified, unused operands (not linked to any operation) are + not squeezed because they don't have mode information to determine which dimensions + correspond to which modes. """ - ops = [(ope, stp.squeeze_modes()) for ope, stp in self.operations] - inputs = tuple(op.squeeze() for op in self.inputs) - outputs = tuple(op.squeeze() for op in self.outputs) + ops = [(ope, stp.squeeze_modes(modes)) for ope, stp in self.operations] + if modes is not None: + inputs, outputs = self.inputs, self.outputs + else: + inputs = tuple(op.squeeze() for op in self.inputs) + outputs = tuple(op.squeeze() for op in self.outputs) return SegmentedPolynomial._from_default_operands(inputs, outputs, ops) def split_mode(self, mode: str, size: int) -> SegmentedPolynomial: diff --git a/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py index 177185f..5315c9c 100644 --- a/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py +++ b/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py @@ -773,3 +773,90 @@ def test_consolidate_squeezes_unused_operands(): assert consolidated.inputs[0].segment_shape == (8,) assert consolidated.inputs[2].segment_shape == (8,) assert consolidated.outputs[0].segment_shape == (8,) + + +def test_squeeze_modes_unused_operand_shape_one(): + """Test squeeze_modes behavior with unused operand that has shape (1,) in all segments.""" + # Create an STP with 2 operands: 1 input + 1 output + # Mode "v" has dimension 1 (squeezable), mode "u" has dimension 8 + stp = cue.SegmentedTensorProduct.from_subscripts("uv,uv+") + stp.add_segment(0, (8, 1)) + stp.add_segment(0, (8, 1)) + stp.add_segment(1, (8, 1)) + stp.add_segment(1, (8, 1)) + stp.add_path(0, 0, c=1.0) + stp.add_path(1, 1, c=2.0) + + # Create an unused operand with shape (1,) in all segments + unused_operand = cue.SegmentedOperand(ndim=1, segments=[(1,), (1,), (1,)]) + + # Create polynomial: input 0 is used, input 1 is unused + # Operation uses input 0 and output 2, so input 1 is unused + poly = cue.SegmentedPolynomial( + [stp.operands[0], unused_operand], + [stp.operands[1]], + [(cue.Operation((0, 2)), stp)], + ) + + # Before squeeze: unused operand has shape (1,) + assert poly.inputs[1].segment_shape == (1,) + assert poly.inputs[1].ndim == 1 + + # squeeze_modes() without argument: squeezes ALL 1-dimensions including unused operands + squeezed_all = poly.squeeze_modes() + assert squeezed_all.inputs[1].segment_shape == () + assert squeezed_all.inputs[1].ndim == 0 + + # squeeze_modes("v") with a specific mode: unused operand is NOT squeezed + # because we don't know which dimension corresponds to which mode + squeezed_v = poly.squeeze_modes("v") + assert squeezed_v.inputs[1].segment_shape == (1,) + assert squeezed_v.inputs[1].ndim == 1 + + +def test_squeeze_modes_selective(): + """Test that squeeze_modes with modes argument only squeezes specified modes.""" + # Create an STP with 3 operands where two modes (u and v) are squeezable (dim=1) + stp = cue.SegmentedTensorProduct.from_subscripts("uvi,uvi,uvi+") + + # All segments have u=1, v=1, i=8 + stp.add_segment(0, (1, 1, 8)) + stp.add_segment(0, (1, 1, 8)) + stp.add_segment(1, (1, 1, 8)) + stp.add_segment(1, (1, 1, 8)) + stp.add_segment(2, (1, 1, 8)) + stp.add_segment(2, (1, 1, 8)) + + stp.add_path(0, 0, 0, c=1.0) + stp.add_path(1, 1, 1, c=2.0) + + poly = cue.SegmentedPolynomial.eval_last_operand(stp) + + # Before squeeze: all operands have shape (1, 1, 8) + assert poly.inputs[0].segment_shape == (1, 1, 8) + assert poly.inputs[1].segment_shape == (1, 1, 8) + assert poly.outputs[0].segment_shape == (1, 1, 8) + + # squeeze_modes() without argument: squeezes both u and v + squeezed_all = poly.squeeze_modes() + assert squeezed_all.inputs[0].segment_shape == (8,) + assert squeezed_all.inputs[1].segment_shape == (8,) + assert squeezed_all.outputs[0].segment_shape == (8,) + + # squeeze_modes("v") only squeezes v, leaves u intact + squeezed_v = poly.squeeze_modes("v") + assert squeezed_v.inputs[0].segment_shape == (1, 8) + assert squeezed_v.inputs[1].segment_shape == (1, 8) + assert squeezed_v.outputs[0].segment_shape == (1, 8) + + # squeeze_modes("u") only squeezes u, leaves v intact + squeezed_u = poly.squeeze_modes("u") + assert squeezed_u.inputs[0].segment_shape == (1, 8) + assert squeezed_u.inputs[1].segment_shape == (1, 8) + assert squeezed_u.outputs[0].segment_shape == (1, 8) + + # squeeze_modes("uv") squeezes both u and v + squeezed_uv = poly.squeeze_modes("uv") + assert squeezed_uv.inputs[0].segment_shape == (8,) + assert squeezed_uv.inputs[1].segment_shape == (8,) + assert squeezed_uv.outputs[0].segment_shape == (8,)