diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml deleted file mode 100644 index e7abad9f..00000000 --- a/.github/workflows/coverage.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: Coverage on PR - -on: - pull_request: - branches: [ "main", "release" ] - - -jobs: - cuequivariance: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install --upgrade uv - python -m uv pip install pytest pytest-cov - python -m uv pip install ./cuequivariance - - name: Run coverage - run: | - pytest --cov=cuequivariance cuequivariance > coverage.txt - coverage_pr=$(cat coverage.txt | grep TOTAL | awk '{print $4}' | sed 's/%//') - echo "Coverage on PR branch: $coverage_pr" - - git fetch origin ${{ github.event.pull_request.base.ref }}:${{ github.event.pull_request.base.ref }} - git checkout ${{ github.event.pull_request.base.ref }} - pytest --cov=cuequivariance cuequivariance > coverage.txt - coverage_target=$(cat coverage.txt | grep TOTAL | awk '{print $4}' | sed 's/%//') - echo "Coverage on target branch: $coverage_target" - - if [ $coverage_pr -lt $coverage_target ]; then - echo "Coverage on PR branch is lower than on target branch" - echo "Coverage on PR branch: $coverage_pr" - echo "Coverage on target branch: $coverage_target" - exit 1 - fi diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 48aaca35..196060f4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,18 +27,22 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade uv python -m uv pip install pytest - python -m uv pip install ./cuequivariance + python -m uv pip install -e ./cuequivariance - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance cuequivariance-jax: - - runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.10", "3.12"] + include: + - runner: "ubuntu-latest" + python-version: "3.10" + - runner: "self-hosted" + python-version: "3.12" + + runs-on: ${{ matrix.runner }} steps: - uses: actions/checkout@v4 @@ -50,9 +54,12 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install --upgrade uv - python -m uv pip install pytest - python -m uv pip install ./cuequivariance - python -m uv pip install ./cuequivariance_jax + python -m uv pip install pytest "jax[cuda12]" + python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch + python -m uv pip install -e ./cuequivariance + python -m uv pip install --no-deps -e ./cuequivariance_jax + python -c "import cuequivariance; print('cue', cuequivariance.__version__)" + python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)" - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_jax @@ -80,8 +87,11 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade uv python -m uv pip install pytest torch e3nn - python -m uv pip install ./cuequivariance --force-reinstall - python -m uv pip install ./cuequivariance_torch --force-reinstall + python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch + python -m uv pip install -e ./cuequivariance + python -m uv pip install --no-deps -e ./cuequivariance_torch + python -c "import cuequivariance; print('cue', cuequivariance.__version__)" + python -c "import cuequivariance_torch; print('cuet', cuequivariance_torch.__version__)" - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_torch diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 265ce6c5..c7f8ceaf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.3 + rev: v0.9.9 hooks: - id: ruff args: ["--extend-select", "I", "--fix"] diff --git a/CHANGELOG.md b/CHANGELOG.md index b2bed14d..62f07f0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ ## Latest Changes +### Breaking Changes +- Rename `SegmentedTensorProduct.flop_cost` in `flops` +- Rename `SegmentedTensorProduct.memory_cost` in `memory` +- Removed `IrrepsArray` in favor of `RepArray` +- Change folder structure of cuequivariance and cuequivariance-jax. Now the main subfolders are `segmented_polynomials` and `group_theory` +- Deprecate `cue.EquivariantTensorProduct` in favor of `cue.EquivariantPolynomial` +- The descriptors return `cue.EquivariantPolynomial` instead of `cue.EquivariantTensorProduct` + +# Added +- Class `cue.SegmentedOperand`, `cue.SegmentedPolynomial` +- Class `cue.EquivariantPolynomial` that contains a `cue.SegmentedPolynomial` and the `cue.Rep` of its inputs and outputs + + ## 0.3.0 (2025-03-05) The main changes are: diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index ed380e9c..db540153 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -18,7 +18,14 @@ importlib.resources.files(__package__).joinpath("VERSION").read_text().strip() ) -from cuequivariance.representation import ( +from .segmented_polynomials import ( + Operation, + SegmentedOperand, + SegmentedTensorProduct, + SegmentedPolynomial, +) + +from .group_theory import ( Rep, Irrep, clebsch_gordan, @@ -27,9 +34,6 @@ SU2, SO3, O3, -) - -from cuequivariance.irreps_array import ( get_irrep_scope, MulIrrep, Irreps, @@ -45,18 +49,20 @@ reduced_tensor_product_basis, reduced_symmetric_tensor_product_basis, reduced_antisymmetric_tensor_product_basis, + EquivariantPolynomial, + EquivariantTensorProduct, # deprecated ) -from cuequivariance.segmented_tensor_product import SegmentedTensorProduct -from cuequivariance.equivariant_tensor_product import EquivariantTensorProduct -from cuequivariance.operation import Operation - -from cuequivariance import ( - segmented_tensor_product, - descriptors, -) +from cuequivariance import segmented_polynomials as segmented_polynomials +from cuequivariance import group_theory as group_theory +from cuequivariance.group_theory import descriptors as descriptors __all__ = [ + "__version__", + "Operation", + "SegmentedOperand", + "SegmentedTensorProduct", + "SegmentedPolynomial", "Rep", "Irrep", "clebsch_gordan", @@ -80,9 +86,9 @@ "reduced_tensor_product_basis", "reduced_symmetric_tensor_product_basis", "reduced_antisymmetric_tensor_product_basis", - "SegmentedTensorProduct", + "EquivariantPolynomial", "EquivariantTensorProduct", - "Operation", - "segmented_tensor_product", + "segmented_polynomials", + "group_theory", "descriptors", ] diff --git a/cuequivariance/cuequivariance/misc/linalg.py b/cuequivariance/cuequivariance/etc/linalg.py similarity index 100% rename from cuequivariance/cuequivariance/misc/linalg.py rename to cuequivariance/cuequivariance/etc/linalg.py diff --git a/cuequivariance/cuequivariance/etc/permutations.py b/cuequivariance/cuequivariance/etc/permutations.py new file mode 100644 index 00000000..9d3fda6a --- /dev/null +++ b/cuequivariance/cuequivariance/etc/permutations.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence + + +def compose_permutations(p1: tuple[int, ...], p2: tuple[int, ...]) -> tuple[int, ...]: + """Compose two permutations""" + return tuple(p1[p2[i]] for i in range(len(p1))) + + +def inverse_permutation(p: tuple[int, ...]) -> tuple[int, ...]: + """Inverse a permutation""" + return tuple(p.index(i) for i in range(len(p))) + + +def generate_permutations_from( + generators: Sequence[tuple[int, ...]], +) -> set[tuple[int, ...]]: + """Generate all permutations from a list of generators""" + result = set(generators) + + while True: + n = len(result) + new_result = result.copy() + for g in result: + for h in result: + new_result.add(compose_permutations(g, h)) + if len(new_result) == n: + break + result = new_result + + return result diff --git a/cuequivariance/cuequivariance/misc/sympy_utils.py b/cuequivariance/cuequivariance/etc/sympy_utils.py similarity index 100% rename from cuequivariance/cuequivariance/misc/sympy_utils.py rename to cuequivariance/cuequivariance/etc/sympy_utils.py diff --git a/cuequivariance/cuequivariance/group_theory/__init__.py b/cuequivariance/cuequivariance/group_theory/__init__.py new file mode 100644 index 00000000..4e221b98 --- /dev/null +++ b/cuequivariance/cuequivariance/group_theory/__init__.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .representations import ( + Rep, + Irrep, + clebsch_gordan, + selection_rule_product, + selection_rule_power, + SU2, + SO3, + O3, +) + +from .irreps_array import ( + get_irrep_scope, + MulIrrep, + Irreps, + IrrepsLayout, + mul_ir, + ir_mul, + IrrepsAndLayout, + get_layout_scope, + assume, + NumpyIrrepsArray, + from_segments, + concatenate, + reduced_tensor_product_basis, + reduced_symmetric_tensor_product_basis, + reduced_antisymmetric_tensor_product_basis, +) + +from .equivariant_polynomial import EquivariantPolynomial +from .equivariant_tensor_product import EquivariantTensorProduct + + +__all__ = [ + "Rep", + "Irrep", + "clebsch_gordan", + "selection_rule_product", + "selection_rule_power", + "SU2", + "SO3", + "O3", + "get_irrep_scope", + "MulIrrep", + "Irreps", + "IrrepsLayout", + "mul_ir", + "ir_mul", + "IrrepsAndLayout", + "get_layout_scope", + "assume", + "NumpyIrrepsArray", + "from_segments", + "concatenate", + "reduced_tensor_product_basis", + "reduced_symmetric_tensor_product_basis", + "reduced_antisymmetric_tensor_product_basis", + "EquivariantPolynomial", + "EquivariantTensorProduct", +] diff --git a/cuequivariance/cuequivariance/descriptors/__init__.py b/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/descriptors/__init__.py rename to cuequivariance/cuequivariance/group_theory/descriptors/__init__.py diff --git a/cuequivariance/cuequivariance/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py similarity index 83% rename from cuequivariance/cuequivariance/descriptors/irreps_tp.py rename to cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py index 59b07cd4..de2a7f2b 100644 --- a/cuequivariance/cuequivariance/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py @@ -16,13 +16,12 @@ from typing import Optional, Sequence import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp -from cuequivariance.irreps_array.irrep_utils import into_list_of_irrep +from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep def fully_connected_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``weights[uvw],lhs[iu],rhs[jv],output[kw]`` @@ -39,7 +38,7 @@ def fully_connected_tensor_product( irreps3 (Irreps): Irreps of the output. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the fully connected tensor product. + :class:`cue.EquivariantPolynomial `: Descriptor of the fully connected tensor product. Examples: >>> cue.descriptors.fully_connected_tensor_product( @@ -47,13 +46,14 @@ def fully_connected_tensor_product( ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), ... ) - EquivariantTensorProduct(61440x0 x 16x0+16x1+16x2 x 16x0+16x1+16x2 -> 16x0+16x1+16x2) + ╭ a=61440x0 b=16x0+16x1+16x2 c=16x0+16x1+16x2 -> D=16x0+16x1+16x2 + ╰─ [ijk]·a[uvw]·b[iu]·c[jv]➜D[kw] ─ num_paths=15 i={1, 3, 5} j={1, 3, 5} k={1, 3, 5} u=16 v=16 w=16 Where ``61440x0`` are the 61440 weights needed to mix all the inputs with all the outputs. """ G = irreps1.irrep_class - d = stp.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") for mul, ir in irreps1: d.add_segment(1, (ir.dim, mul)) @@ -73,14 +73,14 @@ def fully_connected_tensor_product( d.add_path((mul1, mul2, mul3), i1, i2, i3, c=cg) d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -88,7 +88,7 @@ def full_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``lhs[iu],rhs[jv],output[kuv]`` @@ -102,14 +102,14 @@ def full_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the full tensor product. + :class:`cue.EquivariantPolynomial `: Descriptor of the full tensor product. """ G = irreps1.irrep_class if irreps3_filter is not None: irreps3_filter = into_list_of_irrep(G, irreps3_filter) - d = stp.SegmentedTensorProduct.from_subscripts("iu,jv,kuv+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iu,jv,kuv+ijk") for mul, ir in irreps1: d.add_segment(0, (ir.dim, mul)) @@ -136,13 +136,13 @@ def full_tensor_product( d = d.permute_segments(2, inv) d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -150,7 +150,7 @@ def channelwise_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``weights[uv],lhs[iu],rhs[jv],output[kuv]`` @@ -166,14 +166,14 @@ def channelwise_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the channelwise tensor product. + :class:`cue.EquivariantPolynomial `: Descriptor of the channelwise tensor product. """ G = irreps1.irrep_class if irreps3_filter is not None: irreps3_filter = into_list_of_irrep(G, irreps3_filter) - d = stp.SegmentedTensorProduct.from_subscripts("uv,iu,jv,kuv+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("uv,iu,jv,kuv+ijk") for mul, ir in irreps1: d.add_segment(1, (ir.dim, mul)) @@ -201,14 +201,14 @@ def channelwise_tensor_product( d = d.permute_segments(3, inv) d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -245,7 +245,7 @@ def elementwise_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``lhs[iu],rhs[ju],output[ku]`` @@ -257,7 +257,7 @@ def elementwise_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the elementwise tensor product. + :class:`cue.EquivariantPolynomial `: Descriptor of the elementwise tensor product. """ G = irreps1.irrep_class @@ -268,7 +268,7 @@ def elementwise_tensor_product( irreps1_cut, irreps2_cut = _align_two_irreps(irreps1, irreps2, cue.ir_mul) - d = stp.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") irreps3 = [] for (mul, ir1), (_, ir2) in zip(irreps1_cut, irreps2_cut): @@ -286,19 +286,17 @@ def elementwise_tensor_product( irreps3 = cue.Irreps(G, irreps3) d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) -def linear( - irreps_in: cue.Irreps, irreps_out: cue.Irreps -) -> cue.EquivariantTensorProduct: +def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPolynomial: """ subscripts: ``weights[uv],input[iu],output[iv]`` @@ -309,9 +307,9 @@ def linear( irreps_out (Irreps): Irreps of the output. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the linear transformation. + :class:`cue.EquivariantPolynomial `: Descriptor of the linear transformation. """ - d = stp.SegmentedTensorProduct.from_subscripts("uv_iu_iv") + d = cue.SegmentedTensorProduct.from_subscripts("uv_iu_iv") for mul, ir in irreps_in: d.add_segment(1, (ir.dim, mul)) for mul, ir in irreps_out: @@ -325,11 +323,11 @@ def linear( d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps_in, cue.ir_mul), cue.IrrepsAndLayout(irreps_out, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) diff --git a/cuequivariance/cuequivariance/descriptors/rotations.py b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py similarity index 76% rename from cuequivariance/cuequivariance/descriptors/rotations.py rename to cuequivariance/cuequivariance/group_theory/descriptors/rotations.py index 2257c1b5..f6525970 100644 --- a/cuequivariance/cuequivariance/descriptors/rotations.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py @@ -17,12 +17,11 @@ import numpy as np import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp def fixed_axis_angle_rotation( irreps: cue.Irreps, axis: np.ndarray, angle: float -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``input[u],output[u]`` @@ -30,7 +29,7 @@ def fixed_axis_angle_rotation( """ assert irreps.irrep_class in [cue.SO3, cue.O3] - d = stp.SegmentedTensorProduct.from_subscripts("iu_ju+ij") + d = cue.SegmentedTensorProduct.from_subscripts("iu_ju+ij") for mul, ir in irreps: # Note the transpose @@ -42,18 +41,18 @@ def fixed_axis_angle_rotation( ) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) def yxy_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``gamma[],beta[],alpha[],input[u],output[u]`` @@ -69,12 +68,14 @@ def yxy_rotation( where l is the maximum L in the input and output irreps. """ - cbio = xy_rotation(irreps, lmax).d # gamma, beta, input, A - aio = y_rotation(irreps, lmax).d # alpha, A, output - cbiao = stp.dot(cbio, aio, (3, 1)) # gamma, beta, input, alpha, output + # gamma, beta, input, A + cbio = xy_rotation(irreps, lmax).polynomial.operations[0][1] + aio = y_rotation(irreps, lmax).polynomial.operations[0][1] # alpha, A, output + cbiao = cue.segmented_polynomials.dot( + cbio, aio, (3, 1) + ) # gamma, beta, input, alpha, output cbaio = cbiao.move_operand(2, 3) # gamma, beta, alpha, input, output - return cue.EquivariantTensorProduct( - cbaio, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[1].size), cue.ir_mul), @@ -82,35 +83,36 @@ def yxy_rotation( cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(cbaio), ) def xy_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``gamma[],beta[],input[u],output[u]`` Rotation around the y-axis followed by rotation around the x-axis """ - cio = y_rotation(irreps, lmax).d # gamma, input, A - bio = x_rotation(irreps, lmax).d # beta, A, output - cibo = stp.dot(cio, bio, (2, 1)) # gamma, input, beta, output + cio = y_rotation(irreps, lmax).polynomial.operations[0][1] # gamma, input, A + bio = x_rotation(irreps, lmax).polynomial.operations[0][1] # beta, A, output + cibo = cue.segmented_polynomials.dot(cio, bio, (2, 1)) # gamma, input, beta, output cbio = cibo.move_operand(1, 2) # gamma, beta, input, output - return cue.EquivariantTensorProduct( - cbio, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(cbio), ) def yx_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``phi[],theta[],input[u],output[u]`` @@ -118,22 +120,22 @@ def yx_rotation( """ cio = x_rotation(irreps, lmax).d bio = y_rotation(irreps, lmax).d - cibo = stp.dot(cio, bio, (2, 1)) + cibo = cue.segmented_polynomials.dot(cio, bio, (2, 1)) cbio = cibo.move_operand(1, 2) - return cue.EquivariantTensorProduct( - cbio, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(cbio), ) def y_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``phi[],input[u],output[u]`` @@ -150,7 +152,7 @@ def y_rotation( if lmax is None: lmax = max(ir.l for _, ir in irreps) - d = stp.SegmentedTensorProduct.from_subscripts("i_ju_ku+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("i_ju_ku+ijk") phc = d.add_segment( 0, (lmax,) ) # cos(th * lmax), cos(th * (lmax - 1)), ..., cos(th) @@ -190,19 +192,19 @@ def y_rotation( d.add_path(phc, slo, slo, c=c) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) def x_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``phi[],input[u],output[u]`` @@ -216,35 +218,39 @@ def x_rotation( """ assert irreps.irrep_class in [cue.SO3, cue.O3] - dy = y_rotation(irreps, lmax).d - dz90 = fixed_axis_angle_rotation(irreps, np.array([0.0, 0.0, 1.0]), np.pi / 2.0).d - d = stp.dot(stp.dot(dy, dz90, (1, 1)), dz90, (1, 1)) + dy = y_rotation(irreps, lmax).polynomial.operations[0][1] + dz90 = fixed_axis_angle_rotation( + irreps, np.array([0.0, 0.0, 1.0]), np.pi / 2.0 + ).polynomial.operations[0][1] + d = cue.segmented_polynomials.dot( + cue.segmented_polynomials.dot(dy, dz90, (1, 1)), dz90, (1, 1) + ) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) -def inversion(irreps: cue.Irreps) -> cue.EquivariantTensorProduct: +def inversion(irreps: cue.Irreps) -> cue.EquivariantPolynomial: """ subsrcipts: ``input[u],output[u]`` """ - d = stp.SegmentedTensorProduct.from_subscripts("iu_ju+ji") + d = cue.SegmentedTensorProduct.from_subscripts("iu_ju+ji") for mul, ir in irreps: assert len(ir.H) == 1 H = ir.H[0] assert np.allclose(H @ H, np.eye(ir.dim), atol=1e-6) d.add_path(None, None, c=H, dims={"u": mul}) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py similarity index 65% rename from cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py rename to cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index 9f19ff8a..6f8b15b1 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -14,18 +14,18 @@ # limitations under the License. from functools import cache -import sympy as sp +import sympy import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp -from cuequivariance.misc.sympy_utils import sqrtQarray_to_sympy +from cuequivariance.etc.sympy_utils import sqrtQarray_to_sympy def spherical_harmonics( ir_vec: cue.Irrep, ls: list[int], layout: cue.IrrepsLayout = cue.ir_mul -) -> cue.EquivariantTensorProduct: - """ - subscripts: ``vector[],...,vector[],Yl[]`` +) -> cue.EquivariantPolynomial: + """Polynomial descriptor for the spherical harmonics. + + Subscripts: ``vector[],...,vector[],Yl[]`` Args: ir_vec (Irrep): irrep of the input vector, for example ``cue.SO3(1)``. @@ -33,14 +33,17 @@ def spherical_harmonics( layout (IrrepsLayout, optional): layout of the output. Defaults to ``cue.ir_mul``. Returns: - :class:`cue.EquivariantTensorProduct `: The descriptor. + :class:`cue.EquivariantPolynomial `: The descriptor. - Examples: + Example: >>> spherical_harmonics(cue.SO3(1), [0, 1, 2]) - EquivariantTensorProduct((1)^(0..2) -> 0+1+2) + ╭ a=1 -> B=0+1+2 + │ []➜B[] ───────── num_paths=1 + │ []·a[]➜B[] ───── num_paths=3 + ╰─ []·a[]·a[]➜B[] ─ num_paths=11 """ if len(ls) != 1: - return cue.EquivariantTensorProduct.stack( + return cue.EquivariantPolynomial.stack( [spherical_harmonics(ir_vec, [ell], layout) for ell in ls], [False, True] ) @@ -48,18 +51,26 @@ def spherical_harmonics( ir, formula = sympy_spherical_harmonics(ir_vec, ell) assert ir_vec.dim == 3 - d = stp.SegmentedTensorProduct.empty_segments([3] * ell + [ir.dim]) + d = cue.SegmentedTensorProduct.empty_segments([3] * ell + [ir.dim]) for i in range(ir.dim): - for degrees, coeff in sp.Poly(formula[i], sp.symbols("x:3")).as_dict().items(): + for degrees, coeff in ( + sympy.Poly(formula[i], sympy.symbols("x:3")).as_dict().items() + ): indices = poly_degrees_to_path_indices(degrees) d.add_path(*indices, i, c=coeff) - return cue.EquivariantTensorProduct( - [d], + d = d.symmetrize_operands(range(ell), force=True) + + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul), cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul), ], + cue.SegmentedPolynomial( + [cue.SegmentedOperand([()] * 3)], + [cue.SegmentedOperand([()] * ir.dim)], + [(cue.Operation([0] * ell + [1]), d)], + ), ) @@ -73,14 +84,14 @@ def poly_degrees_to_path_indices(degrees: tuple[int, ...]) -> tuple[int, ...]: @cache def sympy_spherical_harmonics( ir_vec: cue.Irrep, ell: int -) -> tuple[cue.Irrep, sp.Array]: +) -> tuple[cue.Irrep, sympy.Array]: if ell == 0: - return ir_vec.trivial(), sp.Array([1]) + return ir_vec.trivial(), sympy.Array([1]) if ell == 1: assert ir_vec.dim == 3 - x = sp.symbols("x:3") - return ir_vec, sp.sqrt(3) * sp.Array([x[0], x[1], x[2]]) + x = sympy.symbols("x:3") + return ir_vec, sympy.sqrt(3) * sympy.Array([x[0], x[1], x[2]]) l2 = ell // 2 l1 = ell - l2 @@ -88,11 +99,11 @@ def sympy_spherical_harmonics( ir2, yl2 = sympy_spherical_harmonics(ir_vec, l2) ir = sorted(cue.selection_rule_product(ir1, ir2))[-1] - def sh_var(ir: cue.Irrep, ell: int) -> list[sp.Symbol]: - return [sp.symbols(f"sh{ell}_{m}") for m in range(ir.dim)] + def sh_var(ir: cue.Irrep, ell: int) -> list[sympy.Symbol]: + return [sympy.symbols(f"sh{ell}_{m}") for m in range(ir.dim)] cg = sqrtQarray_to_sympy(ir_vec.clebsch_gordan(ir1, ir2, ir).squeeze(0)) - yl = sp.Array( + yl = sympy.Array( [ sum( sh_var(ir1, l1)[i] * sh_var(ir2, l2)[j] * cg[i, j, k] @@ -106,7 +117,7 @@ def sh_var(ir: cue.Irrep, ell: int) -> list[sp.Symbol]: y = yl.subs(zip(sh_var(ir1, l1), yl1)).subs(zip(sh_var(ir2, l2), yl2)) cst = y.subs({"x0": 0, "x1": 1, "x2": 0}) - norm = sp.sqrt(sum(cst.applyfunc(lambda x: x**2))) + norm = sympy.sqrt(sum(cst.applyfunc(lambda x: x**2))) - y = sp.sqrt(sp.Integer(ir.dim)) * y / norm - return ir, sp.simplify(y) + y = sympy.sqrt(sympy.Integer(ir.dim)) * y / norm + return ir, sympy.simplify(y) diff --git a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py similarity index 74% rename from cuequivariance/cuequivariance/descriptors/symmetric_contractions.py rename to cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py index 09b13749..9a94bec4 100644 --- a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py @@ -13,44 +13,44 @@ # See the License for the specific language governing permissions and # limitations under the License. import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp def symmetric_contraction( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: list[int], -) -> cue.EquivariantTensorProduct: - r""" - subscripts: ``weights[u],input[u],output[u]`` - - Construct the descriptor for a symmetric contraction. +) -> cue.EquivariantPolynomial: + """Construct the descriptor for a symmetric contraction. The symmetric contraction is a weighted sum of the input contracted with itself degree times. + Subscripts: ``weights[u],input[u],output[u]`` + Args: irreps_in (Irreps): The input irreps, the multiplicity are treated in parallel. irreps_out (Irreps): The output irreps. - degree (int): The degree of the symmetric contraction. + degrees (list[int]): List of degrees for the symmetric contractions. Returns: - :class:`cue.EquivariantTensorProduct `: - The descriptor of the symmetric contraction. + EquivariantPolynomial: The descriptor of the symmetric contraction. The operands are the weights, the input degree times and the output. - Examples: + Example: >>> cue.descriptors.symmetric_contraction( ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), ... 16 * cue.Irreps("SO3", "0 + 1"), ... [1, 2, 3] ... ) - EquivariantTensorProduct(32x0+80x0+176x0 x (16x0+16x1+16x2)^(1..3) -> 16x0+16x1) + ╭ a=32x0+80x0+176x0 b=16x0+16x1+16x2 -> C=16x0+16x1 + │ []·a[u]·b[u]➜C[u] ─────────── num_paths=4 u=16 + │ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=37 u=16 + ╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=437 u=16 Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). """ degrees = list(degrees) if len(degrees) != 1: - return cue.EquivariantTensorProduct.stack( + return cue.EquivariantPolynomial.stack( [ symmetric_contraction(irreps_in, irreps_out, [degree]) for degree in degrees @@ -69,8 +69,10 @@ def symmetric_contraction( input_operands = range(1, degree + 1) output_operand = degree + 1 + input_operand = cue.SegmentedOperand(ndim=1, segments=[(mul,)] * irreps_in.dim) + if degree == 0: - d = stp.SegmentedTensorProduct.from_subscripts("i_i") + d = cue.SegmentedTensorProduct.from_subscripts("i_i") for _, ir in irreps_out: if not ir.is_scalar(): d.add_segment(output_operand, {"i": ir.dim}) @@ -80,7 +82,7 @@ def symmetric_contraction( else: abc = "abcdefgh"[:degree] - d = stp.SegmentedTensorProduct.from_subscripts( + d = cue.SegmentedTensorProduct.from_subscripts( f"w_{'_'.join(f'{a}' for a in abc)}_i+{abc}iw" ) @@ -102,11 +104,18 @@ def symmetric_contraction( d = d.flatten_coefficient_modes() d = d.append_modes_to_all_operands("u", {"u": mul}) - return cue.EquivariantTensorProduct( - [d], + for i in input_operands: + assert d.operands[i] == input_operand + + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul), ], + cue.SegmentedPolynomial( + [d.operands[0], input_operand], + [d.operands[-1]], + [(cue.Operation([0] + [1] * degree + [2]), d)], + ), ) diff --git a/cuequivariance/cuequivariance/descriptors/transposition.py b/cuequivariance/cuequivariance/group_theory/descriptors/transposition.py similarity index 69% rename from cuequivariance/cuequivariance/descriptors/transposition.py rename to cuequivariance/cuequivariance/group_theory/descriptors/transposition.py index ae671164..a7ae0b0e 100644 --- a/cuequivariance/cuequivariance/descriptors/transposition.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/transposition.py @@ -17,20 +17,17 @@ def transpose( irreps: cue.Irreps, source: cue.IrrepsLayout, target: cue.IrrepsLayout -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """Transpose the irreps layout of a tensor.""" d = cue.SegmentedTensorProduct( - operands=[ - cue.segmented_tensor_product.Operand( - subscripts="ui" if source == cue.mul_ir else "iu" - ), - cue.segmented_tensor_product.Operand( - subscripts="ui" if target == cue.mul_ir else "iu" - ), + operands_and_subscripts=[ + (cue.SegmentedOperand(ndim=2), "ui" if source == cue.mul_ir else "iu"), + (cue.SegmentedOperand(ndim=2), "ui" if target == cue.mul_ir else "iu"), ] ) for mul, ir in irreps: d.add_path(None, None, c=1, dims={"u": mul, "i": ir.dim}) - return cue.EquivariantTensorProduct( - d, [cue.IrrepsAndLayout(irreps, source), cue.IrrepsAndLayout(irreps, target)] + return cue.EquivariantPolynomial( + [cue.IrrepsAndLayout(irreps, source), cue.IrrepsAndLayout(irreps, target)], + cue.SegmentedPolynomial.eval_last_operand(d), ) diff --git a/cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py b/cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py new file mode 100644 index 00000000..f2310ac1 --- /dev/null +++ b/cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses + +import numpy as np + +import cuequivariance as cue + + +@dataclasses.dataclass(init=False, frozen=True) +class EquivariantPolynomial: + """A polynomial representation with equivariance constraints. + + This class extends SegmentedPolynomial by incorporating information about the group representations + of each input and output tensor. It ensures that operations performed by the polynomial respect + the equivariance constraints defined by these representations, making it suitable for building + equivariant neural networks. + + Args: + operands (list[cue.Rep]): Group representations for all operands (inputs and outputs). + polynomial (cue.SegmentedPolynomial): The underlying polynomial transformation. + """ + + operands: tuple[cue.Rep, ...] + polynomial: cue.SegmentedPolynomial + + def __init__(self, operands: list[cue.Rep], polynomial: cue.SegmentedPolynomial): + assert isinstance(polynomial, cue.SegmentedPolynomial) + object.__setattr__(self, "operands", tuple(operands)) + object.__setattr__(self, "polynomial", polynomial) + if len(self.operands) != self.polynomial.num_operands: + raise ValueError( + f"Number of operands {len(self.operands)} must equal the number of inputs" + f" {self.polynomial.num_inputs} plus the number of outputs {self.polynomial.num_outputs}" + ) + for rep, ope in zip(self.operands, self.polynomial.operands): + assert ope.size == rep.dim, ( + f"{ope} incompatible with {rep}. {ope.size=} != {rep.dim=}" + ) + + def __hash__(self) -> int: + return hash((self.operands, self.polynomial)) + + def __eq__(self, value) -> bool: + assert isinstance(value, EquivariantPolynomial) + return self.operands == value.operands and self.polynomial == value.polynomial + + def __lt__(self, value) -> bool: + assert isinstance(value, EquivariantPolynomial) + return ( + self.num_inputs, + self.num_outputs, + self.operands, + self.polynomial, + ) < ( + value.num_inputs, + value.num_outputs, + value.operands, + value.polynomial, + ) + + def __mul__(self, factor: float) -> EquivariantPolynomial: + return EquivariantPolynomial(self.operands, self.polynomial * factor) + + def __rmul__(self, factor: float) -> EquivariantPolynomial: + return self.__mul__(factor) + + def __repr__(self): + return self.polynomial.to_string([f"{rep}" for rep in self.operands]) + + def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: + """Evaluate the polynomial on the given inputs. + + Args: + *inputs (np.ndarray): Input tensors to evaluate the polynomial on. + + Returns: + list[np.ndarray]: Output tensors resulting from the polynomial evaluation. + """ + return self.polynomial(*inputs) + + @property + def num_operands(self) -> int: + """The total number of operands (inputs and outputs) in the polynomial.""" + return len(self.operands) + + @property + def num_inputs(self) -> int: + """The number of input tensors expected by the polynomial.""" + return self.polynomial.num_inputs + + @property + def num_outputs(self) -> int: + """The number of output tensors produced by the polynomial.""" + return self.polynomial.num_outputs + + @property + def inputs(self) -> tuple[cue.Rep, ...]: + """The group representations of the input tensors.""" + return self.operands[: self.num_inputs] + + @property + def outputs(self) -> tuple[cue.Rep, ...]: + """The group representations of the output tensors.""" + return self.operands[self.num_inputs :] + + def fuse_stps(self) -> EquivariantPolynomial: + """Fuse segmented tensor products with identical operations and operands. + + Returns: + EquivariantPolynomial: A new polynomial with fused tensor products. + """ + return EquivariantPolynomial( + self.operands, + self.polynomial.fuse_stps(), + ) + + @classmethod + def stack( + cls, polys: list[EquivariantPolynomial], stacked: list[bool] + ) -> EquivariantPolynomial: + """Stack multiple equivariant polynomials together. + + This method combines multiple polynomials by stacking their operands according to the + stacked parameter. Operands with the same index that are not stacked must be identical + across all polynomials. + + Args: + polys (list[EquivariantPolynomial]): List of polynomials to stack. + stacked (list[bool]): Boolean flags indicating which operands should be stacked. + + Returns: + EquivariantPolynomial: A new polynomial combining the stacked polynomials. + + Raises: + ValueError: If operands that are not stacked differ across polynomials. + """ + assert len(polys) > 0 + num_operands = polys[0].num_operands + + assert all(pol.num_operands == num_operands for pol in polys) + assert len(stacked) == num_operands + + operands = [] + for oid in range(num_operands): + if stacked[oid]: + for pol in polys: + if not isinstance(pol.operands[oid], cue.IrrepsAndLayout): + raise ValueError( + f"Cannot stack operand {oid} of type {type(pol.operands[oid])}" + ) + operands.append(cue.concatenate([pol.operands[oid] for pol in polys])) + else: + ope = polys[0].operands[oid] + for pol in polys: + if pol.operands[oid] != ope: + raise ValueError( + f"Operand {oid} must be the same for all polynomials." + f" Found {ope} and {pol.operands[oid]}" + ) + operands.append(ope) + + poly = cls( + operands, + cue.SegmentedPolynomial.stack([pol.polynomial for pol in polys], stacked), + ) + return poly.fuse_stps() + + def squeeze_modes(self) -> EquivariantPolynomial: + """Squeeze the modes of the segmented tensor products. + + Returns: + EquivariantPolynomial: A new polynomial with squeezed modes. + """ + return EquivariantPolynomial( + self.operands, + self.polynomial.squeeze_modes(), + ) + + def flatten_coefficient_modes(self) -> EquivariantPolynomial: + """Flatten the coefficient modes of the segmented tensor products. + + Returns: + EquivariantPolynomial: A new polynomial with flattened coefficient modes. + """ + return EquivariantPolynomial( + self.operands, + self.polynomial.flatten_coefficient_modes(), + ) + + def jvp(self, has_tangent: list[bool]) -> EquivariantPolynomial: + """Compute the Jacobian-vector product of the polynomial. + + This method creates a new polynomial that, when evaluated, computes the Jacobian-vector + product of the original polynomial. This is used for forward-mode automatic differentiation. + + Args: + has_tangent (list[bool]): Boolean flags indicating which inputs have tangent vectors. + + Returns: + EquivariantPolynomial: A new polynomial representing the JVP operation. + """ + return EquivariantPolynomial( + list(self.inputs) + + [x for has, x in zip(has_tangent, self.inputs) if has] + + list(self.outputs), + self.polynomial.jvp(has_tangent), + ) + + def transpose( + self, + is_undefined_primal: list[bool], + has_cotangent: list[bool], + ) -> EquivariantPolynomial: + """Transpose the polynomial operation. + + This method creates a new polynomial that represents the transpose of the original operation. + The transpose is essential for reverse-mode automatic differentiation. + + Args: + is_undefined_primal (list[bool]): Boolean flags indicating which inputs are undefined primals. + has_cotangent (list[bool]): Boolean flags indicating which outputs have cotangents. + + Returns: + EquivariantPolynomial: A new polynomial representing the transposed operation. + + Raises: + ValueError: If the polynomial is non-linear and cannot be transposed. + """ + return EquivariantPolynomial( + # defined inputs + [ + x + for is_undefined, x in zip(is_undefined_primal, self.inputs) + if not is_undefined + ] + # cotangent outputs + + [x for has, x in zip(has_cotangent, self.outputs) if has] + # undefined inputs + + [ + x + for is_undefined, x in zip(is_undefined_primal, self.inputs) + if is_undefined + ], + self.polynomial.transpose(is_undefined_primal, has_cotangent), + ) + + def backward( + self, requires_gradient: list[bool], has_cotangent: list[bool] + ) -> EquivariantPolynomial: + """Compute the backward pass of the polynomial for gradient computation. + + This method combines the JVP and transpose operations to create a new polynomial that, + when evaluated, computes gradients of outputs with respect to inputs. This is the + core operation in reverse-mode automatic differentiation. + + Args: + requires_gradient (list[bool]): Boolean flags indicating which inputs require gradients. + has_cotangent (list[bool]): Boolean flags indicating which outputs have cotangents. + + Returns: + EquivariantPolynomial: A new polynomial for gradient computation. + """ + return EquivariantPolynomial( + list(self.inputs) + + [x for has, x in zip(has_cotangent, self.outputs) if has] + + [x for req, x in zip(requires_gradient, self.inputs) if req], + self.polynomial.backward(requires_gradient, has_cotangent), + ) + + def flops(self, batch_size: int = 1) -> int: + """Compute the number of floating point operations in the polynomial. + + Args: + batch_size (int, optional): The batch size for the computation. Defaults to 1. + + Returns: + int: The estimated number of floating-point operations. + """ + return self.polynomial.flops(batch_size) + + def memory(self, batch_sizes: list[int]) -> int: + """Compute the memory usage of the polynomial. + + Args: + batch_sizes (list[int]): The batch sizes for each operand. + + Returns: + int: The estimated memory usage in number of scalar elements. + """ + assert len(batch_sizes) == len(self.operands) + return sum(Z * rep.dim for Z, rep in zip(batch_sizes, self.operands)) diff --git a/cuequivariance/cuequivariance/equivariant_tensor_product.py b/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py similarity index 90% rename from cuequivariance/cuequivariance/equivariant_tensor_product.py rename to cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py index 8ae70a53..b5f5613f 100644 --- a/cuequivariance/cuequivariance/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py @@ -16,10 +16,10 @@ import copy import dataclasses +import warnings from typing import Optional, Sequence, Union import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp @dataclasses.dataclass(init=False, frozen=True) @@ -54,16 +54,22 @@ class EquivariantTensorProduct: """ operands: tuple[cue.Rep, ...] - ds: list[stp.SegmentedTensorProduct] + ds: list[cue.SegmentedTensorProduct] def __init__( self, - d: Union[stp.SegmentedTensorProduct, Sequence[stp.SegmentedTensorProduct]], + d: Union[cue.SegmentedTensorProduct, Sequence[cue.SegmentedTensorProduct]], operands: list[cue.Rep], symmetrize: bool = True, ): + warnings.warn( + "EquivariantTensorProduct is deprecated and will be removed in a future version. " + "Please use EquivariantPolynomial instead.", + DeprecationWarning, + stacklevel=2, + ) operands = tuple(operands) - if isinstance(d, stp.SegmentedTensorProduct): + if isinstance(d, cue.SegmentedTensorProduct): assert len(operands) == d.num_operands for oid in range(d.num_operands): assert operands[oid].dim == d.operands[oid].size @@ -122,7 +128,7 @@ def __rmul__(self, factor: float) -> EquivariantTensorProduct: return self.__mul__(factor) @property - def d(self) -> stp.SegmentedTensorProduct: + def d(self) -> cue.SegmentedTensorProduct: assert len(self.ds) == 1 return self.ds[0] @@ -262,7 +268,7 @@ def change_layout( del layout layouts = [cue.IrrepsLayout.as_layout(layout) for layout in layouts] - def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: + def f(d: cue.SegmentedTensorProduct) -> cue.SegmentedTensorProduct: ii = self.map_operands(d.num_operands) assert len(ii) == d.num_operands @@ -296,7 +302,9 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: continue raise NotImplementedError return d.add_or_transpose_modes( - stp.Subscripts.from_operands(new_subscripts, d.coefficient_subscripts) + cue.segmented_polynomials.Subscripts.from_operands( + new_subscripts, d.coefficient_subscripts + ) ) return EquivariantTensorProduct( @@ -309,7 +317,7 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: def flop_cost(self, batch_size: int) -> int: """Compute the number of flops of the tensor product.""" - return sum(d.flop_cost(-1) for d in self.ds) * batch_size + return sum(d.flops(-1) for d in self.ds) * batch_size def memory_cost( self, batch_sizes: tuple[int, ...], itemsize: Union[int, tuple[int, ...]] @@ -326,35 +334,6 @@ def memory_cost( def backward(self, input: int) -> tuple[EquivariantTensorProduct, tuple[int, ...]]: """ The backward pass of the equivariant tensor product. - - Args: - input: The input with respect to which the backward pass is computed. - - Returns: - A tuple containing the ETP representing the backward pass and the permutation of the operands between the original and the backward ETP. - - Examples: - >>> e = cue.descriptors.fully_connected_tensor_product( - ... cue.Irreps("SO3", "4x0+1x1"), - ... cue.Irreps("SO3", "4x0+2x1"), - ... cue.Irreps("SO3", "4x0+3x1") - ... ) - >>> e - EquivariantTensorProduct(114x0 x 4x0+1 x 4x0+2x1 -> 4x0+3x1) - >>> e_bwd, i = e.backward(0) - >>> e_bwd - EquivariantTensorProduct(4x0+3x1 x 4x0+1 x 4x0+2x1 -> 114x0) - >>> i - (3, 1, 2, 0) - - >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3]) - >>> e - EquivariantTensorProduct((1)^(0..3) -> 0+1+2+3) - >>> e_bwd, i = e.backward(0) - >>> e_bwd - EquivariantTensorProduct(0+1+2+3 x (1)^(0..2) -> 1) - >>> i - (1, 0, 0) """ assert input < self.num_inputs @@ -391,7 +370,7 @@ def backward(self, input: int) -> tuple[EquivariantTensorProduct, tuple[int, ... e = EquivariantTensorProduct(ds, tuple(self.operands[i] for i in oids)) return e, tuple(oids) - def stp_operand(self, oid: int) -> Optional[stp.Operand]: + def stp_operand(self, oid: int) -> Optional[cue.SegmentedOperand]: # output if oid == self.num_operands - 1: return self.ds[0].operands[-1] @@ -427,7 +406,7 @@ def stack( assert all(e.operands[oid] == ope for e in es) new_operands.append(ope) - new_ds: dict[int, stp.SegmentedTensorProduct] = {} + new_ds: dict[int, cue.SegmentedTensorProduct] = {} for eid, e in enumerate(es): for d in e.ds: d = copy.deepcopy(d) @@ -435,9 +414,9 @@ def stack( for oid in range(d.num_operands): if stacked[ii[oid]]: for e_ in reversed(es[:eid]): - d.insert_segments(oid, 0, e_.stp_operand(ii[oid])) + d.insert_segments(oid, 0, e_.stp_operand(ii[oid]).segments) for e_ in es[eid + 1 :]: - d.insert_segments(oid, -1, e_.stp_operand(ii[oid])) + d.insert_segments(oid, -1, e_.stp_operand(ii[oid]).segments) if d.num_operands not in new_ds: new_ds[d.num_operands] = d diff --git a/cuequivariance/cuequivariance/experimental/__init__.py b/cuequivariance/cuequivariance/group_theory/experimental/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/__init__.py rename to cuequivariance/cuequivariance/group_theory/experimental/__init__.py diff --git a/cuequivariance/cuequivariance/experimental/e3nn.py b/cuequivariance/cuequivariance/group_theory/experimental/e3nn.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/e3nn.py rename to cuequivariance/cuequivariance/group_theory/experimental/e3nn.py diff --git a/cuequivariance/cuequivariance/experimental/escn.py b/cuequivariance/cuequivariance/group_theory/experimental/escn.py similarity index 93% rename from cuequivariance/cuequivariance/experimental/escn.py rename to cuequivariance/cuequivariance/group_theory/experimental/escn.py index c08b0f61..1e6f43c3 100644 --- a/cuequivariance/cuequivariance/experimental/escn.py +++ b/cuequivariance/cuequivariance/group_theory/experimental/escn.py @@ -17,7 +17,6 @@ import numpy as np import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp # The function escn_iu_ju_ku below is a 1:1 adaptation of https://github.com/e3nn/e3nn-jax/blob/a2a81ab451b9cd597d7be27b3e1faba79457475d/e3nn_jax/experimental/linear_shtp.py#L38-L165 @@ -26,7 +25,7 @@ def escn_tp( irreps_out: cue.Irreps, m_max: Optional[int] = None, l_max: Optional[int] = None, -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``weights[uv],input[u],output[v]`` @@ -39,7 +38,7 @@ def escn_tp( l_max (int, optional): Maximum angular resolution along the principal axis. Returns: - EquivariantTensorProduct: + EquivariantPolynomial: Descriptor of the tensor product part of the eSCN convolution. - Operand 0: weights @@ -61,7 +60,7 @@ def pr(mul_ir: cue.MulIrrep) -> bool: irreps_out = irreps_out.filter(keep=pr) - d = stp.SegmentedTensorProduct.from_subscripts("iuv,ju,kv+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iuv,ju,kv+ijk") for mul, ir in irreps_in: d.add_segment(1, (ir.dim, mul)) @@ -105,13 +104,13 @@ def pr(mul_ir: cue.MulIrrep) -> bool: d = d.normalize_paths_for_operand(2) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps_in, cue.ir_mul), cue.IrrepsAndLayout(irreps_out, cue.ir_mul), ], + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -119,7 +118,7 @@ def escn_tp_compact( irreps_in: cue.Irreps, irreps_out: cue.Irreps, m_max: Optional[int] = None, -) -> stp.SegmentedTensorProduct: +) -> cue.SegmentedPolynomial: """ subsrcipts: ``weights[uv],input[u],output[v]`` @@ -146,7 +145,7 @@ def escn_tp_compact( if G not in [cue.SO3]: raise NotImplementedError("Only SO3 is supported") - d = stp.SegmentedTensorProduct.from_subscripts("uv,u,v") + d = cue.SegmentedTensorProduct.from_subscripts("uv,u,v") l_max_in = max(ir.l for _, ir in irreps_in) for m in range(-l_max_in, l_max_in + 1): @@ -178,7 +177,8 @@ def escn_tp_compact( d.add_path(i, l_max_in - m, l_max_out + m, c=-1.0) d = d.normalize_paths_for_operand(2) - return d # TODO: return an EquivariantTensorProduct using SphericalSignal + # TODO: return an EquivariantPolynomial using SphericalSignal + return cue.SegmentedPolynomial.eval_last_operand(d) class SphericalSignal(cue.Rep): diff --git a/cuequivariance/cuequivariance/experimental/gatr.py b/cuequivariance/cuequivariance/group_theory/experimental/gatr.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/gatr.py rename to cuequivariance/cuequivariance/group_theory/experimental/gatr.py diff --git a/cuequivariance/cuequivariance/experimental/mace/__init__.py b/cuequivariance/cuequivariance/group_theory/experimental/mace/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/mace/__init__.py rename to cuequivariance/cuequivariance/group_theory/experimental/mace/__init__.py diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py similarity index 87% rename from cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py rename to cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py index 78aee0e0..71dcb263 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py @@ -18,14 +18,12 @@ import numpy as np import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp -from cuequivariance import descriptors -from cuequivariance.misc.linalg import round_to_sqrt_rational, triu_array +from cuequivariance.etc.linalg import round_to_sqrt_rational, triu_array def symmetric_contraction( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: list[int] -) -> tuple[cue.EquivariantTensorProduct, np.ndarray]: +) -> tuple[cue.EquivariantPolynomial, np.ndarray]: r""" subscripts: ``weights[u],input[u],output[u]`` @@ -42,12 +40,15 @@ def symmetric_contraction( w = jax.random.normal(jax.random.key(0), (p.shape[0], mul)) w = jnp.einsum("au,ab->bu", w, p).flatten() - cuex.equivariant_tensor_product(e, w, cuex.randn(jax.random.key(1), e.inputs[1])) + x = cuex.randn(jax.random.key(1), e.inputs[1]) + y = cuex.equivariant_polynomial(e, [w, x]) """ assert min(degrees) > 0 - e1 = cue.EquivariantTensorProduct.stack( + + # poly1 replicates the behavior of the original MACE implementation + poly1 = cue.EquivariantPolynomial.stack( [ - cue.EquivariantTensorProduct.stack( + cue.EquivariantPolynomial.stack( [ _symmetric_contraction(irreps_in, irreps_out[i : i + 1], deg) for deg in reversed(degrees) @@ -58,7 +59,7 @@ def symmetric_contraction( ], [True, False, True], ) - e2 = descriptors.symmetric_contraction(irreps_in, irreps_out, degrees) + poly2 = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, degrees) a1, a2 = [ np.concatenate( [ @@ -67,11 +68,11 @@ def symmetric_contraction( 1, None, ) - for d in sorted(e.ds, key=lambda d: d.num_operands) + for _, d in pol.polynomial.operations ], axis=1, ) - for e in [e1, e2] + for pol in [poly1, poly2] ] # This nonzeros selection is just for lightening the inversion @@ -83,7 +84,7 @@ def symmetric_contraction( projection = round_to_sqrt_rational(projection) np.testing.assert_allclose(a1, projection @ a2, atol=1e-7) - return e2, projection + return poly2, projection def _flatten( @@ -103,7 +104,7 @@ def _flatten( def _stp_to_matrix( - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, ) -> np.ndarray: m = np.zeros([ope.num_segments for ope in d.operands]) for path in d.paths: @@ -114,7 +115,7 @@ def _stp_to_matrix( # This function is an adaptation of https://github.com/ACEsuit/mace/blob/bd412319b11c5f56c37cec6c4cfae74b2a49ff43/mace/modules/symmetric_contraction.py def _symmetric_contraction( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degree: int -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: mul = irreps_in.muls[0] assert all(mul == m for m in irreps_in.muls) assert all(mul == m for m in irreps_out.muls) @@ -125,7 +126,7 @@ def _symmetric_contraction( output_operand = degree + 1 abc = "abcdefgh"[:degree] - d = stp.SegmentedTensorProduct.from_subscripts( + d = cue.SegmentedTensorProduct.from_subscripts( f"u_{'_'.join(f'{a}' for a in abc)}_i+{abc}ui" ) @@ -145,13 +146,19 @@ def _symmetric_contraction( d = d.flatten_coefficient_modes() d = d.append_modes_to_all_operands("u", {"u": mul}) - return cue.EquivariantTensorProduct( - [d], + + assert d.num_operands >= 3 + [w, x], y = d.operands[:2], d.operands[-1] + + return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in.new_scalars(w.size), cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul), ], + cue.SegmentedPolynomial( + [w, x], [y], [(cue.Operation([0] + [1] * degree + [2]), d)] + ), ) diff --git a/cuequivariance/cuequivariance/irreps_array/__init__.py b/cuequivariance/cuequivariance/group_theory/irreps_array/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/__init__.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/__init__.py diff --git a/cuequivariance/cuequivariance/irreps_array/context_decorator.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py similarity index 85% rename from cuequivariance/cuequivariance/irreps_array/context_decorator.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py index 3ee8e902..00e614d2 100644 --- a/cuequivariance/cuequivariance/irreps_array/context_decorator.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py @@ -15,13 +15,14 @@ from functools import wraps from typing import Optional, Type, Union -import cuequivariance as cue -import cuequivariance.irreps_array as irreps_array -from cuequivariance.irreps_array.context_irrep_class import ( +import cuequivariance as cue # noqa: F401 +import cuequivariance.group_theory.irreps_array as irreps_array +from cuequivariance.group_theory import Irrep +from cuequivariance.group_theory.irreps_array.context_irrep_class import ( pop_irrep_scope, push_irrep_scope, ) -from cuequivariance.irreps_array.context_layout import ( +from cuequivariance.group_theory.irreps_array.context_layout import ( pop_layout_scope, push_layout_scope, ) @@ -51,7 +52,7 @@ class assume: def __init__( self, - irrep_class: Optional[Union[str, Type[cue.Irrep]]] = None, + irrep_class: Optional[Union[str, Type[Irrep]]] = None, layout: Optional[irreps_array.IrrepsLayout] = None, ): if isinstance(irrep_class, irreps_array.IrrepsLayout) and layout is None: diff --git a/cuequivariance/cuequivariance/irreps_array/context_irrep_class.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py similarity index 88% rename from cuequivariance/cuequivariance/irreps_array/context_irrep_class.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py index 84b4c370..8be62c8c 100644 --- a/cuequivariance/cuequivariance/irreps_array/context_irrep_class.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py @@ -15,11 +15,12 @@ from typing import Type, Union import cuequivariance as cue +from cuequivariance.group_theory.representations import Irrep -_irrep_class: Union[None, str, Type[cue.Irrep]] = None +_irrep_class: Union[None, str, Type[Irrep]] = None -def get_irrep_scope(raising: bool = True) -> Type[cue.Irrep]: +def get_irrep_scope(raising: bool = True) -> Type[Irrep]: if raising and _irrep_class is None: raise ValueError( "No irrep class set in the context. Please specify the irrep class explicitly or use ``with cue.assume(irrep):``." diff --git a/cuequivariance/cuequivariance/irreps_array/context_layout.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py similarity index 94% rename from cuequivariance/cuequivariance/irreps_array/context_layout.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py index 766c5df6..62a33246 100644 --- a/cuequivariance/cuequivariance/irreps_array/context_layout.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Union -from cuequivariance.irreps_array import IrrepsLayout +from cuequivariance.group_theory.irreps_array import IrrepsLayout _layout: Union[None, IrrepsLayout] = None diff --git a/cuequivariance/cuequivariance/irreps_array/irrep_utils.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py similarity index 84% rename from cuequivariance/cuequivariance/irreps_array/irrep_utils.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py index 3ecb66a6..ec119d90 100644 --- a/cuequivariance/cuequivariance/irreps_array/irrep_utils.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py @@ -15,21 +15,21 @@ from typing import Iterable, Type, Union import cuequivariance as cue -from cuequivariance import irreps_array +from cuequivariance.group_theory import Irrep, irreps_array def into_list_of_irrep( - irrep_class: Type[cue.Irrep], + irrep_class: Type[Irrep], input: Union[ str, - cue.Irrep, + Irrep, irreps_array.MulIrrep, - Iterable[Union[str, cue.Irrep, irreps_array.MulIrrep]], + Iterable[Union[str, Irrep, irreps_array.MulIrrep]], ], -) -> list[cue.Irrep]: +) -> list[Irrep]: if isinstance(input, str): return [rep for _, rep in cue.Irreps(irrep_class, input)] - if isinstance(input, cue.Irrep): + if isinstance(input, Irrep): return [input] if isinstance(input, irreps_array.MulIrrep): return [input.ir] @@ -41,7 +41,7 @@ def into_list_of_irrep( output = [] for rep in input: - if isinstance(rep, cue.Irrep): + if isinstance(rep, Irrep): output.append(rep) elif isinstance(rep, irreps_array.MulIrrep): output.append(rep.ir) diff --git a/cuequivariance/cuequivariance/irreps_array/irreps.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/irreps.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/irreps.py diff --git a/cuequivariance/cuequivariance/irreps_array/irreps_and_layout.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py similarity index 98% rename from cuequivariance/cuequivariance/irreps_array/irreps_and_layout.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py index b92b7780..0abe4081 100644 --- a/cuequivariance/cuequivariance/irreps_array/irreps_and_layout.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py @@ -19,10 +19,11 @@ import numpy as np import cuequivariance as cue +from cuequivariance.group_theory.representations import Rep @dataclass(init=False, frozen=True) -class IrrepsAndLayout(cue.Rep): +class IrrepsAndLayout(Rep): r""" A group representation (:class:`Rep`) made from the combination of :class:`Irreps` and :class:`IrrepsLayout` into a single object. diff --git a/cuequivariance/cuequivariance/irreps_array/irreps_layout.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_layout.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/irreps_layout.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/irreps_layout.py diff --git a/cuequivariance/cuequivariance/irreps_array/misc_ui.py b/cuequivariance/cuequivariance/group_theory/irreps_array/misc_ui.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/misc_ui.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/misc_ui.py diff --git a/cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py b/cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py similarity index 99% rename from cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py index 2c51c629..ea1f34f3 100644 --- a/cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py @@ -20,7 +20,7 @@ import numpy as np import cuequivariance as cue -from cuequivariance.irreps_array.irrep_utils import into_list_of_irrep +from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep # This class is inspired by https://github.com/e3nn/e3nn-jax/blob/245e17eb23deaccad9f2c9cfd40fe40515e3c074/e3nn_jax/_src/irreps_array.py diff --git a/cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py b/cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py similarity index 97% rename from cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py index 854ed16f..79147099 100644 --- a/cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py @@ -21,9 +21,7 @@ import numpy as np import cuequivariance as cue -from cuequivariance import irreps_array -from cuequivariance.irreps_array.irrep_utils import into_list_of_irrep -from cuequivariance.misc.linalg import ( +from cuequivariance.etc.linalg import ( basis_intersection, perm_compose, perm_inverse, @@ -31,6 +29,8 @@ round_to_sqrt_rational, sparsify_matrix, ) +from cuequivariance.group_theory import Irrep, irreps_array +from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ def reduced_tensor_product_basis( *, layout: irreps_array.IrrepsLayout = irreps_array.mul_ir, epsilon: float = 1e-5, - keep_ir: Optional[Union[irreps_array.Irreps, List[cue.Irrep]]] = None, + keep_ir: Optional[Union[irreps_array.Irreps, List[Irrep]]] = None, _use_optimized_implementation: bool = True, **irreps_dict, ) -> irreps_array.NumpyIrrepsArray: @@ -150,7 +150,7 @@ def reduced_symmetric_tensor_product_basis( *, layout: irreps_array.IrrepsLayout = irreps_array.mul_ir, epsilon: float = 1e-5, - keep_ir: Optional[Union[irreps_array.Irreps, List[cue.Irrep]]] = None, + keep_ir: Optional[Union[irreps_array.Irreps, List[Irrep]]] = None, _use_optimized_implementation: bool = True, ) -> irreps_array.NumpyIrrepsArray: r"""Reduce a symmetric tensor product, usually called for a single irrep. @@ -197,7 +197,7 @@ def reduced_antisymmetric_tensor_product_basis( *, layout: irreps_array.IrrepsLayout = irreps_array.mul_ir, epsilon: float = 1e-5, - keep_ir: Optional[Union[irreps_array.Irreps, List[cue.Irrep]]] = None, + keep_ir: Optional[Union[irreps_array.Irreps, List[Irrep]]] = None, _use_optimized_implementation: bool = True, ) -> irreps_array.NumpyIrrepsArray: r"""Reduce an antisymmetric tensor product. @@ -240,7 +240,7 @@ def reduced_antisymmetric_tensor_product_basis( def _entrypoint( irreps_tuple: Tuple[irreps_array.Irreps, ...], perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], - keep_ir: Optional[FrozenSet[cue.Irrep]], + keep_ir: Optional[FrozenSet[Irrep]], layout: irreps_array.IrrepsLayout, epsilon: float, _use_optimized_implementation: bool, @@ -276,7 +276,7 @@ def _entrypoint( def _main_cached_recursive( irreps_tuple: Tuple[irreps_array.Irreps, ...], perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], - keep_ir: Optional[FrozenSet[cue.Irrep]], + keep_ir: Optional[FrozenSet[Irrep]], epsilon: float, _use_optimized_implementation: bool, ) -> irreps_array.NumpyIrrepsArray: @@ -469,7 +469,7 @@ def _optimized_reduced_symmetric_tensor_product_basis( degree: int, *, epsilon: float, - keep_ir: Optional[List[cue.Irrep]] = None, + keep_ir: Optional[List[Irrep]] = None, ): r"""Reduce a symmetric tensor product. @@ -682,7 +682,7 @@ def germinate_perm_repr( def reduce_basis_product( basis1: irreps_array.NumpyIrrepsArray, basis2: irreps_array.NumpyIrrepsArray, - keep_ir: Optional[List[cue.Irrep]] = None, + keep_ir: Optional[List[Irrep]] = None, ) -> irreps_array.NumpyIrrepsArray: """Reduce the product of two basis.""" basis1 = basis1.regroup() @@ -692,7 +692,7 @@ def reduce_basis_product( layout = basis1.layout if keep_ir is not None: - assert all(isinstance(ir, cue.Irrep) for ir in keep_ir) + assert all(isinstance(ir, Irrep) for ir in keep_ir) keep_ir = frozenset(keep_ir) assert basis1.array.dtype == np.float64 @@ -710,7 +710,7 @@ def reduce_basis_product( if cache_key in _cache_reduce_basis_product: return _cache_reduce_basis_product[cache_key] - new_irreps: List[Tuple[int, cue.Irrep]] = [] + new_irreps: List[Tuple[int, Irrep]] = [] new_list = [] for (mul1, ir1), x1 in zip(basis1.irreps, basis1.segments): @@ -764,7 +764,7 @@ def constrain_rotation_basis_by_permutation_basis( (permutation_basis.shape[0], prod(permutation_basis.shape[1:])), ) # (free, dim) - new_irreps: List[Tuple[int, cue.Irrep]] = [] + new_irreps: List[Tuple[int, Irrep]] = [] new_list: List[np.ndarray] = [] for ir in sorted({ir for mul, ir in rotation_basis.irreps}): diff --git a/cuequivariance/cuequivariance/representation/__init__.py b/cuequivariance/cuequivariance/group_theory/representations/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/representation/__init__.py rename to cuequivariance/cuequivariance/group_theory/representations/__init__.py diff --git a/cuequivariance/cuequivariance/representation/irrep.py b/cuequivariance/cuequivariance/group_theory/representations/irrep.py similarity index 99% rename from cuequivariance/cuequivariance/representation/irrep.py rename to cuequivariance/cuequivariance/group_theory/representations/irrep.py index e4338596..99cbd940 100644 --- a/cuequivariance/cuequivariance/representation/irrep.py +++ b/cuequivariance/cuequivariance/group_theory/representations/irrep.py @@ -22,7 +22,7 @@ import numpy as np import cuequivariance as cue # noqa: F401 -from cuequivariance.representation import Rep +from cuequivariance.group_theory.representations import Rep # This class is inspired from https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irrep.py diff --git a/cuequivariance/cuequivariance/representation/irrep_o3.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py similarity index 98% rename from cuequivariance/cuequivariance/representation/irrep_o3.py rename to cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py index 524aa56e..cf88f992 100644 --- a/cuequivariance/cuequivariance/representation/irrep_o3.py +++ b/cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py @@ -21,7 +21,7 @@ import numpy as np -from cuequivariance.representation import SO3, Irrep +from cuequivariance.group_theory.representations import SO3, Irrep # This class is an adaptation of https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irreps/o3_real.py diff --git a/cuequivariance/cuequivariance/representation/irrep_so3.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py similarity index 97% rename from cuequivariance/cuequivariance/representation/irrep_so3.py rename to cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py index 10d47d61..e7fcadca 100644 --- a/cuequivariance/cuequivariance/representation/irrep_so3.py +++ b/cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py @@ -21,8 +21,8 @@ import numpy as np -from cuequivariance.misc.linalg import round_to_sqrt_rational -from cuequivariance.representation import SU2, Irrep +from cuequivariance.etc.linalg import round_to_sqrt_rational +from cuequivariance.group_theory.representations import SU2, Irrep # This function is copied from https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irreps/so3_real.py diff --git a/cuequivariance/cuequivariance/representation/irrep_su2.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py similarity index 99% rename from cuequivariance/cuequivariance/representation/irrep_su2.py rename to cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py index d378bbea..11e8dea6 100644 --- a/cuequivariance/cuequivariance/representation/irrep_su2.py +++ b/cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py @@ -22,7 +22,7 @@ import numpy as np -from cuequivariance.representation import Irrep +from cuequivariance.group_theory.representations import Irrep # This class is an adaptation of https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irreps/su2.py diff --git a/cuequivariance/cuequivariance/representation/rep.py b/cuequivariance/cuequivariance/group_theory/representations/rep.py similarity index 100% rename from cuequivariance/cuequivariance/representation/rep.py rename to cuequivariance/cuequivariance/group_theory/representations/rep.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/__init__.py b/cuequivariance/cuequivariance/segmented_polynomials/__init__.py similarity index 79% rename from cuequivariance/cuequivariance/segmented_tensor_product/__init__.py rename to cuequivariance/cuequivariance/segmented_polynomials/__init__.py index 67519aeb..42cadb24 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/__init__.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,23 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. from .subscripts import Subscripts -from .operand import Operand from .path import Path +from .segmented_operand import SegmentedOperand from .segmented_tensor_product import SegmentedTensorProduct from .dot import dot, trace from .evaluate import compute_last_operand, primitive_compute_last_operand from .dispatch import dispatch +from .operation import Operation +from .segmented_polynomial import SegmentedPolynomial + __all__ = [ "Subscripts", - "Operand", "Path", + "SegmentedOperand", "SegmentedTensorProduct", "dot", "trace", "compute_last_operand", "primitive_compute_last_operand", "dispatch", + "Operation", + "SegmentedPolynomial", ] diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dimensions_dict.py b/cuequivariance/cuequivariance/segmented_polynomials/dimensions_dict.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/dimensions_dict.py rename to cuequivariance/cuequivariance/segmented_polynomials/dimensions_dict.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py b/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py similarity index 90% rename from cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py rename to cuequivariance/cuequivariance/segmented_polynomials/dispatch.py index ecdce832..40cf2ec3 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py @@ -16,14 +16,15 @@ import math from typing import Generator, Tuple -import cuequivariance.segmented_tensor_product as stp +# we cannot import cuequivariance as cue because of circular import +from cuequivariance.segmented_polynomials import SegmentedTensorProduct, Subscripts def dispatch( - descriptor: stp.SegmentedTensorProduct, - targets: list[stp.Subscripts], + descriptor: SegmentedTensorProduct, + targets: list[Subscripts], permutation_mode: str, -) -> Generator[Tuple[stp.SegmentedTensorProduct, Tuple[int, ...]], None, None]: +) -> Generator[Tuple[SegmentedTensorProduct, Tuple[int, ...]], None, None]: """Dispatch a descriptor to a target subscripts. Args: @@ -41,7 +42,7 @@ def dispatch( and dispatch the flattened descriptor to the target subscripts. The function will yield all the possible dispatches found. """ - targets = [stp.Subscripts(subscripts) for subscripts in targets] + targets = [Subscripts(subscripts) for subscripts in targets] targets = [ subscripts for subscripts in targets diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py b/cuequivariance/cuequivariance/segmented_polynomials/dot.py similarity index 85% rename from cuequivariance/cuequivariance/segmented_tensor_product/dot.py rename to cuequivariance/cuequivariance/segmented_polynomials/dot.py index e54d4b19..0d43669b 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/dot.py @@ -14,13 +14,13 @@ # limitations under the License. from __future__ import annotations -import copy +import dataclasses import itertools from typing import Any, Sequence import numpy as np -from cuequivariance import segmented_tensor_product as stp +import cuequivariance as cue def stable_unique(xs: Sequence[Any]) -> Sequence[Any]: @@ -34,10 +34,10 @@ def stable_unique(xs: Sequence[Any]) -> Sequence[Any]: def dot( - x: stp.SegmentedTensorProduct, - y: stp.SegmentedTensorProduct, + x: cue.SegmentedTensorProduct, + y: cue.SegmentedTensorProduct, *contraction: tuple[int, int], -) -> stp.SegmentedTensorProduct: +) -> cue.SegmentedTensorProduct: """ Compute the dot product of two segmented tensor products. @@ -50,7 +50,7 @@ def dot( The segmented tensor product resulting from the dot product. """ for oidx, oidy in contraction: - if x.operands[oidx] != y.operands[oidy]: + if x.operands_and_subscripts[oidx] != y.operands_and_subscripts[oidy]: raise ValueError("Operands to contract must be the same.") x_keep = [ @@ -60,13 +60,15 @@ def dot( oid for oid in range(y.num_operands) if all(oid != j for _, j in contraction) ] - d = stp.SegmentedTensorProduct( + d = cue.SegmentedTensorProduct( coefficient_subscripts=stable_unique( x.coefficient_subscripts + y.coefficient_subscripts ) ) - d.operands = copy.deepcopy( - [x.operands[i] for i in x_keep] + [y.operands[i] for i in y_keep] + d = dataclasses.replace( + d, + operands_and_subscripts=[x.operands_and_subscripts[i] for i in x_keep] + + [y.operands_and_subscripts[i] for i in y_keep], ) formula = f"{x.coefficient_subscripts} , {y.coefficient_subscripts} -> {d.coefficient_subscripts}" @@ -106,8 +108,8 @@ def dot( def trace( - d: stp.SegmentedTensorProduct, *contraction: tuple[int, int] -) -> stp.SegmentedTensorProduct: + d: cue.SegmentedTensorProduct, *contraction: tuple[int, int] +) -> cue.SegmentedTensorProduct: """ Compute the trace of a segmented tensor product. @@ -130,21 +132,17 @@ def trace( mapping = { chj: chi for i, j in contraction - for chi, chj in zip(d.operands[i].subscripts, d.operands[j].subscripts) + for chi, chj in zip(d.subscripts.operands[i], d.subscripts.operands[j]) } f = lambda subscripts: "".join(mapping.get(ch, ch) for ch in subscripts) # noqa coefficients_subscripts_renamed = f(d.coefficient_subscripts) coefficients_subscripts_compressed = stable_unique(coefficients_subscripts_renamed) - dout = stp.SegmentedTensorProduct( + dout = cue.SegmentedTensorProduct( coefficient_subscripts=coefficients_subscripts_compressed, - operands=[ - stp.Operand( - subscripts=f(ope.subscripts), - segments=ope.segments, - ) - for ope in (d.operands[i] for i in keep) + operands_and_subscripts=[ + (ope, f(ss)) for ope, ss in (d.operands_and_subscripts[i] for i in keep) ], ) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py b/cuequivariance/cuequivariance/segmented_polynomials/evaluate.py similarity index 98% rename from cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py rename to cuequivariance/cuequivariance/segmented_polynomials/evaluate.py index 289a722d..631890ec 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/evaluate.py @@ -19,11 +19,11 @@ import numpy as np -import cuequivariance.segmented_tensor_product as stp +import cuequivariance as cue def compute_last_operand( - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, *inputs: np.ndarray, segment_axes: Union[int, list[int]] = -1, dtype: Optional[np.dtype] = None, diff --git a/cuequivariance/cuequivariance/operation.py b/cuequivariance/cuequivariance/segmented_polynomials/operation.py similarity index 92% rename from cuequivariance/cuequivariance/operation.py rename to cuequivariance/cuequivariance/segmented_polynomials/operation.py index 429e92a4..f26c724e 100644 --- a/cuequivariance/cuequivariance/operation.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/operation.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import dataclasses import itertools from collections import defaultdict @@ -21,6 +22,7 @@ OVARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" +@dataclasses.dataclass(init=False, frozen=True) class Operation: """Descriptor mapping input/output buffers to tensor product operands. @@ -45,17 +47,19 @@ class Operation: buffers: tuple[int, ...] - def __init__(self, buffers: tuple[int, ...]): + def __init__(self, buffers: tuple[int, ...] | Operation): + if isinstance(buffers, Operation): + buffers = buffers.buffers assert len(buffers) > 0, buffers assert all(isinstance(b, int) for b in buffers), buffers assert all(i >= 0 for i in buffers), buffers - self.buffers = tuple(int(b) for b in buffers) + object.__setattr__(self, "buffers", tuple(int(b) for b in buffers)) def __repr__(self): return f"Operation({self.buffers})" - def to_string(self, num_inputs: int) -> str: - return " ".join(IVARS[b] if b < num_inputs else OVARS[b] for b in self.buffers) + def to_letters(self, num_inputs: int) -> list[str]: + return [IVARS[b] if b < num_inputs else OVARS[b] for b in self.buffers] @staticmethod def list_to_string( @@ -65,12 +69,12 @@ def list_to_string( o = ", ".join(OVARS[num_inputs : num_inputs + num_outputs]) s = f"({i}) -> ({o})" for op in operations: - s += "\n " + op.to_string(num_inputs) + s += "\n " + " ".join(op.to_letters(num_inputs)) return s def __lt__(self, value): assert isinstance(value, Operation) - return self.buffers < value.buffers + return (len(self.buffers), self.buffers) < (len(value.buffers), value.buffers) def __hash__(self) -> int: return hash(self.buffers) @@ -82,6 +86,12 @@ def __eq__(self, value): def permute_operands(self, permutation: tuple[int, ...]) -> Operation: return Operation(tuple(self.buffers[p] for p in permutation)) + def move_operand_last(self, operand: int) -> Operation: + buffers = list(self.buffers) + b = buffers.pop(operand) + buffers.append(b) + return Operation(tuple(buffers)) + def input_operands_buffers(self, num_inputs: int) -> list[tuple[int, int]]: return [(op, i) for op, i in enumerate(self.buffers) if i < num_inputs] diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/path.py b/cuequivariance/cuequivariance/segmented_polynomials/path.py similarity index 90% rename from cuequivariance/cuequivariance/segmented_tensor_product/path.py rename to cuequivariance/cuequivariance/segmented_polynomials/path.py index be980265..8b59ea11 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/path.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/path.py @@ -39,7 +39,8 @@ class Path: def __init__(self, indices, coefficients): super().__setattr__("indices", tuple(int(i) for i in indices)) super().__setattr__( - "coefficients", np.asarray(coefficients, dtype=np.float64, order="C").copy() + "coefficients", + np.asarray(coefficients, dtype=np.float64, order="C", copy=True), ) def assert_valid(self): @@ -75,6 +76,13 @@ def __eq__(self, other: Path) -> bool: self.coefficients, other.coefficients ) + def __lt__(self, other: Path) -> bool: + k1 = (self.indices, self.coefficients.shape) + k2 = (other.indices, other.coefficients.shape) + if k1 != k2: + return k1 < k2 + return tuple(self.coefficients.flatten()) < tuple(other.coefficients.flatten()) + def permute_operands(self, perm: tuple[int, ...]) -> Path: """ Apply a permutation to the operands. diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py new file mode 100644 index 00000000..1a6644df --- /dev/null +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import math + +from cuequivariance.segmented_polynomials.dimensions_dict import format_dimensions_dict + + +@dataclasses.dataclass(init=False, frozen=True) +class SegmentedOperand: + """A segmented operand is a list of segment's shapes.""" + + ndim: int + segments: tuple[tuple[int, ...]] + _dims: dict[int, set[int]] + + def __init__( + self, + segments: list[tuple[int, ...]] | None = None, + *, + ndim: int | None = None, + _dims: dict[int, set[int]] | None = None, + ): + if segments is None: + segments = [] + object.__setattr__(self, "segments", tuple(segments)) + + if ndim is None: + assert len(self.segments) > 0 + ndim = len(self.segments[0]) + object.__setattr__(self, "ndim", ndim) + + if _dims is None: + _dims = dict() + for segment in self.segments: + for i, d in enumerate(segment): + _dims.setdefault(i, set()).add(d) + else: + _dims = _dims.copy() + object.__setattr__(self, "_dims", _dims) + + @classmethod + def empty_segments(cls, num_segments: int) -> SegmentedOperand: + """Create an operand with ndim=0""" + return cls(ndim=0, segments=[()] * num_segments, _dims=dict()) + + @classmethod + def stack(cls, operands: list[SegmentedOperand]) -> SegmentedOperand: + """Stack a list of operands together.""" + assert len(operands) > 0 + ndim = operands[0].ndim + assert all(ope.ndim == ndim for ope in operands) + + _dims = dict() + for ope in operands: + for i, d in ope.get_dimensions_dict().items(): + _dims.setdefault(i, set()).update(d) + + return cls( + ndim=ndim, + segments=sum([list(ope.segments) for ope in operands], []), + _dims=_dims, + ) + + def copy(self) -> SegmentedOperand: + """Copy the operand.""" + return SegmentedOperand( + ndim=self.ndim, + segments=self.segments, + _dims=self._dims, + ) + + def assert_valid(self): + """Assert that the operand is valid.""" + for segment in self.segments: + if len(segment) != self.ndim: + raise ValueError( + f"segment {segment} has {len(segment)} dimensions, expected {self.ndim}." + ) + + if not all(isinstance(dim, int) and dim > 0 for dim in segment): + raise ValueError(f"segment {segment} is not valid.") + + for i, d in enumerate(segment): + if d not in self.get_dims(i): + raise ValueError( + f"dimension {d} not in {i} dimensions {self.get_dims(i)}." + ) + + def insert_segment(self, index: int, segment: tuple[int, ...]): + """Insert a segment at a given index.""" + if len(segment) != self.ndim: + raise ValueError( + f"segment has {len(segment)} dimensions, expected {self.ndim}." + ) + + if index < 0: + index = len(self.segments) + index + + if index < 0 or index > len(self.segments): + raise ValueError( + f"index {index} is out of bounds for segments {self.segments}." + ) + + segment = tuple(int(d) for d in segment) + object.__setattr__( + self, + "segments", + self.segments[:index] + (segment,) + self.segments[index:], + ) + + for i, d in enumerate(segment): + self._dims.setdefault(i, set()).add(d) + + def add_segment(self, segment: tuple[int, ...]) -> int: + """Add a segment to the operand.""" + self.insert_segment(len(self.segments), segment) + return len(self.segments) - 1 + + def __hash__(self) -> int: + return hash((self.ndim, self.segments)) + + def __eq__(self, other: SegmentedOperand) -> bool: + assert isinstance(other, SegmentedOperand) + return self.ndim == other.ndim and self.segments == other.segments + + def __lt__(self, other: SegmentedOperand) -> bool: + assert isinstance(other, SegmentedOperand) + return (self.ndim, self.segments) < (other.ndim, other.segments) + + def __repr__(self) -> str: + dims = format_dimensions_dict(self.get_dimensions_dict()) + return f"Operand(ndim={self.ndim} num_segments={self.num_segments} dims={dims})" + + def __getitem__(self, index: int) -> tuple[int, ...]: + return self.segments[index] + + def __len__(self) -> int: + return self.num_segments + + def __iter__(self): + return iter(self.segments) + + @property + def num_segments(self) -> int: + """The number of segments in the operand.""" + return len(self.segments) + + @property + def size(self) -> int: + """The total size of the operand.""" + if self.all_same_segment_shape(): + return self.num_segments * self.segment_size + + return sum(math.prod(segment) for segment in self.segments) + + def segment_slices(self) -> list[slice]: + """Return slice object for each segment.""" + offset = 0 + slices = [] + for segment in self.segments: + slices.append(slice(offset, offset + math.prod(segment))) + offset += math.prod(segment) + return slices + + def get_dimensions_dict(self) -> dict[int, set[int]]: + """Return a dictionary of dimensions for each channel.""" + return self._dims.copy() + + def get_dims(self, i: int) -> set[int]: + """Return the dimensions for a given channel.""" + return self._dims.get(i, set()).copy() + + def all_same_segment_shape(self) -> bool: + """Check if all segments have the same shape. Returns False if there are no segments.""" + return all(len(dd) == 1 for dd in self._dims.values()) and self.num_segments > 0 + + @property + def segment_shape(self) -> tuple[int, ...]: + """The shape of the segments if they are all the same.""" + if not self.all_same_segment_shape(): + raise ValueError("Segments do not have the same shape.") + return self.segments[0] + + @property + def segment_size(self) -> int: + """The size of the segments if they are all the same.""" + if not self.all_same_segment_shape(): + raise ValueError("Segments do not have the same shape.") + return math.prod(self.segments[0]) + + def __add__(self, other: SegmentedOperand) -> SegmentedOperand: + if self.ndim != other.ndim: + raise ValueError("ndim do not match.") + return SegmentedOperand( + ndim=self.ndim, + segments=self.segments + other.segments, + _dims={i: self.get_dims(i) | other.get_dims(i) for i in range(self.ndim)}, + ) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py new file mode 100644 index 00000000..c876a62a --- /dev/null +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -0,0 +1,576 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import copy +import dataclasses +import itertools +from typing import Callable, Sequence + +import numpy as np + +import cuequivariance as cue +from cuequivariance.segmented_polynomials.operation import IVARS, OVARS + +from .dimensions_dict import format_dimensions_dict + + +@dataclasses.dataclass(init=False, frozen=True) +class SegmentedPolynomial: + """A polynomial representation using segmented tensor products. + + This class represents a polynomial using a collection of segmented tensor products, where each product + is associated with an operation that specifies how inputs are combined. The polynomial maps a set of + input tensors to output tensors through these tensor products. + + Args: + inputs (tuple of SegmentedOperand): Input operands. + outputs (tuple of SegmentedOperand): Output operands. + tensor_products (list of tuple of Operation and SegmentedTensorProduct): List of operation and tensor product pairs + that define the polynomial transformation. + """ + + inputs: tuple[cue.SegmentedOperand, ...] + outputs: tuple[cue.SegmentedOperand, ...] + operations: tuple[tuple[cue.Operation, cue.SegmentedTensorProduct], ...] + + def __init__( + self, + inputs: Sequence[cue.SegmentedOperand], + outputs: Sequence[cue.SegmentedOperand], + operations: Sequence[ + tuple[cue.Operation | Sequence[int], cue.SegmentedTensorProduct] + ], + ): + inputs = tuple(inputs) + outputs = tuple(outputs) + operands = inputs + outputs + + tmp = [] + for opt, stp in operations: + opt = cue.Operation(opt) + assert isinstance(opt, cue.Operation) + assert isinstance(stp, cue.SegmentedTensorProduct) + assert len(opt.buffers) == stp.num_operands + for buffer_id, operand in zip(opt.buffers, stp.operands): + assert operand == operands[buffer_id] + + out_oid, bid = opt.output_operand_buffer(len(inputs)) + tmp.append( + (bid, opt.move_operand_last(out_oid), stp.move_operand_last(out_oid)) + ) + tmp = sorted(tmp) + operations = [(opt, stp) for _, opt, stp in tmp] + + object.__setattr__(self, "inputs", inputs) + object.__setattr__(self, "outputs", outputs) + object.__setattr__(self, "operations", tuple(operations)) + + @classmethod + def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): + return cls( + stp.operands[:-1], + (stp.operands[-1],), + ((cue.Operation(tuple(range(stp.num_operands))), stp),), + ) + + @classmethod + def from_default_buffers( + cls, + inputs: Sequence[cue.SegmentedOperand | None], + outputs: Sequence[cue.SegmentedOperand | None], + tensor_products: Sequence[ + tuple[cue.Operation | Sequence[int], cue.SegmentedTensorProduct] + ], + ): + buffers = list(inputs) + list(outputs) + for ope, stp in tensor_products: + ope = cue.Operation(ope) + assert isinstance(stp, cue.SegmentedTensorProduct) + assert len(ope.buffers) == stp.num_operands + for buffer_id, operand in zip(ope.buffers, stp.operands): + buffers[buffer_id] = operand + + return cls(buffers[: len(inputs)], buffers[len(inputs) :], tensor_products) + + @classmethod + def from_stps( + cls, + inputs: Sequence[cue.SegmentedOperand | None], + outputs: Sequence[cue.SegmentedOperand | None], + tensor_products: Sequence[ + tuple[cue.Operation | Sequence[int], cue.SegmentedTensorProduct] + ], + ) -> SegmentedPolynomial: + """Stack segmented tensor products together.""" + inputs, outputs = list(inputs), list(outputs) + return cls.stack( + [ + cls.from_default_buffers(inputs, outputs, [(ope, stp)]) + for ope, stp in tensor_products + ], + [ope is None for ope in inputs + outputs], + ) + + def __hash__(self) -> int: + return hash((self.inputs, self.outputs, self.operations)) + + def __eq__(self, value) -> bool: + assert isinstance(value, SegmentedPolynomial) + return ( + self.inputs == value.inputs + and self.outputs == value.outputs + and self.operations == value.operations + ) + + def __lt__(self, value) -> bool: + assert isinstance(value, SegmentedPolynomial) + return ( + self.inputs, + self.outputs, + self.operations, + ) < ( + value.inputs, + value.outputs, + value.operations, + ) + + def __mul__(self, factor: float) -> SegmentedPolynomial: + return SegmentedPolynomial( + self.inputs, + self.outputs, + tuple((ope, factor * stp) for ope, stp in self.operations), + ) + + def __rmul__(self, factor: float) -> SegmentedPolynomial: + return self.__mul__(factor) + + def __repr__(self): + def sfmt(shape: tuple[int, ...]) -> str: + return "(" + ",".join(str(d) for d in shape) + ")" + + buffer_names = [] + for ope in self.operands: + if ope.all_same_segment_shape(): + buffer_names.append( + f"[{ope.size}:{ope.num_segments}⨯{sfmt(ope.segment_shape)}]" + ) + else: + txts = [] + n = 20 + for s in ope.segments: + txts.append(sfmt(s)) + if len("+".join(txts)) > n: + txts.pop() + break + if len(txts) < len(ope.segments): + txts.append("...") + buffer_names.append(f"[{ope.size}:{'+'.join(txts)}]") + return self.to_string(buffer_names) + + def to_string(self, buffer_names: list[str] | None = None) -> str: + buffer_txts = ( + IVARS[: self.num_inputs] + + OVARS[self.num_inputs : self.num_inputs + self.num_outputs] + ) + if buffer_names is not None: + buffer_txts = [ + f"{symbol}={name}" for symbol, name in zip(buffer_txts, buffer_names) + ] + + header = ( + " ".join(buffer_txts[: self.num_inputs]) + + " -> " + + " ".join(buffer_txts[self.num_inputs :]) + ) + + def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct) -> str: + items = [ + f"{buffer}[{ss}]" + for buffer, ss in zip( + ope.to_letters(self.num_inputs), stp.subscripts.operands + ) + ] + out = items[-1] + items = [f"[{stp.coefficient_subscripts}]"] + items[:-1] + return "·".join(items) + "➜" + out + + lines = ["│ " + f(ope, stp) for ope, stp in self.operations] + if len(lines) > 0: + lines[-1] = "╰─" + lines[-1][2:] + + n = max(len(line) for line in lines) + lines = [ + line + + " " + + "─" * (n - len(line)) + + "─ " + + f"num_paths={stp.num_paths} {format_dimensions_dict(stp.get_dimensions_dict())}" + for line, (_, stp) in zip(lines, self.operations) + ] + lines = ["╭ " + header] + lines + + lines = [line.rstrip() for line in lines] + return "\n".join(lines) + + def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: + inferred_shape = np.broadcast_shapes(*[x.shape[:-1] for x in inputs]) + inferred_dtype = np.result_type(*[x.dtype for x in inputs]) + outputs = [ + np.zeros(inferred_shape + (ope.size,), dtype=inferred_dtype) + for ope in self.outputs + ] + for ope, stp in self.operations: + oid, bid = ope.output_operand_buffer(self.num_inputs) + outputs[bid - self.num_inputs] += ( + cue.segmented_polynomials.compute_last_operand( + stp.move_operand_last(oid), + *[inputs[bid] for bid in ope.input_buffers(self.num_inputs)], + dtype=inferred_dtype, + ) + ) + return outputs + + @property + def operands(self) -> tuple[cue.SegmentedOperand, ...]: + return self.inputs + self.outputs + + @property + def num_inputs(self) -> int: + return len(self.inputs) + + @property + def num_outputs(self) -> int: + return len(self.outputs) + + @property + def num_operands(self) -> int: + """Number of operands in the polynomial.""" + return self.num_inputs + self.num_outputs + + def map_tensor_products( + self, + f: Callable[ + [cue.Operation, cue.SegmentedTensorProduct], + tuple[cue.Operation, cue.SegmentedTensorProduct] | None, + ], + ) -> SegmentedPolynomial: + new_tensor_products = [f(ope, stp) for ope, stp in self.operations] + new_tensor_products = tuple( + ope_stp for ope_stp in new_tensor_products if ope_stp is not None + ) + return SegmentedPolynomial.from_default_buffers( + self.inputs, self.outputs, new_tensor_products + ) + + def fuse_stps(self) -> SegmentedPolynomial: + """Fuse segmented tensor products with identical operations and operands.""" + poly = self.map_tensor_products( + lambda ope, stp: (ope, stp.canonicalize_subscripts()) + ) + + groups = itertools.groupby( + poly.operations, + key=lambda x: ( + x[0], + x[1].operands_and_subscripts, + x[1].coefficient_subscripts, + ), + ) + new_tensor_products = tuple( + ( + ope, + cue.SegmentedTensorProduct( + operands_and_subscripts=operands_and_subscripts, + coefficient_subscripts=coefficient_subscripts, + paths=[path for _, stp in elements for path in stp.paths], + ).consolidate_paths(), + ) + for ( + ope, + operands_and_subscripts, + coefficient_subscripts, + ), elements in groups + ) + return SegmentedPolynomial(self.inputs, self.outputs, new_tensor_products) + + def consolidate(self) -> SegmentedPolynomial: + """Consolidate the segmented tensor products.""" + + def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): + stp = ( + stp.consolidate_modes() + .squeeze_modes() + .remove_empty_segments() + .consolidate_paths() + ) + if stp.num_paths == 0: + return None + return ope, stp + + return self.fuse_stps().map_tensor_products(f) + + def used_inputs(self) -> list[bool]: + """Inputs used in the polynomial. (List of boolean values)""" + return [ + any(buffer in ope.buffers for ope, _ in self.operations) + for buffer in range(self.num_inputs) + ] + + def used_outputs(self) -> list[bool]: + """Outputs used in the polynomial. (List of boolean values)""" + return [ + any(buffer in ope.buffers for ope, _ in self.operations) + for buffer in range(self.num_inputs, self.num_inputs + self.num_outputs) + ] + + def used_buffers(self) -> list[bool]: + """Buffers used in the polynomial. (List of boolean values)""" + return self.used_inputs() + self.used_outputs() + + def select_buffers(self, keep: list[bool]) -> SegmentedPolynomial: + """Select the buffers of the polynomial.""" + assert len(keep) == self.num_operands + + # Create a mapping from old buffer indices to new buffer indices + new_index = [] + i = 0 + for u in keep: + if u: + new_index.append(i) + i += 1 + else: + new_index.append(None) + + # Filter tensor products that write to buffers we want to keep + # and remap the buffer indices + new_tensor_products = [] + for ope, stp in self.operations: + # Check if the operation writes to a buffer we want to keep + bid = ope.output_buffer(self.num_inputs) + if keep[bid]: + # Check if all input buffers needed by this operation are kept + if not all(keep[buffer] for buffer in ope.buffers): + raise ValueError( + f"Operation {ope} writes to buffer {bid} which is kept, but requires input buffers that are being dropped" + ) + + # Remap buffer indices + new_ope = cue.Operation([new_index[buffer] for buffer in ope.buffers]) + new_tensor_products.append((new_ope, stp)) + + return SegmentedPolynomial( + [x for x, k in zip(self.inputs, keep[: self.num_inputs]) if k], + [x for x, k in zip(self.outputs, keep[self.num_inputs :]) if k], + new_tensor_products, + ) + + def select_outputs(self, keep: list[bool]) -> SegmentedPolynomial: + """Select the outputs of the polynomial.""" + assert len(keep) == self.num_outputs + return self.select_buffers([True] * self.num_inputs + keep) + + def remove_unused_buffers(self) -> SegmentedPolynomial: + """Remove unused buffers from the polynomial.""" + return self.select_buffers(self.used_buffers()) + + def compute_only(self, keep: list[bool]) -> SegmentedPolynomial: + """Compute only the selected outputs of the polynomial.""" + assert len(keep) == self.num_outputs + return SegmentedPolynomial( + self.inputs, + self.outputs, # on purpose, we keep all outputs + [ + (ope, stp) + for ope, stp in self.operations + if keep[ope.output_buffer(self.num_inputs) - self.num_inputs] + ], + ) + + @classmethod + def stack( + cls, polys: list[SegmentedPolynomial], stacked: list[bool] + ) -> SegmentedPolynomial: + """Stack segmented polynomials together.""" + assert len(polys) > 0 + num_inputs = polys[0].num_inputs + num_outputs = polys[0].num_outputs + assert all(pol.num_inputs == num_inputs for pol in polys) + assert all(pol.num_outputs == num_outputs for pol in polys) + assert len(stacked) == num_inputs + num_outputs + + operands = [] + for bid in range(num_inputs + num_outputs): + if stacked[bid]: + operands.append( + cue.SegmentedOperand.stack( + [ + pol.operands[bid] + for pol in polys + if pol.operands[bid] + is not None # special case for .from_stps + ] + ) + ) + else: + ope = polys[0].operands[bid] + assert all(pol.operands[bid] == ope for pol in polys) + operands.append(ope) + + tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] = [] + for index, pol in enumerate(polys): + for ope, stp in pol.operations: + stp = copy.deepcopy(stp) + for oid, buffer in enumerate(ope.buffers): + if stacked[buffer]: + for p in reversed(polys[:index]): + stp.insert_segments(oid, 0, p.buffer_segments(buffer)) + for p in polys[index + 1 :]: + stp.insert_segments(oid, -1, p.buffer_segments(buffer)) + tensor_products.append((ope, stp)) + + return cls( + operands[:num_inputs], operands[num_inputs:], tensor_products + ).consolidate() + + def squeeze_modes(self) -> SegmentedPolynomial: + """Squeeze the modes of the segmented tensor products.""" + return SegmentedPolynomial.from_default_buffers( + self.inputs, + self.outputs, + [(ope, stp.squeeze_modes()) for ope, stp in self.operations], + ) + + def flatten_coefficient_modes(self) -> SegmentedPolynomial: + """Flatten the coefficient modes of the segmented tensor products.""" + return SegmentedPolynomial.from_default_buffers( + self.inputs, + self.outputs, + [(ope, stp.flatten_coefficient_modes()) for ope, stp in self.operations], + ) + + def jvp(self, has_tangent: list[bool]) -> SegmentedPolynomial: + """Compute the Jacobian-vector product of the polynomial.""" + assert len(has_tangent) == self.num_inputs + + # Symmetrizing the polynomial helps identify simplifications by group_by_operational_symmetries + sym_poly = self.symmetrize_for_identical_operands() + + new_tps = [] + for ope, stp in sym_poly.operations: + jvps = ope.jvp(has_tangent) + permutations: list[tuple[int, ...]] = stp.symmetries() + for multiplicator, ope in cue.Operation.group_by_operational_symmetries( + permutations, jvps + ): + new_tps.append((ope, multiplicator * stp)) + return SegmentedPolynomial( + list(self.inputs) + [x for has, x in zip(has_tangent, self.inputs) if has], + self.outputs, + new_tps, + ) + + def transpose( + self, + is_undefined_primal: list[bool], + has_cotangent: list[bool], + ) -> SegmentedPolynomial: + """Transpose the polynomial.""" + assert len(is_undefined_primal) == self.num_inputs + assert len(has_cotangent) == self.num_outputs + + new_tps = [] + for ope, stp in self.operations: + ope = ope.transpose(is_undefined_primal, has_cotangent) + if ope is not None: + new_tps.append((ope, stp)) + return SegmentedPolynomial( + # defined inputs + [x for undef, x in zip(is_undefined_primal, self.inputs) if not undef] + # cotangent outputs + + [x for has, x in zip(has_cotangent, self.outputs) if has], + # undefined inputs + [x for undef, x in zip(is_undefined_primal, self.inputs) if undef], + new_tps, + ) + + def backward( + self, requires_gradient: list[bool], has_cotangent: list[bool] + ) -> SegmentedPolynomial: + """Compute the backward pass of the polynomial.""" + return self.jvp(requires_gradient).transpose( + is_undefined_primal=[False] * self.num_inputs + + [True] * sum(requires_gradient), + has_cotangent=has_cotangent, + ) + + def flops(self, batch_size: int = 1) -> int: + """Compute the number of floating point operations in the polynomial.""" + n = 0 + for ope, stp in self.operations: + oid, _ = ope.output_operand_buffer(self.num_inputs) + n += stp.flops(oid) + return batch_size * n + + def memory(self, batch_sizes: list[int]) -> int: + """Compute the memory usage of the polynomial.""" + assert len(batch_sizes) == self.num_operands + return sum(Z * ope.size for Z, ope in zip(batch_sizes, self.operands)) + + def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: + segments = None + for ope, stp in self.operations: + if buffer in ope.buffers: + ope = stp.operands[ope.buffers.index(buffer)] + if segments is None: + segments = ope.segments + elif segments != ope.segments: + raise ValueError( + f"Buffer {buffer} has inconsistent segments: {segments} vs {ope.segments}" + ) + if segments is None: + raise ValueError(f"Buffer {buffer} is not used") + return segments + + def symmetrize_for_identical_operands(self) -> SegmentedPolynomial: + """Symmetrize the paths of the segmented tensor products for identical operands. + + This operation increases the number of paths in the segmented tensor products. + """ + + symmetrized_tensor_products = [] + for ope, stp in self.operations: + for set_of_operands in ope.operands_with_identical_buffers(): + stp = stp.symmetrize_operands(set_of_operands) + stp = stp.sort_paths() + symmetrized_tensor_products.append((ope, stp)) + + return SegmentedPolynomial( + self.inputs, self.outputs, symmetrized_tensor_products + ) + + def unsymmetrize_for_identical_operands(self) -> SegmentedPolynomial: + """Unsymmetrize the paths of the segmented tensor products for identical operands. + + This operation decreases the number of paths in the segmented tensor products. + """ + + def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): + for set_of_operands in ope.operands_with_identical_buffers(): + stp = stp.sort_indices_for_identical_operands(set_of_operands) + stp = stp.sort_paths() + return ope, stp + + return self.map_tensor_products(optimize_paths) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py similarity index 78% rename from cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py rename to cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index 2eb0af2c..7682c239 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -31,15 +31,19 @@ import opt_einsum import cuequivariance as cue # noqa: F401 -from cuequivariance import segmented_tensor_product as stp -from cuequivariance.misc.linalg import round_to_rational, round_to_sqrt_rational +from cuequivariance.etc.linalg import round_to_rational, round_to_sqrt_rational +from cuequivariance.etc.permutations import ( + generate_permutations_from, + inverse_permutation, +) +from cuequivariance.segmented_polynomials import Path, Subscripts from .dimensions_dict import format_dimensions_dict logger = logging.getLogger(__name__) -@dataclasses.dataclass(init=False, frozen=False) +@dataclasses.dataclass(init=False, frozen=True) class SegmentedTensorProduct: """ Irreps-agnostic and dataflow-agnostic descriptor of a segmented tensor product @@ -56,30 +60,66 @@ class SegmentedTensorProduct: .. rubric:: Methods """ - operands: list[stp.Operand] - paths: list[stp.Path] + operands_and_subscripts: tuple[tuple[cue.SegmentedOperand, str], ...] coefficient_subscripts: str + paths: tuple[Path, ...] ################################ Initializers ################################ + # From here we can use object.__setattr__ to modify the attributes def __init__( self, - *, - operands: Optional[list[stp.Operand]] = None, - paths: Optional[list[stp.Path]] = None, + operands_and_subscripts: Sequence[tuple[cue.SegmentedOperand | None, str]] + | None = None, coefficient_subscripts: str = "", + *, + paths: Sequence[Path] | None = None, ): - if operands is None: - operands = [] - self.operands = operands - + if operands_and_subscripts is None: + operands_and_subscripts = [] if paths is None: paths = [] - self.paths = paths - self.coefficient_subscripts = coefficient_subscripts + + operands_and_subscripts = tuple( + ( + ope.copy() if ope is not None else cue.SegmentedOperand(ndim=len(ss)), + str(ss), + ) + for ope, ss in operands_and_subscripts + ) + + object.__setattr__(self, "operands_and_subscripts", operands_and_subscripts) + object.__setattr__(self, "coefficient_subscripts", coefficient_subscripts) + object.__setattr__(self, "paths", tuple(paths)) + + def set_operand(self, oid: int, operand: cue.SegmentedOperand): + assert oid < len(self.operands_and_subscripts) + object.__setattr__( + self, + "operands_and_subscripts", + self.operands_and_subscripts[:oid] + + ((operand.copy(), self.operands_and_subscripts[oid][1]),) + + self.operands_and_subscripts[oid + 1 :], + ) + + def set_paths(self, paths: list[Path]): + object.__setattr__(self, "paths", tuple(copy.deepcopy(path) for path in paths)) + + def insert_path_(self, path_index: int, path: Path): + object.__setattr__( + self, + "paths", + self.paths[:path_index] + (copy.deepcopy(path),) + self.paths[path_index:], + ) + + # until here. Below we use dataclasses.replace or the setters to modify the attributes + + @property + def operands(self) -> tuple[cue.SegmentedOperand, ...]: + return tuple(ope for ope, _ in self.operands_and_subscripts) def assert_valid(self): - assert stp.Subscripts.is_valid(self.subscripts) + assert Subscripts.is_valid(self.subscripts) for m in self.subscripts.modes(): if self.subscripts.count(m) == 1: @@ -112,7 +152,7 @@ def assert_valid(self): ) @classmethod - def from_subscripts(cls, subscripts: stp.Subscripts) -> SegmentedTensorProduct: + def from_subscripts(cls, subscripts: Subscripts) -> SegmentedTensorProduct: r""" Create a descriptor from a subscripts string. @@ -126,11 +166,15 @@ def from_subscripts(cls, subscripts: stp.Subscripts) -> SegmentedTensorProduct: >>> print(d) uv,ui,vj+ij operands=[(2, 3)],[(2, 5)],[(3, 4)] paths=[op0[0]*op1[0]*op2[0]*c c.shape=(5, 4) c.nnz=20] """ - subscripts = stp.Subscripts(subscripts) - operands = [stp.Operand(subscripts=operand) for operand in subscripts.operands] + subscripts = Subscripts(subscripts) + operands = [ + cue.SegmentedOperand(ndim=len(operand)) for operand in subscripts.operands + ] return cls( - operands=operands, paths=[], coefficient_subscripts=subscripts.coefficients + operands_and_subscripts=list(zip(operands, subscripts.operands)), + paths=[], + coefficient_subscripts=subscripts.coefficients, ) @classmethod @@ -143,7 +187,9 @@ def empty_segments(cls, num_segments: list[int]) -> SegmentedTensorProduct: ,, sizes=2,3,4 num_segments=2,3,4 num_paths=0 """ return cls( - operands=[stp.Operand.empty_segments(num) for num in num_segments], + operands_and_subscripts=[ + (cue.SegmentedOperand.empty_segments(num), "") for num in num_segments + ], paths=[], coefficient_subscripts="", ) @@ -179,7 +225,33 @@ def from_base64(cls, data: str) -> SegmentedTensorProduct: def __hash__(self) -> int: return hash( - (tuple(self.operands), tuple(self.paths), self.coefficient_subscripts) + (self.operands_and_subscripts, self.paths, self.coefficient_subscripts) + ) + + def __eq__(self, value: SegmentedTensorProduct) -> bool: + assert isinstance(value, SegmentedTensorProduct) + return ( + self.operands_and_subscripts == value.operands_and_subscripts + and self.paths == value.paths + and self.coefficient_subscripts == value.coefficient_subscripts + ) + + def __lt__(self, value: SegmentedTensorProduct) -> bool: + assert isinstance(value, SegmentedTensorProduct) + return ( + self.num_operands, + self.num_paths, + self.subscripts, + self.operands, + self.paths, + self.coefficient_subscripts, + ) < ( + value.num_operands, + value.num_paths, + value.subscripts, + value.operands, + value.paths, + value.coefficient_subscripts, ) def __repr__(self) -> str: @@ -188,7 +260,7 @@ def __repr__(self) -> str: "[" + ",".join(map(str, operand.segments)) + "]" for operand in self.operands ) - return f"{self.subscripts} operands={operands} paths={self.paths}" + return f"{self.subscripts} operands={operands} paths={list(self.paths)}" sizes = ",".join(f"{operand.size}" for operand in self.operands) num_segments = ",".join(f"{len(operand)}" for operand in self.operands) output = f"{self.subscripts} sizes={sizes} num_segments={num_segments} num_paths={self.num_paths}" @@ -200,7 +272,7 @@ def __repr__(self) -> str: @property def num_operands(self) -> int: """Number of operands.""" - return len(self.operands) + return len(self.operands_and_subscripts) @property def num_paths(self) -> int: @@ -208,10 +280,10 @@ def num_paths(self) -> int: return len(self.paths) @property - def subscripts(self) -> stp.Subscripts: + def subscripts(self) -> Subscripts: """Subscripts of the tensor product.""" - return stp.Subscripts.from_operands( - [operand.subscripts for operand in self.operands], + return Subscripts.from_operands( + [subscripts for _, subscripts in self.operands_and_subscripts], self.coefficient_subscripts, ) @@ -262,7 +334,7 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str: ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1") - ... ).d + ... ).polynomial.operations[0][1] >>> d = d.flatten_coefficient_modes() >>> print(d.to_text()) uvw,u,v,w sizes=320,16,16,16 num_segments=5,4,4,4 num_paths=16 u=4 v=4 w=4 @@ -283,9 +355,9 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str: """ out = f"{self}" dims = self.get_dimensions_dict() - for oid, operand in enumerate(self.operands): - out += f"\noperand #{oid} subscripts={operand.subscripts}" - for i, ch in enumerate(operand.subscripts): + for oid, (operand, subscripts) in enumerate(self.operands_and_subscripts): + out += f"\noperand #{oid} subscripts={subscripts}" + for i, ch in enumerate(subscripts): if len(dims[ch]) == 1: out += f"\n | {ch}: [{operand.segments[0][i]}] * {len(operand.segments)}" else: @@ -295,8 +367,8 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str: + "]" ) - out += f"\nFlop cost: {' '.join(f'{oid}->{self.flop_cost(oid)}' for oid in range(self.num_operands))}" - out += f"\nMemory cost: {self.memory_cost('global')}" + out += f"\nFlop cost: {' '.join(f'{oid}->{self.flops(oid)}' for oid in range(self.num_operands))}" + out += f"\nMemory cost: {self.memory([1] * self.num_operands)}" if len(self.paths) > 0: out += "\nPath indices: " + ", ".join( @@ -343,21 +415,20 @@ def to_dict(self, extended: bool = False) -> dict[str, Any]: "coefficient_subscripts": self.coefficient_subscripts, "operands": [ { - "subscripts": ope.subscripts, + "subscripts": ss, "segments": ope.segments, "size": ope.size, "segment_offsets": [sl.start for sl in slices], "segment_sizes": [sl.stop - sl.start for sl in slices], - "flop_cost": self.flop_cost(oid), + "flops": self.flops(oid), } - for oid, ope, slices in zip( - range(self.num_operands), self.operands, segment_slices + for oid, (ope, ss), slices in zip( + range(self.num_operands), + self.operands_and_subscripts, + segment_slices, ) ], - "memory_cost": { - algorithm: self.memory_cost(algorithm) - for algorithm in ["sequential", "global"] - }, + "memory": self.memory([1] * self.num_operands), "paths": paths, } return extended_dict @@ -379,18 +450,19 @@ def to_base64(self, extended: bool = False) -> str: ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1") - ... ).d + ... ).polynomial.operations[0][1] >>> print(d.to_base64()) eJytkstuwjAQRX/F8r...lTF2zlX91/fHyvj2Z4= """ return base64.b64encode(self.to_bytes(extended)).decode("ascii") + @functools.cache def get_dimensions_dict(self) -> dict[str, set[int]]: """Get the dimensions of the tensor product.""" dims: dict[str, set[int]] = {ch: set() for ch in self.subscripts.modes()} - for operand in self.operands: - for m, dd in operand.get_dimensions_dict().items(): - dims[m].update(dd) + for operand, subscripts in self.operands_and_subscripts: + for i, dd in operand.get_dimensions_dict().items(): + dims[subscripts[i]].update(dd) # Note: no need to go through the coefficients since must be contracted with the operands return dims @@ -403,7 +475,7 @@ def get_dims(self, m: str) -> set[int]: ... cue.Irreps("SO3", "4x0+8x1"), ... cue.Irreps("SO3", "3x0+3x1"), ... cue.Irreps("SO3", "5x0+7x1") - ... ).d + ... ).polynomial.operations[0][1] >>> d.get_dims("u") {8, 4} >>> d.get_dims("v") @@ -414,7 +486,7 @@ def get_dims(self, m: str) -> set[int]: return self.get_dimensions_dict().get(m, set()) def get_path_dimensions_dict( - self, path: Union[int, stp.Path], *, returns_sets: bool = False + self, path: Union[int, Path], *, returns_sets: bool = False ) -> dict[str, Union[int, set[int]]]: """Get the dimensions of a specific path.""" if isinstance(path, int): @@ -424,7 +496,7 @@ def get_path_dimensions_dict( m: {d} for m, d in zip(self.coefficient_subscripts, path.coefficients.shape) } for oid, sid in enumerate(path.indices): - for m, d in zip(self.operands[oid].subscripts, self.operands[oid][sid]): + for m, d in zip(self.subscripts.operands[oid], self.operands[oid][sid]): dims.setdefault(m, set()).add(d) if returns_sets: @@ -433,7 +505,7 @@ def get_path_dimensions_dict( return {m: next(iter(dd)) for m, dd in dims.items()} def get_path_dim( - self, path: Union[int, stp.Path], m: str, *, returns_set=False + self, path: Union[int, Path], m: str, *, returns_set=False ) -> Union[int, set[int]]: """Get the dimension of a specific mode in a specific path.""" if isinstance(path, int): @@ -442,14 +514,14 @@ def get_path_dim( m, set() if returns_set else 0 ) - def segment_slice(self, operand: int, path: Union[int, stp.Path]) -> slice: + def segment_slice(self, operand: int, path: Union[int, Path]) -> slice: """Get the slice of the segment in the given operand selected by the given path.""" if isinstance(path, int): path = self.paths[path] return self.operands[operand].segment_slices()[path.indices[operand]] def get_segment_shape( - self, operand: int, path: Union[int, stp.Path] + self, operand: int, path: Union[int, Path] ) -> tuple[int, ...]: """Get the shape of the segment in the given operand selected by the given path.""" if isinstance(path, int): @@ -494,33 +566,48 @@ def compressed_path_segment(self, operand: int) -> np.ndarray: return np.append(0, np.bincount(i, minlength=n).cumsum()) - @functools.lru_cache(maxsize=None) + def operands_with_identical_segments(self) -> frozenset[frozenset[int]]: + """Groups of operands sharing the same segments.""" + operand_to_oid = collections.defaultdict(list) + for oid, ope in enumerate(self.operands): + operand_to_oid[ope].append(oid) + return frozenset(map(frozenset, operand_to_oid.values())) + + @functools.cache def symmetries(self) -> list[tuple[int, ...]]: """List of permutations that leave the tensor product invariant.""" - def functionally_equivalent( - d1: SegmentedTensorProduct, d2: SegmentedTensorProduct - ) -> bool: - d1 = d1.consolidate_paths().sort_paths() - d2 = d2.consolidate_paths().sort_paths() - return d1 == d2 + d = self.consolidate_paths() - ps = [] - for p in itertools.permutations(range(self.num_operands)): - if functionally_equivalent(self, self.permute_operands(p)): - ps.append(p) - return ps + ps = set() + for group in self.operands_with_identical_segments(): + group = sorted(group) + for p in itertools.permutations(range(len(group))): + p = tuple( + group[p[group.index(i)]] if i in group else i + for i in range(self.num_operands) + ) + if p in ps: + continue + if d == d.permute_operands(p).consolidate_paths(): + ps.add(p) + ps.add(inverse_permutation(p)) + ps = generate_permutations_from(ps) + return sorted(ps) def coefficients_equal_one(self) -> bool: """Check if all coefficients are equal to one.""" return np.all(self.stacked_coefficients == 1) - def flop_cost(self, operand: int, algorithm: str = "optimal") -> int: + def flops( + self, operand: int, batch_size: int = 1, algorithm: str = "optimal" + ) -> int: """ Compute the number of flops needed to compute the specified operand. Args: operand (int): The operand for which to compute the flop cost. + batch_size (int, optional): The batch size for the computation. Defaults to 1. algorithm (str, optional): The algorithm to use to compute the cost. Can be 'optimal' or 'naive'. Returns: @@ -529,12 +616,12 @@ def flop_cost(self, operand: int, algorithm: str = "optimal") -> int: d = self.move_operand_last(operand) subscripts = ( d.coefficient_subscripts - + "".join("," + operand.subscripts for operand in d.operands[:-1]) + + "".join("," + ss for ss in d.subscripts.operands[:-1]) + "->" - + d.operands[-1].subscripts + + d.subscripts.operands[-1] ) - @functools.lru_cache(maxsize=None) + @functools.cache def compute_cost(segment_shapes: tuple[tuple[int, ...], ...]) -> int: _, info = opt_einsum.contract_path( subscripts, *segment_shapes, optimize="optimal", shapes=True @@ -554,31 +641,12 @@ def compute_cost(segment_shapes: tuple[tuple[int, ...], ...]) -> int: dims = d.get_path_dimensions_dict(path) coeff_shape = tuple(dims[ch] for ch in d.coefficient_subscripts) cost += compute_cost((coeff_shape,) + shapes) - return cost - - def memory_cost(self, algorithm: str = "sequential") -> int: - """ - Compute the number of memory accesses needed to compute the specified operand. + return cost * batch_size - Args: - operand (int): The operand for which to compute the memory cost. - algorithm (str, optional): The algorithm to use to compute the cost. Can be 'sequential' or 'global'. - - Returns: - int: The number of memory accesses needed to compute the specified operand. - """ - if algorithm == "sequential": - return sum( - sum( - math.prod(self.get_segment_shape(oid, path)) - for oid in range(self.num_operands) - ) - for path in self.paths - ) - elif algorithm == "global": - return sum(operand.size for operand in self.operands) - else: - raise ValueError(f"unknown algorithm {algorithm}.") + def memory(self, batch_sizes: list[int]) -> int: + """Compute the memory usage of the tensor product.""" + assert len(batch_sizes) == self.num_operands + return sum(Z * ope.size for Z, ope in zip(batch_sizes, self.operands)) ################################ Modifiers ################################ @@ -616,7 +684,7 @@ def insert_path( dims.setdefault(m, set()).add(d) for oid, s in enumerate(segments): - subscripts = self.operands[oid].subscripts + subscripts = self.subscripts.operands[oid] if isinstance(s, int): if not ( -self.operands[oid].num_segments @@ -653,21 +721,19 @@ def insert_path( dims = {m: next(iter(dd)) for m, dd in dims.items()} - self.paths.insert( - path_index, - stp.Path( - [ - ( - (s + self.operands[oid].num_segments) - % self.operands[oid].num_segments - if isinstance(s, int) - else self.add_segment(oid, dims) - ) - for oid, s in enumerate(segments) - ], - coefficients, - ), + path = Path( + [ + ( + (s + self.operands[oid].num_segments) + % self.operands[oid].num_segments + if isinstance(s, int) + else self.add_segment(oid, dims) + ) + for oid, s in enumerate(segments) + ], + coefficients, ) + self.insert_path_(path_index, path) return path_index def add_path( @@ -712,69 +778,40 @@ def add_path( """ return self.insert_path(len(self.paths), *segments, c=c, dims=dims) - def insert_segment( - self, - operand: int, - sid: int, - segment: Union[tuple[int, ...], dict[str, int]], - ): - """Insert a segment at a specific index.""" - operand = _canonicalize_index("operand", operand, self.num_operands) - sid = _canonicalize_index("sid", sid, self.operands[operand].num_segments + 1) - - self.operands[operand].insert_segment(sid, segment) - self.paths = [ - stp.Path( - [ - s if s < sid or oid != operand else s + 1 - for oid, s in enumerate(path.indices) - ], - path.coefficients, - ) - for path in self.paths - ] - - def insert_segments( - self, - operand: int, - sid: int, - segments: Union[list[tuple[int, ...]], stp.Operand], - ): + def insert_segments(self, operand: int, sid: int, segments: list[tuple[int, ...]]): """Insert segments at a specific index.""" operand = _canonicalize_index("operand", operand, self.num_operands) sid = _canonicalize_index("sid", sid, self.operands[operand].num_segments + 1) - if not isinstance(segments, stp.Operand): - segments = stp.Operand( - subscripts=self.operands[operand].subscripts, segments=segments - ) - o = self.operands[operand] - if o.subscripts != segments.subscripts: - raise ValueError( - f"expected segments subscripts {segments.subscripts} to match operand subscripts {o.subscripts}." - ) - - self.operands[operand] = stp.Operand( - subscripts=o.subscripts, - segments=o.segments[:sid] + segments.segments + o.segments[sid:], - _dims={m: o.get_dims(m) | segments.get_dims(m) for m in o.subscripts}, + n = cue.SegmentedOperand(ndim=o.ndim, segments=segments) + self.set_operand( + operand, + cue.SegmentedOperand( + ndim=o.ndim, + segments=o.segments[:sid] + n.segments + o.segments[sid:], + _dims={m: o.get_dims(m) | n.get_dims(m) for m in range(o.ndim)}, + ), + ) + self.set_paths( + [ + Path( + [ + s if s < sid or oid != operand else s + n.num_segments + for oid, s in enumerate(path.indices) + ], + path.coefficients, + ) + for path in self.paths + ], ) - self.paths = [ - stp.Path( - [ - s if s < sid or oid != operand else s + segments.num_segments - for oid, s in enumerate(path.indices) - ], - path.coefficients, - ) - for path in self.paths - ] def add_segment( self, operand: int, segment: Union[tuple[int, ...], dict[str, int]] ) -> int: """Add a segment to the descriptor.""" + if isinstance(segment, dict): + segment = tuple(segment[m] for m in self.subscripts.operands[operand]) return self.operands[operand].add_segment(segment) def add_segments( @@ -795,7 +832,7 @@ def canonicalize_subscripts(self) -> SegmentedTensorProduct: This is useful to identify equivalent descriptors. """ - subscripts = stp.Subscripts.canonicalize(self.subscripts) + subscripts = Subscripts.canonicalize(self.subscripts) return self.add_or_rename_modes(subscripts) def add_or_rename_modes( @@ -812,15 +849,16 @@ def add_or_rename_modes( Returns: SegmentedTensorProduct: The new descriptor with the renamed modes. """ - subscripts = stp.Subscripts(subscripts) + subscripts = Subscripts(subscripts) if subscripts.is_equivalent(self.subscripts): d = SegmentedTensorProduct.from_subscripts(subscripts) for oid, operand in enumerate(self.operands): d.add_segments(oid, operand.segments) for path in self.paths: - d.paths.append( - stp.Path(indices=path.indices, coefficients=path.coefficients) + d.insert_path_( + len(d.paths), + Path(indices=path.indices, coefficients=path.coefficients), ) return d @@ -839,7 +877,7 @@ def add_or_rename_modes( ) mapping = mappings[0] - if not all(ch in mapping for ope in self.operands for ch in ope.subscripts): + if not all(ch in mapping for ss in self.subscripts.operands for ch in ss): raise ValueError( f"expected all segment modes to be in the mapping {mapping}." ) @@ -852,24 +890,25 @@ def add_or_rename_modes( for oid in range(self.num_operands): for sid in range(self.operands[oid].num_segments): dims = collections.defaultdict(lambda: 1) - for m, d in zip(self.operands[oid].subscripts, self.operands[oid][sid]): + for m, d in zip(self.subscripts.operands[oid], self.operands[oid][sid]): dims[mapping[m]] = d D.add_segment(oid, dims) for path in self.paths: dims: dict[str, int] = dict() for oid, sid in enumerate(path.indices): - for m, d in zip(D.operands[oid].subscripts, D.operands[oid][sid]): + for m, d in zip(D.subscripts.operands[oid], D.operands[oid][sid]): dims[m] = d - D.paths.append( - stp.Path( + D.insert_path_( + len(D.paths), + Path( indices=path.indices, coefficients=np.reshape( path.coefficients, tuple(dims[m] for m in D.coefficient_subscripts), ), - ) + ), ) return D @@ -887,7 +926,7 @@ def add_or_transpose_modes( Returns: SegmentedTensorProduct: The new descriptor with the transposed modes. """ - subscripts = stp.Subscripts.complete_wildcards(subscripts, self.subscripts) + subscripts = Subscripts.complete_wildcards(subscripts, self.subscripts) if dims is None: dims = dict() @@ -911,8 +950,8 @@ def add_or_transpose_modes( d = SegmentedTensorProduct.from_subscripts(subscripts) for oid in range(self.num_operands): old, new = ( - self.operands[oid].subscripts, - d.operands[oid].subscripts, + self.subscripts.operands[oid], + d.subscripts.operands[oid], ) perm = [old.index(ch) if ch in old else ch for ch in new] for sid in range(self.operands[oid].num_segments): @@ -925,13 +964,15 @@ def add_or_transpose_modes( ) old, new = self.coefficient_subscripts, d.coefficient_subscripts perm = [old.index(ch) for ch in new] - for path in self.paths: - d.paths.append( - stp.Path( + d.set_paths( + [ + Path( indices=path.indices, coefficients=np.transpose(path.coefficients, perm), ) - ) + for path in self.paths + ] + ) return d def append_modes_to_all_operands( @@ -953,7 +994,7 @@ def append_modes_to_all_operands( ) if not all(ch in dims for ch in modes): raise ValueError(f"expected dimensions for all new modes {modes}.") - subscripts = stp.Subscripts.from_operands( + subscripts = Subscripts.from_operands( [ope + modes for ope in self.subscripts.operands], coefficients=self.subscripts.coefficients, ) @@ -961,10 +1002,10 @@ def append_modes_to_all_operands( def permute_operands(self, perm: tuple[int, ...]) -> SegmentedTensorProduct: """Permute the operands of the descriptor.""" - assert set(perm) == set(range(self.num_operands)) + # assert set(perm) == set(range(self.num_operands)) # removed for performance return dataclasses.replace( self, - operands=[self.operands[i] for i in perm], + operands_and_subscripts=[self.operands_and_subscripts[i] for i in perm], paths=[path.permute_operands(perm) for path in self.paths], ) @@ -991,13 +1032,13 @@ def permute_segments( ) -> SegmentedTensorProduct: """Permute the segments of an operand.""" operand = _canonicalize_index("operand", operand, self.num_operands) - new_operands = self.operands.copy() - new_operands[operand] = stp.Operand( + new_operands = list(self.operands) + new_operands[operand] = cue.SegmentedOperand( + ndim=self.operands[operand].ndim, segments=[self.operands[operand][i] for i in perm], - subscripts=self.operands[operand].subscripts, ) new_paths = [ - stp.Path( + Path( indices=tuple( sid if oid != operand else perm.index(sid) for oid, sid in enumerate(path.indices) @@ -1006,7 +1047,11 @@ def permute_segments( ) for path in self.paths ] - return dataclasses.replace(self, operands=new_operands, paths=new_paths) + return dataclasses.replace( + self, + operands_and_subscripts=list(zip(new_operands, self.subscripts.operands)), + paths=new_paths, + ) def sort_paths( self, operands_ordering: Optional[Union[int, Sequence[int]]] = None @@ -1068,13 +1113,14 @@ def filter_shape(shape: tuple[int, ...], subscripts: str) -> tuple[int, ...]: d.add_segments( oid, [ - filter_shape(segment, operand.subscripts) + filter_shape(segment, self.subscripts.operands[oid]) for segment in operand.segments ], ) for path in self.paths: - d.paths.append( - stp.Path( + d.insert_path_( + len(d.paths), + Path( indices=path.indices, coefficients=np.reshape( path.coefficients, @@ -1082,7 +1128,7 @@ def filter_shape(shape: tuple[int, ...], subscripts: str) -> tuple[int, ...]: path.coefficients.shape, self.coefficient_subscripts ), ), - ) + ), ) logger.debug(f"Squeezed {self} to {d}") return d @@ -1104,10 +1150,10 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: if mode not in self.subscripts: return self - for oid, operand in enumerate(self.operands): - if mode in operand.subscripts and not operand.subscripts.startswith(mode): + for oid, ss in enumerate(self.subscripts.operands): + if mode in ss and not ss.startswith(mode): raise ValueError( - f"mode {mode} is not the first mode in operand {oid} ({operand.subscripts})." + f"mode {mode} is not the first mode in operand {oid} ({ss})." ) if not all(dim % size == 0 for dim in self.get_dims(mode)): @@ -1119,7 +1165,7 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: ): return ( self.add_or_transpose_modes( - stp.Subscripts.from_operands( + Subscripts.from_operands( self.subscripts.operands, mode + self.coefficient_subscripts.replace(mode, ""), ) @@ -1131,14 +1177,14 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: d = SegmentedTensorProduct.from_subscripts(self.subscripts) offsets_per_operand = [] - for oid, operand in enumerate(self.operands): - if mode not in operand.subscripts: + for oid, (operand, ss) in enumerate(self.operands_and_subscripts): + if mode not in ss: for segment in operand: d.add_segment(oid, segment) offsets_per_operand.append(None) continue - assert operand.subscripts.startswith(mode) + assert ss.startswith(mode) offsets = [] for segment in operand: @@ -1162,7 +1208,9 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: coefficients = path.coefficients if self.coefficient_subscripts.startswith(mode): coefficients = np.split(coefficients, num_subdivisions, axis=0)[i] - d.paths.append(stp.Path(indices=indices, coefficients=coefficients)) + d.insert_path_( + len(d.paths), Path(indices=indices, coefficients=coefficients) + ) logger.debug(f"Split {mode} in {self}: got {d}") return d @@ -1195,7 +1243,7 @@ def normalize_paths_for_operand(self, operand: int) -> SegmentedTensorProduct: # ^^^^^^^-- (0, "uvw") (1, "iu") (2, "jv") (3, "kw") (4, "ijk") if i != operand # e.g. discard (3, "kw") for u in subscripts # u v w i u j v i j k - if u not in self.operands[operand].subscripts # e.g. discard k and w + if u not in self.subscripts.operands[operand] # e.g. discard k and w } # e.g. {u, v, i, j} for path in self.paths: @@ -1211,7 +1259,7 @@ def normalize_paths_for_operand(self, operand: int) -> SegmentedTensorProduct: c = path.coefficients / np.sqrt(total_variance) else: c = 0.0 * path.coefficients - new_paths.append(stp.Path(indices=path.indices, coefficients=c)) + new_paths.append(Path(indices=path.indices, coefficients=c)) d = dataclasses.replace(self, paths=new_paths) logger.debug(f"Normalized paths for operand {operand} in {self}: got {d}") @@ -1228,23 +1276,36 @@ def fuse_paths_with_same_indices(self) -> SegmentedTensorProduct: """Fuse paths with the same indices.""" paths = dict() for path in self.paths: - indices = tuple(path.indices) - if indices in paths: - paths[indices] += path.coefficients + if path.indices in paths: + paths[path.indices] += path.coefficients else: - paths[indices] = path.coefficients + paths[path.indices] = path.coefficients return dataclasses.replace( self, paths=[ - stp.Path(indices=indices, coefficients=coefficients) + Path(indices=indices, coefficients=coefficients) for indices, coefficients in paths.items() ], ) + @functools.cache def consolidate_paths(self) -> SegmentedTensorProduct: """Consolidate the paths by merging duplicates and removing zeros.""" - return self.fuse_paths_with_same_indices().remove_zero_paths() + # equivalent to self.fuse_paths_with_same_indices().remove_zero_paths().sort_paths() + paths = dict() + for path in self.paths: + if path.indices in paths: + paths[path.indices] += path.coefficients + else: + paths[path.indices] = path.coefficients + paths = [ + Path(indices=indices, coefficients=coefficients) + for indices, coefficients in paths.items() + if not np.all(coefficients == 0) + ] + paths = sorted(paths, key=lambda path: path.indices) + return dataclasses.replace(self, paths=paths) def sort_indices_for_identical_operands( self, operands: Sequence[int] @@ -1268,7 +1329,7 @@ def f(ii: tuple[int, ...]) -> tuple[int, ...]: return dataclasses.replace( self, paths=[ - stp.Path( + Path( indices=f(path.indices), coefficients=path.coefficients, ) @@ -1276,13 +1337,28 @@ def f(ii: tuple[int, ...]) -> tuple[int, ...]: ], ).consolidate_paths() - def symmetrize_operands(self, operands: Sequence[int]) -> SegmentedTensorProduct: + def symmetrize_operands( + self, operands: Sequence[int], force: bool = False + ) -> SegmentedTensorProduct: """Symmetrize the specified operands permuting the indices.""" operands = sorted(set(operands)) if len(operands) < 2: return self permutations = list(itertools.permutations(range(len(operands)))) + + def make_global_perm(perm: tuple[int, ...]) -> tuple[int, ...]: + p = list(range(self.num_operands)) + for i, j in enumerate(perm): + p[operands[i]] = operands[j] + return tuple(p) + + if not force: + # check if the tensor product is already symmetric + symmetries: list[tuple[int, ...]] = self.symmetries() + if all(make_global_perm(perm) in symmetries for perm in permutations): + return self + d = self.sort_indices_for_identical_operands(operands) paths = [] @@ -1294,7 +1370,7 @@ def symmetrize_operands(self, operands: Sequence[int]) -> SegmentedTensorProduct for i, oid in enumerate(operands): new_indices[oid] = indices[operands[perm[i]]] paths.append( - stp.Path( + Path( indices=new_indices, coefficients=path.coefficients / len(permutations), ) @@ -1318,25 +1394,28 @@ def empty(D, oid, sid): perm.append(sid) D = D.permute_segments(oid, perm) - D.paths = [ - path - for path in D.paths - if not any(empty(D, oid, sid) for oid, sid in enumerate(path.indices)) - ] + D.set_paths( + [ + path + for path in D.paths + if not any(empty(D, oid, sid) for oid, sid in enumerate(path.indices)) + ] + ) for oid in range(D.num_operands): - D.operands[oid] = stp.Operand( - subscripts=D.operands[oid].subscripts, + operand = cue.SegmentedOperand( + ndim=D.operands[oid].ndim, segments=[ segment for sid, segment in enumerate(D.operands[oid].segments) if not empty(D, oid, sid) ], _dims={ - m: {d for d in dd if d > 0} - for m, dd in D.operands[oid]._dims.items() + i: {d for d in dd if d > 0} + for i, dd in D.operands[oid]._dims.items() }, ) + D.set_operand(oid, operand) return D @@ -1360,8 +1439,7 @@ def flatten_modes( if force: extra = "" for m in modes: - for ope in self.operands: - sm = ope.subscripts + for _, sm in self.operands_and_subscripts: if m in sm: extra += sm[: sm.index(m)] modes += extra @@ -1371,16 +1449,16 @@ def flatten_modes( return self pattern = re.compile(rf"^([{modes}]*)([^{modes}]*)$") - new_operands = [] + new_operands_and_subscripts = [] offsets_per_operand = [] rm_shape_per_operand = [] rm_modes_per_operand = [] - for operand in self.operands: - ma = pattern.match(operand.subscripts) + for operand, subscripts in self.operands_and_subscripts: + ma = pattern.match(subscripts) if ma is None: raise ValueError( f"expected modes {modes} to be at the beginning of the segment subscripts." - f" Got {operand.subscripts}." + f" Got {subscripts}." ) rm_modes, new_subscripts = ma.groups() @@ -1402,8 +1480,13 @@ def flatten_modes( offsets_per_operand.append(offsets) rm_shape_per_operand.append(rm_shapes) - new_operands.append( - stp.Operand(segments=new_segments, subscripts=new_subscripts) + new_operands_and_subscripts.append( + ( + cue.SegmentedOperand( + ndim=len(new_subscripts), segments=new_segments + ), + new_subscripts, + ) ) def ravel_multi_index(indices: tuple[int, ...], shape: tuple[int, ...]) -> int: @@ -1416,8 +1499,8 @@ def make_new_path( old_indices: tuple[int, ...], # old segment indices (one per operand) sub_indices: dict[str, int], coefficients: np.ndarray, - ) -> stp.Path: - return stp.Path( + ) -> Path: + return Path( [ offsets[sid] + ravel_multi_index( @@ -1480,7 +1563,7 @@ def make_new_path( ) d = dataclasses.replace( self, - operands=new_operands, + operands_and_subscripts=new_operands_and_subscripts, paths=new_paths, coefficient_subscripts="".join( ch for ch in self.coefficient_subscripts if ch not in modes @@ -1514,11 +1597,11 @@ def consolidate_modes(self, modes: Optional[str] = None) -> SegmentedTensorProdu # look for opportunities to consolidate for m in self.subscripts.modes(): neighbors: set[str] = set() - for operand in self.operands: - if m in operand.subscripts: - i = operand.subscripts.index(m) - if i < len(operand.subscripts) - 1: - neighbors.add(operand.subscripts[i + 1]) + for _, subscripts in self.operands_and_subscripts: + if m in subscripts: + i = subscripts.index(m) + if i < len(subscripts) - 1: + neighbors.add(subscripts[i + 1]) else: neighbors.add(".") if len(neighbors) != 1: @@ -1529,13 +1612,13 @@ def consolidate_modes(self, modes: Optional[str] = None) -> SegmentedTensorProdu # Zuvw_Ziu_Zjv_Zkw+ijk ok = True - for operand in self.operands: - if m in operand.subscripts: - if n not in operand.subscripts: + for _, subscripts in self.operands_and_subscripts: + if m in subscripts: + if n not in subscripts: ok = False break else: - if n in operand.subscripts: + if n in subscripts: ok = False break @@ -1553,11 +1636,11 @@ def consolidate_modes(self, modes: Optional[str] = None) -> SegmentedTensorProdu return self - for operand in self.operands: - if modes not in operand.subscripts: - if any(ch in operand.subscripts for ch in modes): + for _, subscripts in self.operands_and_subscripts: + if modes not in subscripts: + if any(ch in subscripts for ch in modes): raise ValueError( - f"expected {modes} to be contiguous in the subscripts {operand.subscripts}." + f"expected {modes} to be contiguous in the subscripts {subscripts}." ) d = self @@ -1582,10 +1665,10 @@ def _consolidate_pair_of_modes(self, m: str, n: str) -> SegmentedTensorProduct: d1 = SegmentedTensorProduct.from_subscripts(d0.subscripts.replace(m + n, m)) - for oid, operand in enumerate(d0.operands): + for oid, (operand, subscripts) in enumerate(d0.operands_and_subscripts): for segment in operand: - if m in operand.subscripts: - i = operand.subscripts.index(m) + if m in subscripts: + i = subscripts.index(m) segment = list(segment) segment[i] *= segment[i + 1] segment.pop(i + 1) @@ -1599,9 +1682,11 @@ def _consolidate_pair_of_modes(self, m: str, n: str) -> SegmentedTensorProduct: c = np.reshape( c, c.shape[:i] + (c.shape[i] * c.shape[i + 1],) + c.shape[i + 2 :] ) - d1.paths.append(stp.Path(indices=path.indices, coefficients=c)) + d1.insert_path_( + len(d1.paths), Path(indices=path.indices, coefficients=c) + ) else: - d1.paths = copy.deepcopy(d0.paths) + d1.set_paths(d0.paths) return d1 @@ -1615,13 +1700,15 @@ def round_coefficients_to_rational( max_denominator (int): The maximum denominator, ``q < max_denominator``. """ d = copy.deepcopy(self) - d.paths = [ - stp.Path( - indices=path.indices, - coefficients=round_to_rational(path.coefficients, max_denominator), - ) - for path in d.paths - ] + d.set_paths( + [ + Path( + indices=path.indices, + coefficients=round_to_rational(path.coefficients, max_denominator), + ) + for path in d.paths + ] + ) return d def round_coefficients_to_sqrt_rational( @@ -1634,13 +1721,17 @@ def round_coefficients_to_sqrt_rational( max_denominator (int): The maximum denominator, ``q < max_denominator``. """ d = copy.deepcopy(self) - d.paths = [ - stp.Path( - indices=path.indices, - coefficients=round_to_sqrt_rational(path.coefficients, max_denominator), - ) - for path in d.paths - ] + d.set_paths( + [ + Path( + indices=path.indices, + coefficients=round_to_sqrt_rational( + path.coefficients, max_denominator + ), + ) + for path in d.paths + ] + ) return d def modify_coefficients( @@ -1653,10 +1744,12 @@ def modify_coefficients( f (callable): The function to apply to the coefficients. """ d = copy.deepcopy(self) - d.paths = [ - stp.Path(indices=path.indices, coefficients=f(path.coefficients)) - for path in d.paths - ] + d.set_paths( + [ + Path(indices=path.indices, coefficients=f(path.coefficients)) + for path in d.paths + ] + ) return d def __mul__(self, factor: float) -> SegmentedTensorProduct: diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/subscripts.py b/cuequivariance/cuequivariance/segmented_polynomials/subscripts.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/subscripts.py rename to cuequivariance/cuequivariance/segmented_polynomials/subscripts.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/operand.py b/cuequivariance/cuequivariance/segmented_tensor_product/operand.py deleted file mode 100644 index 67e21f1a..00000000 --- a/cuequivariance/cuequivariance/segmented_tensor_product/operand.py +++ /dev/null @@ -1,218 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import dataclasses -import math -from typing import Optional, Sequence, Union - -from cuequivariance import segmented_tensor_product as stp - -from .dimensions_dict import format_dimensions_dict - - -@dataclasses.dataclass(init=False, frozen=True) -class Operand: - """A tensor product operand. It is a list of segments and subscripts.""" - - _segments: list[tuple[int, ...]] - subscripts: stp.Subscripts - _dims: dict[str, set[int]] - - def __init__( - self, - *, - subscripts: stp.Subscripts, - segments: Optional[list[tuple[int, ...]]] = None, - _dims: Optional[dict[str, set[int]]] = None, - ): - object.__setattr__(self, "subscripts", stp.Subscripts(subscripts)) - - if segments is None: - segments = [] - object.__setattr__(self, "_segments", segments) - - if _dims is None: - _dims = dict() - for segment in self.segments: - for m, d in zip(self.subscripts, segment): - _dims.setdefault(m, set()).add(d) - object.__setattr__(self, "_dims", _dims) - - @classmethod - def empty_segments(cls, num_segments: int) -> Operand: - """Create an operand with empty subscripts""" - return cls(subscripts="", segments=[()] * num_segments, _dims=dict()) - - def assert_valid(self): - """Assert that the operand is valid.""" - if not all(m.isalpha() and m.islower() for m in self.subscripts): - raise ValueError(f"subscripts {self.subscripts} is not valid.") - - for segment in self.segments: - if not all(isinstance(dim, int) and dim > 0 for dim in segment): - raise ValueError(f"segment {segment} is not valid.") - - if len(segment) != len(self.subscripts): - raise ValueError( - f"segment {segment} has {len(segment)} dimensions, expected {len(self.subscripts)} for subscripts {self.subscripts}." - ) - - for m, d in zip(self.subscripts, segment): - if d not in self.get_dims(m): - raise ValueError( - f"dimension {d} not in {m} dimensions {self.get_dims(m)}." - ) - - def insert_segment( - self, index: int, segment: Union[tuple[int, ...], dict[str, int]] - ): - """Insert a segment at a given index.""" - if isinstance(segment, dict): - segment = tuple(segment[m] for m in self.subscripts) - - if len(segment) != len(self.subscripts): - raise ValueError( - f"segment has {len(segment)} dimensions, expected {len(self.subscripts)} for subscripts {self.subscripts}." - ) - - segment = tuple(int(d) for d in segment) - self._segments.insert(index, segment) - - for m, d in zip(self.subscripts, segment): - self._dims.setdefault(m, set()).add(d) - - def add_segment(self, segment: Union[tuple[int, ...], dict[str, int]]) -> int: - """Add a segment to the operand.""" - self.insert_segment(len(self.segments), segment) - return len(self.segments) - 1 - - def __hash__(self) -> int: - return hash((tuple(self.segments), self.subscripts)) - - def __eq__(self, other: Operand) -> bool: - return self.subscripts == other.subscripts and self.segments == other.segments - - def __repr__(self) -> str: - dims = format_dimensions_dict(self.get_dimensions_dict()) - return f"Operand(subscripts={self.subscripts} num_segments={self.num_segments} {dims})" - - def __getitem__(self, index: int) -> tuple[int, ...]: - return self.segments[index] - - def __len__(self) -> int: - return self.num_segments - - def __iter__(self): - return iter(self.segments) - - @property - def segments(self) -> tuple[tuple[int, ...], ...]: - """The segments of the operand.""" - return tuple(self._segments) - - @property - def num_segments(self) -> int: - """The number of segments in the operand.""" - return len(self.segments) - - @property - def size(self) -> int: - """The total size of the operand.""" - if self.all_same_segment_shape(): - return self.num_segments * self.segment_size - - return sum(math.prod(segment) for segment in self.segments) - - @property - def ndim(self) -> int: - """The number of segment dimensions.""" - return len(self.subscripts) - - def segment_slices(self) -> list[slice]: - """Return slice object for each segment.""" - offset = 0 - slices = [] - for segment in self.segments: - slices.append(slice(offset, offset + math.prod(segment))) - offset += math.prod(segment) - return slices - - def get_dimensions_dict(self) -> dict[str, set[int]]: - """Return a dictionary of dimensions for each channel.""" - return self._dims.copy() - - def get_dims(self, m: str) -> set[int]: - """Return the dimensions for a given channel.""" - return self._dims.get(m, set()).copy() - - def get_segment_shape( - self, dims: dict[str, int], *, default: int = -1 - ) -> tuple[int, ...]: - """Return the shape of a potential segment.""" - return tuple(dims.get(ch, default) for ch in self.subscripts) - - def transpose_modes( - self, subscripts: Union[str, Sequence[str], Sequence[int]] - ) -> Operand: - """Transpose the channels of the operand.""" - if not isinstance(subscripts, Sequence): - raise TypeError("channels must be a sequence.") - - if isinstance(subscripts[0], str): - subscripts = "".join(subscripts) - subscripts = stp.Subscripts.complete_wildcards(subscripts, self.subscripts) - subscripts = [self.subscripts.index(m) for m in subscripts] - - subscripts: list[int] = list(subscripts) - - if len(subscripts) != len(self.subscripts): - raise ValueError( - f"channels has {len(subscripts)} dimensions, expected {len(self.subscripts)} for subscripts {self.subscripts}." - ) - - segments = [tuple(segment[i] for i in subscripts) for segment in self.segments] - return Operand( - subscripts="".join(self.subscripts[i] for i in subscripts), - segments=segments, - _dims=self._dims, - ) - - def all_same_segment_shape(self) -> bool: - """Check if all segments have the same shape. Returns False if there are no segments.""" - return all(len(dd) == 1 for dd in self._dims.values()) and self.num_segments > 0 - - @property - def segment_shape(self) -> tuple[int, ...]: - """The shape of the segments if they are all the same.""" - if not self.all_same_segment_shape(): - raise ValueError("Segments do not have the same shape.") - return self.segments[0] - - @property - def segment_size(self) -> int: - """The size of the segments if they are all the same.""" - if not self.all_same_segment_shape(): - raise ValueError("Segments do not have the same shape.") - return math.prod(self.segments[0]) - - def __add__(self, other: Operand) -> Operand: - if self.subscripts != other.subscripts: - raise ValueError("subscripts do not match.") - return Operand( - subscripts=self.subscripts, - segments=self.segments + other.segments, - _dims={m: self.get_dims(m) | other.get_dims(m) for m in self.subscripts}, - ) diff --git a/cuequivariance/tests/equivariant_tensor_products_test.py b/cuequivariance/tests/equivariant_tensor_products_test.py deleted file mode 100644 index 14cd90b0..00000000 --- a/cuequivariance/tests/equivariant_tensor_products_test.py +++ /dev/null @@ -1,96 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import pytest - -import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp -from cuequivariance import descriptors - - -def test_commutativity_squeeze_flatten(): - irreps1 = cue.Irreps("O3", "32x0e + 32x1o") - irreps2 = cue.Irreps("O3", "1x0e + 1x1o") - irreps3 = cue.Irreps("O3", "32x0e + 32x1o") - - d = descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3).d - assert ( - d.squeeze_modes().flatten_coefficient_modes() - == d.flatten_coefficient_modes().squeeze_modes() - ) - - d = descriptors.full_tensor_product(irreps1, irreps2, irreps3).d - assert ( - d.squeeze_modes().flatten_coefficient_modes() - == d.flatten_coefficient_modes().squeeze_modes() - ) - - d = descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3).d - assert ( - d.squeeze_modes().flatten_coefficient_modes() - == d.flatten_coefficient_modes().squeeze_modes() - ) - - d = descriptors.linear(irreps1, irreps2).d - assert ( - d.squeeze_modes().flatten_coefficient_modes() - == d.flatten_coefficient_modes().squeeze_modes() - ) - - -@pytest.mark.parametrize("ell", [1, 2, 3, 4]) -def test_spherical_harmonics(ell: int): - d = descriptors.spherical_harmonics(cue.SO3(1), [ell]).d - - vec = np.random.randn(3) - axis = np.random.randn(3) - angle = np.random.rand() - - yl = stp.compute_last_operand(d, *(vec,) * ell) - - R = cue.SO3(1).rotation(axis, angle) - Rl = cue.SO3(ell).rotation(axis, angle) - - yl1 = stp.compute_last_operand(d, *(R @ vec,) * ell) - yl2 = Rl @ yl - - np.testing.assert_allclose(yl1, yl2) - np.testing.assert_allclose(np.sum(yl**2), (2 * ell + 1) * np.sum(vec**2) ** ell) - - -@pytest.mark.parametrize("ell", [0, 1, 2, 3, 4]) -def test_y_rotation(ell: int): - alpha = 0.3 - beta = 0.4 - gamma = -0.5 - - irrep = cue.SO3(ell) - d = descriptors.yxy_rotation(cue.Irreps("SO3", [irrep])).d - - def enc(th: float): - m = np.arange(1, ell + 1) - c = np.cos(m * th) - s = np.sin(m * th) - return np.concatenate([c[::-1], [1.0], s]) - - x = np.random.randn(irrep.dim) - y1 = stp.compute_last_operand(d, enc(gamma), enc(beta), enc(alpha), x) - - A = irrep.rotation(np.array([0.0, 1.0, 0.0]), alpha) - B = irrep.rotation(np.array([1.0, 0.0, 0.0]), beta) - C = irrep.rotation(np.array([0.0, 1.0, 0.0]), gamma) - y2 = A @ B @ C @ x - - np.testing.assert_allclose(y1, y2) diff --git a/cuequivariance/tests/linalg/round_to_rational_test.py b/cuequivariance/tests/etc/round_to_rational_test.py similarity index 97% rename from cuequivariance/tests/linalg/round_to_rational_test.py rename to cuequivariance/tests/etc/round_to_rational_test.py index 33dbf486..d488ac41 100644 --- a/cuequivariance/tests/linalg/round_to_rational_test.py +++ b/cuequivariance/tests/etc/round_to_rational_test.py @@ -14,7 +14,7 @@ # limitations under the License. import numpy as np -from cuequivariance.misc import linalg +from cuequivariance.etc import linalg def test_as_approx_integer_ratio(): diff --git a/cuequivariance/tests/context_test.py b/cuequivariance/tests/group_theory/context_test.py similarity index 100% rename from cuequivariance/tests/context_test.py rename to cuequivariance/tests/group_theory/context_test.py diff --git a/cuequivariance/tests/group_theory/equivariant_polynomial_test.py b/cuequivariance/tests/group_theory/equivariant_polynomial_test.py new file mode 100644 index 00000000..392b8091 --- /dev/null +++ b/cuequivariance/tests/group_theory/equivariant_polynomial_test.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +import cuequivariance as cue + + +def test_transpose(): + cue.descriptors.transpose(cue.Irreps(cue.O3, "3x1e"), cue.mul_ir, cue.ir_mul) + + +def test_commutativity_squeeze_flatten(): + irreps1 = cue.Irreps("O3", "32x0e + 32x1o") + irreps2 = cue.Irreps("O3", "1x0e + 1x1o") + irreps3 = cue.Irreps("O3", "32x0e + 32x1o") + + poly = cue.descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3) + assert ( + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() + ) + + poly = cue.descriptors.full_tensor_product(irreps1, irreps2, irreps3) + assert ( + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() + ) + + poly = cue.descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3) + assert ( + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() + ) + + poly = cue.descriptors.linear(irreps1, irreps2) + assert ( + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() + ) + + +@pytest.mark.parametrize("ell", [1, 2, 3, 4]) +def test_spherical_harmonics(ell: int): + poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [ell]) + + vec = np.random.randn(3) + axis = np.random.randn(3) + angle = np.random.rand() + + [yl] = poly(vec) + + R = cue.SO3(1).rotation(axis, angle) + Rl = cue.SO3(ell).rotation(axis, angle) + + [yl1] = poly(R @ vec) + yl2 = Rl @ yl + + np.testing.assert_allclose(yl1, yl2) + np.testing.assert_allclose(np.sum(yl**2), (2 * ell + 1) * np.sum(vec**2) ** ell) + + +@pytest.mark.parametrize("ell", [0, 1, 2, 3, 4]) +def test_y_rotation(ell: int): + alpha = 0.3 + beta = 0.4 + gamma = -0.5 + + irrep = cue.SO3(ell) + poly = cue.descriptors.yxy_rotation(cue.Irreps("SO3", [irrep])) + + def enc(th: float): + m = np.arange(1, ell + 1) + c = np.cos(m * th) + s = np.sin(m * th) + return np.concatenate([c[::-1], [1.0], s]) + + x = np.random.randn(irrep.dim) + [y1] = poly(enc(gamma), enc(beta), enc(alpha), x) + + A = irrep.rotation(np.array([0.0, 1.0, 0.0]), alpha) + B = irrep.rotation(np.array([1.0, 0.0, 0.0]), beta) + C = irrep.rotation(np.array([0.0, 1.0, 0.0]), gamma) + y2 = A @ B @ C @ x + + np.testing.assert_allclose(y1, y2) + + +def test_elementwise_tensor_product(): + irreps1 = cue.Irreps("O3", "8x0e + 7x1o + 4x2e") + irreps2 = cue.Irreps("O3", "8x0e + 7x1o + 4x0e") + irreps3 = cue.Irreps("O3", "0e + 1o + 1e") + poly = cue.descriptors.elementwise_tensor_product(irreps1, irreps2, irreps3) + assert ( + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() + ) + assert poly.num_inputs == 2 + assert poly.num_outputs == 1 + + +def test_symmetric_contraction(): + irreps_in = 16 * cue.Irreps("SO3", "0 + 1 + 2") + irreps_out = 16 * cue.Irreps("SO3", "0 + 1") + poly = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, [0, 1, 2, 3]) + assert poly.num_inputs == 2 + assert poly.num_outputs == 1 + + [_, _, _, (_, d)] = poly.polynomial.operations + assert d.num_paths == 437 + + poly = poly.polynomial.unsymmetrize_for_identical_operands() + [_, _, _, (_, d)] = poly.operations + assert d.num_paths == 105 diff --git a/cuequivariance/tests/experimental/escn_test.py b/cuequivariance/tests/group_theory/experimental/escn_test.py similarity index 92% rename from cuequivariance/tests/experimental/escn_test.py rename to cuequivariance/tests/group_theory/experimental/escn_test.py index 2b770f4f..bb5afb5d 100644 --- a/cuequivariance/tests/experimental/escn_test.py +++ b/cuequivariance/tests/group_theory/experimental/escn_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import cuequivariance as cue -from cuequivariance.experimental.escn import escn_tp, escn_tp_compact +from cuequivariance.group_theory.experimental.escn import escn_tp, escn_tp_compact def test_escn(): diff --git a/cuequivariance/tests/experimental/gatr_test.py b/cuequivariance/tests/group_theory/experimental/gatr_test.py similarity index 93% rename from cuequivariance/tests/experimental/gatr_test.py rename to cuequivariance/tests/group_theory/experimental/gatr_test.py index cef872de..dc55373f 100644 --- a/cuequivariance/tests/experimental/gatr_test.py +++ b/cuequivariance/tests/group_theory/experimental/gatr_test.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from cuequivariance.experimental.gatr import ( +from cuequivariance.group_theory.experimental.gatr import ( gatr_geometric_product, gatr_linear, gatr_outer_product, diff --git a/cuequivariance/tests/experimental/mace_test.py b/cuequivariance/tests/group_theory/experimental/mace_test.py similarity index 91% rename from cuequivariance/tests/experimental/mace_test.py rename to cuequivariance/tests/group_theory/experimental/mace_test.py index fec5d863..943c8059 100644 --- a/cuequivariance/tests/experimental/mace_test.py +++ b/cuequivariance/tests/group_theory/experimental/mace_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import cuequivariance as cue -from cuequivariance.experimental.mace import symmetric_contraction +from cuequivariance.group_theory.experimental.mace import symmetric_contraction def test_symmetric_contraction(): diff --git a/cuequivariance/tests/irreps/irreps_test.py b/cuequivariance/tests/group_theory/irreps/irreps_test.py similarity index 100% rename from cuequivariance/tests/irreps/irreps_test.py rename to cuequivariance/tests/group_theory/irreps/irreps_test.py diff --git a/cuequivariance/tests/irreps_array_test.py b/cuequivariance/tests/group_theory/irreps_array_test.py similarity index 100% rename from cuequivariance/tests/irreps_array_test.py rename to cuequivariance/tests/group_theory/irreps_array_test.py diff --git a/cuequivariance/tests/reduced_tensor_product_test.py b/cuequivariance/tests/group_theory/reduced_tensor_product_test.py similarity index 100% rename from cuequivariance/tests/reduced_tensor_product_test.py rename to cuequivariance/tests/group_theory/reduced_tensor_product_test.py diff --git a/cuequivariance/tests/segmented_tensor_product/dot_test.py b/cuequivariance/tests/segmented_polynomials/dot_test.py similarity index 78% rename from cuequivariance/tests/segmented_tensor_product/dot_test.py rename to cuequivariance/tests/segmented_polynomials/dot_test.py index 4d8c93c4..4d5a68d1 100644 --- a/cuequivariance/tests/segmented_tensor_product/dot_test.py +++ b/cuequivariance/tests/segmented_polynomials/dot_test.py @@ -15,30 +15,30 @@ import numpy as np import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp -from cuequivariance import descriptors +import cuequivariance.segmented_polynomials as sp +from cuequivariance.group_theory import descriptors def test_dot1(): - d1 = stp.SegmentedTensorProduct.from_subscripts("iab_jb_ak+ijk") + d1 = cue.SegmentedTensorProduct.from_subscripts("iab_jb_ak+ijk") d1.add_path(None, None, None, c=np.random.randn(2, 2, 2), dims={"a": 2, "b": 3}) d1.add_path(None, None, None, c=np.random.randn(2, 3, 2), dims={"a": 4, "b": 3}) d1.add_path(0, 1, 0, c=np.random.randn(2, 3, 2)) - d2 = stp.SegmentedTensorProduct.from_subscripts("jb_b_+j") + d2 = cue.SegmentedTensorProduct.from_subscripts("jb_b_+j") d2.add_path(None, None, None, c=np.random.randn(2), dims={"b": 3}) d2.add_path(None, 0, None, c=np.random.randn(3)) - d3 = stp.dot(d1, d2, (1, 0)) + d3 = sp.dot(d1, d2, (1, 0)) assert d3.subscripts == "iab,ak,b,+ijk" x0, x2 = np.random.randn(d1.operands[0].size), np.random.randn(d1.operands[2].size) y1 = np.random.randn(d2.operands[1].size) - tmp = stp.compute_last_operand(d1.move_operand_last(1), x0, x2) - z0 = stp.compute_last_operand(d2, tmp, y1) + tmp = sp.compute_last_operand(d1.move_operand_last(1), x0, x2) + z0 = sp.compute_last_operand(d2, tmp, y1) - z1 = stp.compute_last_operand(d3, x0, x2, y1) + z1 = sp.compute_last_operand(d3, x0, x2, y1) np.testing.assert_allclose(z0, z1) @@ -49,11 +49,11 @@ def make_examples(): cue.Irreps("SO3", "4x0 + 3x1"), cue.Irreps("SO3", "3x0 + 5x1"), irreps_middle, - ).d + ).polynomial.operations[0][1] assert dx.subscripts == "uvw,iu,jv,kw+ijk" dy = descriptors.channelwise_tensor_product( irreps_middle, cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0 + 1") - ).d + ).polynomial.operations[0][1] dy = dy.squeeze_modes("v") assert dy.subscripts == "u,iu,j,ku+ijk" dy = dy.add_or_rename_modes("w_kw_l_mw+klm") @@ -63,16 +63,16 @@ def make_examples(): def test_dot2(): dx, dy = make_examples() - dxy = stp.dot(dx, dy, (3, 1)) + dxy = sp.dot(dx, dy, (3, 1)) assert dxy.subscripts == "uvw,iu,jv,w,l,mw+ijklm" x0, x1, x2 = [np.random.randn(dx.operands[i].size) for i in range(3)] y0, y2 = [np.random.randn(dy.operands[i].size) for i in [0, 2]] - tmp = stp.compute_last_operand(dx, x0, x1, x2) - A = stp.compute_last_operand(dy, y0, tmp, y2) + tmp = sp.compute_last_operand(dx, x0, x1, x2) + A = sp.compute_last_operand(dy, y0, tmp, y2) - B = stp.compute_last_operand(dxy, x0, x1, x2, y0, y2) + B = sp.compute_last_operand(dxy, x0, x1, x2, y0, y2) np.testing.assert_allclose(A, B) @@ -80,13 +80,13 @@ def test_dot2(): def test_trace(): dx, dy = make_examples() - d1 = stp.dot(dx, dy, (3, 1)) + d1 = sp.dot(dx, dy, (3, 1)) d1 = d1.canonicalize_subscripts() d1 = d1.sort_paths() assert dy.subscripts == "w,kw,l,mw+klm" dy = dy.add_or_rename_modes("a_xa_y_za+xyz") - d2 = stp.trace(stp.dot(dx, dy), (3, 4 + 1)) + d2 = sp.trace(sp.dot(dx, dy), (3, 4 + 1)) d2 = d2.canonicalize_subscripts() d2 = d2.sort_paths() diff --git a/cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py b/cuequivariance/tests/segmented_polynomials/evaluate_test.py similarity index 80% rename from cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py rename to cuequivariance/tests/segmented_polynomials/evaluate_test.py index d57b3390..7f1ced10 100644 --- a/cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py +++ b/cuequivariance/tests/segmented_polynomials/evaluate_test.py @@ -14,40 +14,42 @@ # limitations under the License. import numpy as np -import cuequivariance.segmented_tensor_product as stp +import cuequivariance as cue def test_compute_last_operand_1(): - d = stp.SegmentedTensorProduct.from_subscripts("uv_vw_uw") + d = cue.SegmentedTensorProduct.from_subscripts("uv_vw_uw") d.add_path(None, None, None, c=1.0, dims={"u": 2, "v": 3, "w": 4}) x0 = np.random.randn(d.operands[0].size) x1 = np.random.randn(d.operands[1].size) - x2 = stp.compute_last_operand(d, x0, x1) + x2 = cue.segmented_polynomials.compute_last_operand(d, x0, x1) x2_ = (x0.reshape(2, 3) @ x1.reshape(3, 4)).reshape(-1) np.testing.assert_allclose(x2_, x2) def test_compute_last_operand_2(): - d = stp.SegmentedTensorProduct.from_subscripts("uv,vw,uw") + d = cue.SegmentedTensorProduct.from_subscripts("uv,vw,uw") d.add_path(None, None, None, c=1.0, dims={"u": 2, "v": 3, "w": 4}) x0 = np.random.randn(10, d.operands[0].size) x1 = np.random.randn(10, d.operands[1].size) - x2 = stp.compute_last_operand(d, x0, x1) + x2 = cue.segmented_polynomials.compute_last_operand(d, x0, x1) x2_ = (x0.reshape(10, 2, 3) @ x1.reshape(10, 3, 4)).reshape(10, -1) np.testing.assert_allclose(x2_, x2) def test_compute_last_operand_3(): - d = stp.SegmentedTensorProduct.from_subscripts("uv,vw,uw") + d = cue.SegmentedTensorProduct.from_subscripts("uv,vw,uw") d.add_path(None, None, None, c=1.0, dims={"u": 2, "v": 3, "w": 4}) x0 = np.random.randn(1, d.operands[0].size, 5) x1 = np.random.randn(10, d.operands[1].size, 1) - x2 = stp.compute_last_operand(d, x0, x1, segment_axes=[1, 1, 1]) + x2 = cue.segmented_polynomials.compute_last_operand( + d, x0, x1, segment_axes=[1, 1, 1] + ) x2_ = np.einsum( "ZuvA,ZvwA->ZuwA", x0.reshape(1, 2, 3, 5), x1.reshape(10, 3, 4, 1) @@ -56,13 +58,13 @@ def test_compute_last_operand_3(): def test_compute_last_operand_4(): - d = stp.SegmentedTensorProduct.from_subscripts("iuv_jvw_kuw+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iuv_jvw_kuw+ijk") c = np.random.randn(2, 3, 4) d.add_path(None, None, None, c=c, dims={"u": 2, "v": 3, "w": 4}) x0 = np.random.randn(d.operands[0].size) x1 = np.random.randn(d.operands[1].size) - x2 = stp.compute_last_operand(d, x0, x1) + x2 = cue.segmented_polynomials.compute_last_operand(d, x0, x1) x2_ = np.einsum( "ijk,iuv,jvw->kuw", c, x0.reshape(2, 2, 3), x1.reshape(3, 3, 4) @@ -71,7 +73,7 @@ def test_compute_last_operand_4(): def test_primitive_compute_last_operand(): - d = stp.SegmentedTensorProduct.from_subscripts("iuv_jvw_kuw+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iuv_jvw_kuw+ijk") c = np.random.randn(2, 3, 4) d.add_path(None, None, None, c=c, dims={"u": 2, "v": 3, "w": 4}) @@ -84,7 +86,7 @@ def test_primitive_compute_last_operand(): d = d.to_dict(True) - x2 = stp.primitive_compute_last_operand( + x2 = cue.segmented_polynomials.primitive_compute_last_operand( [ope["subscripts"] for ope in d["operands"]], d["coefficient_subscripts"], [ope["segments"] for ope in d["operands"]], diff --git a/cuequivariance/tests/operation_test.py b/cuequivariance/tests/segmented_polynomials/operation_test.py similarity index 100% rename from cuequivariance/tests/operation_test.py rename to cuequivariance/tests/segmented_polynomials/operation_test.py diff --git a/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py new file mode 100644 index 00000000..57ca9d93 --- /dev/null +++ b/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py @@ -0,0 +1,424 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +import cuequivariance as cue + + +def make_simple_stp() -> cue.SegmentedTensorProduct: + d = cue.SegmentedTensorProduct.empty_segments([2, 2, 2]) + d.add_path(0, 0, 0, c=1.0) + d.add_path(1, 1, 1, c=-2.0) + return d + + +def make_simple_dot_product_stp() -> cue.SegmentedTensorProduct: + d = cue.SegmentedTensorProduct.from_subscripts("i,j,k+ijk") + i0 = d.add_segment(0, (3,)) + i1 = d.add_segment(1, (3,)) + i2 = d.add_segment(2, (1,)) + d.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) + return d + + +def test_init_segmented_polynomial(): + """Test initialization of SegmentedPolynomial.""" + stp = make_simple_stp() + poly = cue.SegmentedPolynomial.eval_last_operand(stp) + + assert poly.num_inputs == 2 + assert poly.num_outputs == 1 + assert poly.num_operands == 3 + assert len(poly.operations) == 1 + assert poly.operations[0] == (cue.Operation((0, 1, 2)), stp) + + +def test_polynomial_equality(): + """Test equality comparison of polynomials.""" + stp1 = make_simple_stp() + stp2 = make_simple_stp() + + poly1 = cue.SegmentedPolynomial.eval_last_operand(stp1) + poly2 = cue.SegmentedPolynomial.eval_last_operand(stp2) + poly3 = cue.SegmentedPolynomial.eval_last_operand(2 * stp2) + + assert poly1 == poly2 + assert poly1 != poly3 + assert poly1 < poly3 # Test less than operator + + +def test_call_function(): + """Test calling the polynomial as a function.""" + # Create a simple bilinear form: f(a, b) = a^T * b + # For this specific test, we need a particular structure + stp = cue.SegmentedTensorProduct.from_subscripts("i,j,k+ijk") + i0 = stp.add_segment(0, (3,)) + i1 = stp.add_segment(1, (3,)) + i2 = stp.add_segment(2, (1,)) + stp.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) + + poly = cue.SegmentedPolynomial.eval_last_operand(stp) + + # Test evaluation + a = np.array([1.0, 2.0, 3.0]) + b = np.array([4.0, 5.0, 6.0]) + + [result] = poly(a, b) + expected = np.array([a.dot(b)]) # Dot product + + assert np.allclose(result, expected) + + +def test_buffer_properties(): + """Test properties related to buffer sizes and usage.""" + stp1 = make_simple_stp() + op1 = cue.Operation((0, 1, 2)) + + # Create a second STP with different structure for testing multiple buffers + stp2 = cue.SegmentedTensorProduct.empty_segments([2, 1]) + stp2.add_path(0, 0, c=1.0) + op2 = cue.Operation((0, 3)) + + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(1), + ], + [(op1, stp1), (op2, stp2)], + ) + + # Test buffer properties + assert [ope.size for ope in poly.operands] == [2, 2, 2, 1] + + assert poly.used_buffers() == [True, True, True, True] + + +def test_remove_unused_buffers(): + """Test removing unused buffers from the polynomial.""" + stp = make_simple_stp() + # Use operation that doesn't use buffer 1 + op = cue.Operation((0, 2, 3)) # Note: buffer 1 is not used + + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), # unused + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op, stp)], + ) + + # Buffer 1 is not used + assert poly.used_buffers() == [True, False, True, True] + + # Remove unused buffer + cleaned_poly = poly.remove_unused_buffers() + + assert cleaned_poly.num_inputs == 2 + assert cleaned_poly.num_outputs == 1 + assert cleaned_poly.used_buffers() == [True, True, True] + + +def test_consolidate(): + """Test consolidating tensor products.""" + stp1 = make_simple_stp() + stp2 = make_simple_stp() + + op = cue.Operation((0, 1, 2)) + + # Create a polynomial with duplicate operations + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op, stp1), (op, stp2)], + ) + + # Consolidate the polynomial + consolidated = poly.consolidate() + + # Should have fused the two tensor products + assert len(consolidated.operations) == 1 + # Coefficients should have been combined for each path + assert len(consolidated.operations[0][1].paths) == 2 + # The coefficients should have been added + assert consolidated.operations[0][1].paths[0].coefficients == 2.0 + assert consolidated.operations[0][1].paths[1].coefficients == -4.0 + + +def test_stack(): + """Test stacking polynomials.""" + # Create two simple polynomials using make_simple_stp + stp = make_simple_stp() + op1 = cue.Operation((0, 1, 2)) + poly1 = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op1, stp)], + ) + + stp2 = make_simple_stp() + op2 = cue.Operation((0, 1, 2)) + poly2 = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op2, stp2)], + ) + + # Stack the polynomials with the output being stacked + stacked = cue.SegmentedPolynomial.stack([poly1, poly2], [False, False, True]) + + assert stacked.num_inputs == 2 + assert stacked.num_outputs == 1 + + assert [ope.size for ope in stacked.operands] == [2, 2, 4] + + [(_, stp)] = stacked.operations + assert stp.operands[0].num_segments == 2 + assert stp.operands[1].num_segments == 2 + assert stp.operands[2].num_segments == 4 + assert stp.num_paths == 4 + assert stp.paths[0].indices == (0, 0, 0) + assert stp.paths[1].indices == (0, 0, 2 + 0) + assert stp.paths[2].indices == (1, 1, 1) + assert stp.paths[3].indices == (1, 1, 2 + 1) + + +def test_flops_and_memory(): + """Test computation of FLOPS and memory usage.""" + stp = make_simple_stp() + op = cue.Operation((0, 1, 2)) + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op, stp)], + ) + # Test FLOPS calculation + flops = poly.flops(batch_size=100) + assert flops > 0 + + # Test memory calculation + memory = poly.memory([100, 100, 100]) + assert memory == 100 * (2 + 2 + 2) # All operands have size 2 + + +def test_jvp(): + """Test Jacobian-vector product computation.""" + # Create a simple polynomial for testing: f(x,y) = x^T * y (dot product) + stp = make_simple_dot_product_stp() + poly = cue.SegmentedPolynomial.eval_last_operand(stp) + + # Input values + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + + # Tangent vectors (directions for differentiation) + x_tangent = np.array([0.1, 0.2, 0.3]) + y_tangent = np.array([0.4, 0.5, 0.6]) + + # Create the JVP polynomial for both inputs having tangents + jvp_poly = poly.jvp([True, True]) + + # When both inputs have tangents, we need to concatenate inputs and tangents + # The JVP polynomial expects inputs followed by their respective tangents + jvp_result = jvp_poly(x, y, x_tangent, y_tangent) + + # For the dot product function f(x,y) = x^T * y: + # The Jacobian w.r.t x is y^T, and the Jacobian w.r.t y is x^T + # So Jvp = y^T * x_tangent + x^T * y_tangent + expected_jvp = np.array([y.dot(x_tangent) + x.dot(y_tangent)]) + + assert np.allclose(jvp_result[0], expected_jvp) + + # Test with only x having a tangent + jvp_x_only = poly.jvp([True, False]) + x_only_result = jvp_x_only(x, y, x_tangent) + expected_x_only = np.array([y.dot(x_tangent)]) + assert np.allclose(x_only_result[0], expected_x_only) + + # Test with only y having a tangent + jvp_y_only = poly.jvp([False, True]) + y_only_result = jvp_y_only(x, y, y_tangent) + expected_y_only = np.array([x.dot(y_tangent)]) + assert np.allclose(y_only_result[0], expected_y_only) + + +def test_transpose_linear(): + """Test transposing a linear polynomial.""" + # Create a linear polynomial f(x, y) = Ax where A is a matrix + # Here we use f(x, y) = x^T * y (dot product) + # This is linear in both x and y + stp = make_simple_dot_product_stp() + poly = cue.SegmentedPolynomial.eval_last_operand(stp) + + # Input values + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + + # x dot y = 1*4 + 2*5 + 3*6 = 32 + + # Cotangent for the output + cotangent = np.array([2.0]) + + # Test transposing with respect to x (x is undefined primal) + # is_undefined_primal = [True, False] means x is undefined, y is defined + # has_cotangent = [True] means the output has a cotangent + transpose_x = poly.transpose( + is_undefined_primal=[True, False], has_cotangent=[True] + ) + + # The transpose polynomial should compute the gradient of the output w.r.t x + # For f(x, y) = x^T * y, the gradient w.r.t x is y + # So transpose_x(y, cotangent) should be y * cotangent + x_result = transpose_x(y, cotangent) + expected_x_result = y * cotangent[0] + assert np.allclose(x_result[0], expected_x_result) + + # Test transposing with respect to y (y is undefined primal) + transpose_y = poly.transpose( + is_undefined_primal=[False, True], has_cotangent=[True] + ) + + # For f(x, y) = x^T * y, the gradient w.r.t y is x + # So transpose_y(x, cotangent) should be x * cotangent + y_result = transpose_y(x, cotangent) + expected_y_result = x * cotangent[0] + assert np.allclose(y_result[0], expected_y_result) + + +def test_transpose_nonlinear(): + """Test transposing a non-linear polynomial raises an error.""" + # Create a non-linear polynomial + stp = make_simple_stp() + op = cue.Operation((0, 0, 1)) # Note: using the same buffer twice (x^2) + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op, stp)], + ) + + # Try to transpose the non-linear polynomial + # This should raise a ValueError since there are multiple undefined primals + # (the same input buffer is used twice) + with np.testing.assert_raises(ValueError): + poly.transpose(is_undefined_primal=[True], has_cotangent=[True]) + + +def test_backward(): + """Test the backward method for gradient computation.""" + # Create a linear polynomial for testing: f(x,y) = x^T * y (dot product) + stp = make_simple_dot_product_stp() + poly = cue.SegmentedPolynomial.eval_last_operand(stp) + + # Input values + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + + # Cotangent for the output (upstream gradient) + cotangent = np.array([2.0]) + + # Test backward with respect to both x and y + backward_both = poly.backward(requires_gradient=[True, True], has_cotangent=[True]) + + # The backward polynomial computes gradients for all inputs that require gradients + # For f(x,y) = x^T * y: + # - gradient w.r.t x is y * cotangent + # - gradient w.r.t y is x * cotangent + grad_x, grad_y = backward_both(x, y, cotangent) + expected_grad_x = y * cotangent[0] + expected_grad_y = x * cotangent[0] + + assert np.allclose(grad_x, expected_grad_x) + assert np.allclose(grad_y, expected_grad_y) + + # Test backward with respect to only x + backward_x = poly.backward(requires_gradient=[True, False], has_cotangent=[True]) + + # Should only compute gradient for x + [grad_x_only] = backward_x(x, y, cotangent) + assert np.allclose(grad_x_only, expected_grad_x) + + # Test backward with respect to only y + backward_y = poly.backward(requires_gradient=[False, True], has_cotangent=[True]) + + # Should only compute gradient for y + [grad_y_only] = backward_y(x, y, cotangent) + assert np.allclose(grad_y_only, expected_grad_y) + + # Test with zero cotangent + zero_cotangent = np.array([0.0]) + grad_x_zero, grad_y_zero = backward_both(x, y, zero_cotangent) + + # With zero cotangent, gradients should be zero + assert np.allclose(grad_x_zero, np.zeros_like(x)) + assert np.allclose(grad_y_zero, np.zeros_like(y)) + + +def test_symmetrize_identical_operands(): + """Test symmetrization and unsymmetrization of polynomials with identical operands.""" + stp = cue.SegmentedTensorProduct.empty_segments([2, 2, 1]) + stp.add_path(0, 1, 0, c=1.0) # x0 * y1 path + + # Create operation that uses the same input buffer twice + op = cue.Operation((0, 0, 1)) # Use buffer 0 twice, write to buffer 1 + poly = cue.SegmentedPolynomial( + [cue.SegmentedOperand.empty_segments(2)], + [cue.SegmentedOperand.empty_segments(1)], + [(op, stp)], + ) + # Symmetrize the polynomial + sym_poly = poly.symmetrize_for_identical_operands() + + # Check that we get 0.5 x0*y1 + 0.5 x1*y0 + # This means we should have two paths with coefficient 0.5 + [(_, sym_stp)] = sym_poly.operations + assert len(sym_stp.paths) == 2 + # Check that we get 0.5 x0*y1 + 0.5 x1*y0 + assert sym_stp.paths[0].coefficients == 0.5 + assert sym_stp.paths[1].coefficients == 0.5 + # Check that the paths have different indices (operands swapped) + assert sym_stp.paths[0].indices == (0, 1, 0) + assert sym_stp.paths[1].indices == (1, 0, 0) + + # Test that unsymmetrize returns to original form + unsym_poly = sym_poly.unsymmetrize_for_identical_operands() + [(_, unsym_stp)] = unsym_poly.operations + assert len(unsym_stp.paths) == 1 + assert unsym_stp.paths[0].coefficients == 1.0 + assert unsym_stp.paths[0].indices == (0, 1, 0) + + # Test evaluation to verify the symmetrization works correctly + x = np.array([1.0, 2.0]) + [result] = poly(x) # Original polynomial + [sym_result] = sym_poly(x) # Symmetrized polynomial + assert np.allclose(result, sym_result) # Results should be identical diff --git a/cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py b/cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py new file mode 100644 index 00000000..749ea75d --- /dev/null +++ b/cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py @@ -0,0 +1,472 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +import cuequivariance as cue + + +def make_coeffs(shape): + n = np.prod(shape) + c = np.arange(n) + 1.0 + return c.reshape(shape) + + +def test_user_friendly(): + d = cue.SegmentedTensorProduct.from_subscripts("ia_jb_kab+ijk") + assert ( + str(d) + == "ia,jb,kab+ijk sizes=0,0,0 num_segments=0,0,0 num_paths=0 a= b= i= j= k=" + ) + + with pytest.raises(ValueError): + d.add_path(0, 0, 0, c=make_coeffs((2, 2, 3))) # need to add segments first + + with pytest.raises(ValueError): + d.add_segment(0, (2, 2, 2)) # wrong number of dimensions + + d.add_segment(0, (5, 16)) + d.add_segment(1, (5, 32)) + d.add_segment(2, (4, 16, 32)) + d.add_segment(2, (4, 32, 32)) + + with pytest.raises(ValueError): + d.add_path(0, 0, 0, c=make_coeffs((5, 5, 5))) # wrong dimension for k + + assert ( + str(d) + == "ia,jb,kab+ijk sizes=80,160,6144 num_segments=1,1,2 num_paths=0 a={16, 32} b=32 i=5 j=5 k=4" + ) + + d.add_path(0, 0, 0, c=make_coeffs((5, 5, 4))) + assert ( + str(d) + == "ia,jb,kab+ijk sizes=80,160,6144 num_segments=1,1,2 num_paths=1 a={16, 32} b=32 i=5 j=5 k=4" + ) + + assert not d.all_segments_are_used() + assert d.subscripts == "ia,jb,kab+ijk" + assert d.subscripts.is_equivalent("ia,jb,kab+ijk") + + d.assert_valid() + + +def test_squeeze(): + d = cue.SegmentedTensorProduct.from_subscripts("i_j+ij") + d.add_segment(0, (1,)) + d.add_segment(1, (20,)) + d.add_path(0, 0, c=make_coeffs((1, 20))) + d.assert_valid() + + assert d.squeeze_modes().subscripts == ",j+j" + + d = cue.SegmentedTensorProduct.from_subscripts("i_j+ij") + d.add_segment(0, (1,)) + d.add_segment(0, (2,)) + d.add_segment(1, (20,)) + d.add_path(0, 0, c=make_coeffs((1, 20))) + d.add_path(1, 0, c=make_coeffs((2, 20))) + d.assert_valid() + + assert d.squeeze_modes().subscripts == "i,j+ij" + + with pytest.raises(ValueError): + d.squeeze_modes("i") + + +def test_normalize_paths_for_operand(): + d = cue.SegmentedTensorProduct.from_subscripts("i_j+ij") + + d.add_segments(0, 2 * [(2,)]) + d.add_segments(1, 2 * [(3,)]) + d.assert_valid() + + d.add_path(0, 0, c=np.array([[2, 0, 0], [0, 2, 0]])) + d.assert_valid() + + d = d.normalize_paths_for_operand(0) + d.assert_valid() + + np.testing.assert_allclose( + d.paths[0].coefficients, + np.array( + [ + [1, 0, 0], + [0, 1, 0], + ] + ), + ) + + +def make_example_descriptor(): + d = cue.SegmentedTensorProduct.from_subscripts("uv_iu_jv+ij") + d.add_path( + None, + None, + None, + c=np.random.randn(2, 3), + dims={"u": 4, "v": 5, "i": 2, "j": 3}, + ) + d.assert_valid() + d.add_path( + None, + None, + None, + c=np.random.randn(4, 5), + dims={"u": 2, "v": 3, "i": 4, "j": 5}, + ) + d.assert_valid() + return d + + +def test_flatten(): + d = make_example_descriptor() + d.assert_valid() + + assert d.flatten_modes("").subscripts == "uv,iu,jv+ij" + assert d.flatten_modes("i").subscripts == "uv,u,jv+j" + assert d.flatten_modes("j").subscripts == "uv,iu,v+i" + assert d.flatten_modes("ij").subscripts == "uv,u,v" + assert d.flatten_modes("ui").subscripts == "v,,jv+j" + + x0 = np.random.randn(d.operands[0].size) + x1 = np.random.randn(d.operands[1].size) + x2 = cue.segmented_polynomials.compute_last_operand(d, x0, x1) + + for channels in ["i", "j", "ij", "ui", "iju", "uvij"]: + np.testing.assert_allclose( + x2, + cue.segmented_polynomials.compute_last_operand( + d.flatten_modes(channels), x0, x1 + ), + ) + + +def test_flatten_coefficients(): + d = make_example_descriptor() + + assert d.subscripts == "uv,iu,jv+ij" + assert d.flatten_coefficient_modes().subscripts == "uv,u,v" + + d = d.add_or_transpose_modes("uv,ui,jv+ij") + assert d.subscripts == "uv,ui,jv+ij" + + with pytest.raises(ValueError): + d.flatten_coefficient_modes() + + assert d.flatten_coefficient_modes(force=True).subscripts == "v,,v" + + +def test_consolidate(): + d = cue.SegmentedTensorProduct.from_subscripts("ab_ab") + d.add_segment(0, (2, 3)) + d.add_segment(1, (2, 3)) + d.add_path(0, 0, c=1.0) + d.assert_valid() + + assert d.consolidate_modes().subscripts == "a,a" + + d = cue.SegmentedTensorProduct.from_subscripts("ab_ab_a") + d.add_segment(0, (2, 3)) + d.add_segment(1, (2, 3)) + d.add_segment(2, (2,)) + d.add_path(0, 0, 0, c=1.0) + d.assert_valid() + + assert d.consolidate_modes() == d + + d = cue.SegmentedTensorProduct.from_subscripts("ab_iab+abi") + d.add_segment(0, (2, 3)) + d.add_segment(1, (4, 2, 3)) + d.add_path(0, 0, c=make_coeffs((2, 3, 4))) + d.assert_valid() + + assert d.consolidate_modes().subscripts == "a,ia+ai" + + d = cue.SegmentedTensorProduct.from_subscripts("ab,iab+abi") + d.add_segment(0, (2, 3)) + d.add_segment(1, (4, 2, 3)) + d.add_path(0, 0, c=make_coeffs((2, 3, 4))) + + assert d.consolidate_modes().subscripts == "a,ia+ai" + + +def test_stacked_coefficients(): + d = cue.SegmentedTensorProduct.from_subscripts("ab_ab+ab") + d.add_segment(0, (2, 3)) + d.add_segment(1, (2, 3)) + np.testing.assert_allclose(d.stacked_coefficients, make_coeffs((0, 2, 3))) + + d.add_path(0, 0, c=make_coeffs((2, 3))) + d.add_path(0, 0, c=make_coeffs((2, 3))) + expected = np.stack([make_coeffs((2, 3)), make_coeffs((2, 3))], axis=0) + np.testing.assert_allclose(d.stacked_coefficients, expected) + + d = d.consolidate_paths() + np.testing.assert_allclose(d.stacked_coefficients, 2 * make_coeffs((1, 2, 3))) + + +@pytest.mark.parametrize("extended", [False, True]) +def test_data_transfer(extended: bool): + d = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") + d.add_path(None, None, None, c=make_coeffs((3, 3, 3)), dims={"u": 12}) + d.add_path(None, None, None, c=make_coeffs((1, 2, 1)), dims={"u": 14}) + d.assert_valid() + + dict = d.to_dict(extended) + json = d.to_json(extended) + bin = d.to_bytes(extended) + b64 = d.to_base64(extended) + + assert d == cue.SegmentedTensorProduct.from_dict(dict) + assert d == cue.SegmentedTensorProduct.from_json(json) + assert d == cue.SegmentedTensorProduct.from_bytes(bin) + assert d == cue.SegmentedTensorProduct.from_base64(b64) + + +def test_to_text(): + d = cue.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") + d.add_path(None, None, None, c=make_coeffs((3, 3, 3)), dims={"u": 12}) + d.add_path(None, None, None, c=make_coeffs((1, 2, 1)), dims={"u": 14}) + d = d.flatten_modes("ijk") + + text = d.to_text() + assert ( + text + == """u,u,u sizes=50,64,50 num_segments=4,5,4 num_paths=29 u={12, 14} +operand #0 subscripts=u + | u: [12, 12, 12, 14] +operand #1 subscripts=u + | u: [12, 12, 12, 14, 14] +operand #2 subscripts=u + | u: [12, 12, 12, 14] +Flop cost: 0->704 1->704 2->704 +Memory cost: 164 +Path indices: 0 0 0, 0 0 1, 0 0 2, 0 1 0, 0 1 1, 0 1 2, 0 2 0, 0 2 1, 0 2 2, 1 0 0, 1 0 1, 1 0 2, 1 1 0, 1 1 1, 1 1 2, 1 2 0, 1 2 1, 1 2 2, 2 0 0, 2 0 1, 2 0 2, 2 1 0, 2 1 1, 2 1 2, 2 2 0, 2 2 1, 2 2 2, 3 3 3, 3 4 3 +Path coefficients: [1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0 12.0 13.0 14.0 15.0 16.0 17.0 18.0 19.0 20.0 21.0 22.0 23.0 24.0 25.0 26.0 27.0 1.0 2.0]""" + ) + + +def test_hash(): + d = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") + d.add_path(None, None, None, c=make_coeffs((3, 3, 3)), dims={"u": 12}) + d.add_path(None, None, None, c=make_coeffs((1, 2, 1)), dims={"u": 14}) + assert hash(d) == hash(d) + + d2 = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") + assert hash(d) != hash(d2) + d2.add_path(None, None, None, c=make_coeffs((3, 3, 3)), dims={"u": 12}) + assert hash(d) != hash(d2) + d2.add_path(None, None, None, c=make_coeffs((1, 2, 1)), dims={"u": 14}) + assert hash(d) == hash(d2) + + +def test_split_mode(): + # Create a descriptor with a mode that has dimensions divisible by the desired split size + d = cue.SegmentedTensorProduct.from_subscripts("ua,ub+ab") + + # Add a segment with u dimension = 6 (divisible by 2 and 3) + d.add_segment(0, (6, 4)) + d.add_segment(1, (6, 5)) + + # Add a path + d.add_path(0, 0, c=make_coeffs((4, 5))) + d.assert_valid() + + # Split mode 'u' with size 2 + d_split = d.split_mode("u", 2) + d_split.assert_valid() + + # Check that the dimensions are correctly split + assert d_split.operands[0].num_segments == 3 # 6/2 = 3 segments + assert d_split.operands[1].num_segments == 3 # 6/2 = 3 segments + + # Check that the subscripts are preserved + assert d_split.subscripts == "ua,ub+ab" + + # Check that the segments have the correct shape + for segment in d_split.operands[0]: + assert segment[0] == 2 # First dimension should be 2 + assert segment[1] == 4 # Second dimension should be 4 + + for segment in d_split.operands[1]: + assert segment[0] == 2 # First dimension should be 2 + assert segment[1] == 5 # Second dimension should be 5 + + # Test with a different split size + d_split_3 = d.split_mode("u", 3) + d_split_3.assert_valid() + + assert d_split_3.operands[0].num_segments == 2 # 6/3 = 2 segments + assert d_split_3.operands[1].num_segments == 2 # 6/3 = 2 segments + + # Test error case: split size not divisible by dimension + with pytest.raises(ValueError): + d.split_mode("u", 5) # 6 is not divisible by 5 + + # Test case where mode is not in descriptor + d_unchanged = d.split_mode("v", 2) # 'v' is not in the descriptor + assert d_unchanged == d + + # Test case where mode is not at the beginning of the operand + d_complex = cue.SegmentedTensorProduct.from_subscripts("au,bu+ab") + d_complex.add_segment(0, (3, 6)) + d_complex.add_segment(1, (4, 6)) + d_complex.add_path(0, 0, c=make_coeffs((3, 4))) + + with pytest.raises(ValueError): + d_complex.split_mode("u", 2) # 'u' is not the first mode in operands + + # Test with coefficient subscripts + d_coeff = cue.SegmentedTensorProduct.from_subscripts("ua,ub,ab+ab") + d_coeff.add_segment(0, (6, 4)) + d_coeff.add_segment(1, (6, 5)) + d_coeff.add_segment(2, (4, 5)) + d_coeff.add_path(0, 0, 0, c=make_coeffs((4, 5))) + + d_coeff_split = d_coeff.split_mode("u", 2) + d_coeff_split.assert_valid() + + assert d_coeff_split.operands[0].num_segments == 3 + assert d_coeff_split.operands[1].num_segments == 3 + assert d_coeff_split.operands[2].num_segments == 1 # Not affected by u split + + # Check that computation results are equivalent + # Create a simple descriptor with just two operands for testing compute_last_operand + d_compute = cue.SegmentedTensorProduct.from_subscripts("a,b+ab") + d_compute.add_segment(0, (4,)) + d_compute.add_segment(1, (5,)) + d_compute.add_path(0, 0, c=make_coeffs((4, 5))) + + # Test computation on original descriptor + x_input = np.random.randn(d_compute.operands[0].size) + result_original = cue.segmented_polynomials.compute_last_operand(d_compute, x_input) + + # Verify split_mode works by first flattening the results to remove 'u' mode indices + d_ua = cue.SegmentedTensorProduct.from_subscripts("ua,b+ab") + d_ua.add_segment(0, (6, 4)) + d_ua.add_segment(1, (5,)) + d_ua.add_path(0, 0, c=make_coeffs((4, 5))) + d_ua_split = d_ua.split_mode("u", 2) + + # Input for the split descriptor - we need a tensor with the right shape + x_input_split = np.random.randn(d_ua_split.operands[0].size) + result_split = cue.segmented_polynomials.compute_last_operand( + d_ua_split, x_input_split + ) + + # Verify the shapes are consistent with our expectations + assert result_original.shape == (5,) + assert result_split.shape == (5,) + + +def test_add_or_transpose_modes(): + # Test 1: Simple mode transposition + d = cue.SegmentedTensorProduct.from_subscripts("ia,ja+ij") + d.add_segment(0, (3, 4)) + d.add_segment(1, (5, 4)) + d.add_path(0, 0, c=make_coeffs((3, 5))) + d.assert_valid() + + # Transpose modes in first operand + d_trans = d.add_or_transpose_modes("ai,ja+ij") + d_trans.assert_valid() + assert d_trans.subscripts == "ai,ja+ij" + assert d_trans.operands[0][0] == (4, 3) and d_trans.operands[1][0] == (5, 4) + np.testing.assert_allclose(d_trans.paths[0].coefficients, d.paths[0].coefficients) + + # Test 2: Adding new modes + d = cue.SegmentedTensorProduct.from_subscripts("i,j+ij") + d.add_segment(0, (3,)) + d.add_segment(1, (4,)) + d.add_path(0, 0, c=make_coeffs((3, 4))) + d.assert_valid() + + # Add new modes with specified dimensions + d_new = d.add_or_transpose_modes("ia,ja+ij", dims={"a": 5}) + d_new.assert_valid() + assert d_new.subscripts == "ia,ja+ij" + assert d_new.operands[0][0] == (3, 5) and d_new.operands[1][0] == (4, 5) + + # Test 3 & 4: Error cases + with pytest.raises(ValueError): + d.add_or_transpose_modes("ia,ja+ij") # Missing dims for a + with pytest.raises(ValueError): + d.add_or_transpose_modes("i,j+i") # Removing j from coefficients + + # Test 5: Transposing coefficient modes + d = cue.SegmentedTensorProduct.from_subscripts("i,j+ij") + d.add_segment(0, (3,)) + d.add_segment(1, (4,)) + d.add_path(0, 0, c=make_coeffs((3, 4))) + d.assert_valid() + + # Transpose coefficient modes + d_coeff = d.add_or_transpose_modes("i,j+ji") + d_coeff.assert_valid() + assert d_coeff.subscripts == "i,j+ji" + np.testing.assert_allclose(d_coeff.paths[0].coefficients, d.paths[0].coefficients.T) + + # Test 6: Adding batch dimensions + d = cue.SegmentedTensorProduct.from_subscripts("ui,uj+ij") + d.add_segment(0, (5, 3)) + d.add_segment(1, (5, 4)) + d.add_path(0, 0, c=make_coeffs((3, 4))) + d.assert_valid() + + # Add batch dimension to both operands + d_batch = d.add_or_transpose_modes("bui,buj+ij", dims={"b": 8}) + d_batch.assert_valid() + assert d_batch.subscripts == "bui,buj+ij" + assert d_batch.operands[0][0] == (8, 5, 3) and d_batch.operands[1][0] == (8, 5, 4) + + +def test_add_or_rename_modes(): + d = cue.SegmentedTensorProduct.from_subscripts("i,j+ij") + d.add_segment(0, (3,)) + d.add_segment(1, (4,)) + d.add_path(0, 0, c=make_coeffs((3, 4))) + d.assert_valid() + d_id = d.add_or_rename_modes("i,j+ij") + d_id.assert_valid() + assert d_id.subscripts == "i,j+ij" + x_input = np.random.randn(d.operands[0].size) + res0 = cue.segmented_polynomials.compute_last_operand(d, x_input) + res_id = cue.segmented_polynomials.compute_last_operand(d_id, x_input) + np.testing.assert_allclose(res0, res_id) + d_ren = d.add_or_rename_modes("a,b+ab") + d_ren.assert_valid() + assert d_ren.subscripts == "a,b+ab" + np.testing.assert_allclose( + res0, cue.segmented_polynomials.compute_last_operand(d_ren, x_input) + ) + with pytest.raises(ValueError): + d.add_or_rename_modes("i+ij") + d_sup = d.add_or_rename_modes("bi,bj+ij", mapping={"i": "i", "j": "j"}) + d_sup.assert_valid() + assert d_sup.subscripts == "bi,bj+ij" + np.testing.assert_allclose( + res0, cue.segmented_polynomials.compute_last_operand(d_sup, x_input) + ) + + +def test_consolidate_with_optional_argument(): + d = cue.SegmentedTensorProduct.from_subscripts("ab_ab") + d.add_segment(0, (2, 3)) + d.add_segment(1, (2, 3)) + d.add_path(0, 0, c=1.0) + d.assert_valid() + d_consol = d.consolidate_modes("ab") + assert d_consol.subscripts == "a,a" diff --git a/cuequivariance/tests/segmented_tensor_product/subscripts_test.py b/cuequivariance/tests/segmented_polynomials/subscripts_test.py similarity index 74% rename from cuequivariance/tests/segmented_tensor_product/subscripts_test.py rename to cuequivariance/tests/segmented_polynomials/subscripts_test.py index 80598381..572455f6 100644 --- a/cuequivariance/tests/segmented_tensor_product/subscripts_test.py +++ b/cuequivariance/tests/segmented_polynomials/subscripts_test.py @@ -14,23 +14,23 @@ # limitations under the License. import pytest -import cuequivariance.segmented_tensor_product as stp +import cuequivariance.segmented_polynomials as sp def test_subscripts(): with pytest.raises(ValueError): - stp.Subscripts("#$%@") + sp.Subscripts("#$%@") with pytest.raises(ValueError): - stp.Subscripts("Zu") # uppercase not supported anymore + sp.Subscripts("Zu") # uppercase not supported anymore with pytest.raises(ValueError): - stp.Subscripts("uZ") # uppercase after lowercase + sp.Subscripts("uZ") # uppercase after lowercase with pytest.raises(ValueError): - stp.Subscripts("uZ+ij+kl") # multiple + signs + sp.Subscripts("uZ+ij+kl") # multiple + signs - subscripts = stp.Subscripts("ui,vj,uvk+ijk") + subscripts = sp.Subscripts("ui,vj,uvk+ijk") assert subscripts.canonicalize() == "ui,vj,uvk+ijk" assert subscripts.coefficients == "ijk" @@ -45,5 +45,5 @@ def test_subscripts(): def test_canonicalize(): - assert stp.Subscripts("ui").canonicalize() == "uv" - assert stp.Subscripts("ab,ad+bd").canonicalize() == "ui,uj+ij" + assert sp.Subscripts("ui").canonicalize() == "uv" + assert sp.Subscripts("ab,ad+bd").canonicalize() == "ui,uj+ij" diff --git a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py deleted file mode 100644 index bfb6773e..00000000 --- a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py +++ /dev/null @@ -1,263 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import numpy as np -import pytest - -import cuequivariance.segmented_tensor_product as stp - - -def test_user_friendly(): - d = stp.SegmentedTensorProduct.from_subscripts("ia_jb_kab+ijk") - assert ( - str(d) - == "ia,jb,kab+ijk sizes=0,0,0 num_segments=0,0,0 num_paths=0 a= b= i= j= k=" - ) - - with pytest.raises(ValueError): - d.add_path(0, 0, 0, c=np.ones((2, 2, 3))) # need to add segments first - - with pytest.raises(ValueError): - d.add_segment(0, (2, 2, 2)) # wrong number of dimensions - - d.add_segment(0, (5, 16)) - d.add_segment(1, (5, 32)) - d.add_segment(2, (4, 16, 32)) - d.add_segment(2, (4, 32, 32)) - - with pytest.raises(ValueError): - d.add_path(0, 0, 0, c=np.ones((5, 5, 5))) # wrong dimension for k - - assert ( - str(d) - == "ia,jb,kab+ijk sizes=80,160,6144 num_segments=1,1,2 num_paths=0 a={16, 32} b=32 i=5 j=5 k=4" - ) - - d.add_path(0, 0, 0, c=np.ones((5, 5, 4))) - assert ( - str(d) - == "ia,jb,kab+ijk sizes=80,160,6144 num_segments=1,1,2 num_paths=1 a={16, 32} b=32 i=5 j=5 k=4" - ) - - assert not d.all_segments_are_used() - assert d.subscripts == "ia,jb,kab+ijk" - assert d.subscripts.is_equivalent("ia,jb,kab+ijk") - - d.assert_valid() - - -def test_squeeze(): - d = stp.SegmentedTensorProduct.from_subscripts("i_j+ij") - d.add_segment(0, (1,)) - d.add_segment(1, (20,)) - d.add_path(0, 0, c=np.ones((1, 20))) - d.assert_valid() - - assert d.squeeze_modes().subscripts == ",j+j" - - d = stp.SegmentedTensorProduct.from_subscripts("i_j+ij") - d.add_segment(0, (1,)) - d.add_segment(0, (2,)) - d.add_segment(1, (20,)) - d.add_path(0, 0, c=np.ones((1, 20))) - d.add_path(1, 0, c=np.ones((2, 20))) - d.assert_valid() - - assert d.squeeze_modes().subscripts == "i,j+ij" - - with pytest.raises(ValueError): - d.squeeze_modes("i") - - -def test_normalize_paths_for_operand(): - d = stp.SegmentedTensorProduct.from_subscripts("i_j+ij") - - d.add_segments(0, 2 * [(2,)]) - d.add_segments(1, 2 * [(3,)]) - d.assert_valid() - - d.add_path(0, 0, c=np.array([[2, 0, 0], [0, 2, 0]])) - d.assert_valid() - - d = d.normalize_paths_for_operand(0) - d.assert_valid() - - np.testing.assert_allclose( - d.paths[0].coefficients, - np.array( - [ - [1, 0, 0], - [0, 1, 0], - ] - ), - ) - - -def make_example_descriptor(): - d = stp.SegmentedTensorProduct.from_subscripts("uv_iu_jv+ij") - d.add_path( - None, - None, - None, - c=np.random.randn(2, 3), - dims={"u": 4, "v": 5, "i": 2, "j": 3}, - ) - d.assert_valid() - d.add_path( - None, - None, - None, - c=np.random.randn(4, 5), - dims={"u": 2, "v": 3, "i": 4, "j": 5}, - ) - d.assert_valid() - return d - - -def test_flatten(): - d = make_example_descriptor() - d.assert_valid() - - assert d.flatten_modes("").subscripts == "uv,iu,jv+ij" - assert d.flatten_modes("i").subscripts == "uv,u,jv+j" - assert d.flatten_modes("j").subscripts == "uv,iu,v+i" - assert d.flatten_modes("ij").subscripts == "uv,u,v" - assert d.flatten_modes("ui").subscripts == "v,,jv+j" - - x0 = np.random.randn(d.operands[0].size) - x1 = np.random.randn(d.operands[1].size) - x2 = stp.compute_last_operand(d, x0, x1) - - for channels in ["i", "j", "ij", "ui", "iju", "uvij"]: - np.testing.assert_allclose( - x2, stp.compute_last_operand(d.flatten_modes(channels), x0, x1) - ) - - -def test_flatten_coefficients(): - d = make_example_descriptor() - - assert d.subscripts == "uv,iu,jv+ij" - assert d.flatten_coefficient_modes().subscripts == "uv,u,v" - - d = d.add_or_transpose_modes("uv,ui,jv+ij") - assert d.subscripts == "uv,ui,jv+ij" - - with pytest.raises(ValueError): - d.flatten_coefficient_modes() - - assert d.flatten_coefficient_modes(force=True).subscripts == "v,,v" - - -def test_consolidate(): - d = stp.SegmentedTensorProduct.from_subscripts("ab_ab") - d.add_segment(0, (2, 3)) - d.add_segment(1, (2, 3)) - d.add_path(0, 0, c=1.0) - d.assert_valid() - - assert d.consolidate_modes().subscripts == "a,a" - - d = stp.SegmentedTensorProduct.from_subscripts("ab_ab_a") - d.add_segment(0, (2, 3)) - d.add_segment(1, (2, 3)) - d.add_segment(2, (2,)) - d.add_path(0, 0, 0, c=1.0) - d.assert_valid() - - assert d.consolidate_modes() == d - - d = stp.SegmentedTensorProduct.from_subscripts("ab_iab+abi") - d.add_segment(0, (2, 3)) - d.add_segment(1, (4, 2, 3)) - d.add_path(0, 0, c=np.ones((2, 3, 4))) - d.assert_valid() - - assert d.consolidate_modes().subscripts == "a,ia+ai" - - d = stp.SegmentedTensorProduct.from_subscripts("ab,iab+abi") - d.add_segment(0, (2, 3)) - d.add_segment(1, (4, 2, 3)) - d.add_path(0, 0, c=np.ones((2, 3, 4))) - - assert d.consolidate_modes().subscripts == "a,ia+ai" - - -def test_stacked_coefficients(): - d = stp.SegmentedTensorProduct.from_subscripts("ab_ab+ab") - d.add_segment(0, (2, 3)) - d.add_segment(1, (2, 3)) - np.testing.assert_allclose(d.stacked_coefficients, np.ones((0, 2, 3))) - - d.add_path(0, 0, c=np.ones((2, 3))) - d.add_path(0, 0, c=np.ones((2, 3))) - np.testing.assert_allclose(d.stacked_coefficients, np.ones((2, 2, 3))) - - d = d.consolidate_paths() - np.testing.assert_allclose(d.stacked_coefficients, 2 * np.ones((1, 2, 3))) - - -@pytest.mark.parametrize("extended", [False, True]) -def test_data_transfer(extended: bool): - d = stp.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") - d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) - d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) - d.assert_valid() - - dict = d.to_dict(extended) - json = d.to_json(extended) - bin = d.to_bytes(extended) - b64 = d.to_base64(extended) - - assert d == stp.SegmentedTensorProduct.from_dict(dict) - assert d == stp.SegmentedTensorProduct.from_json(json) - assert d == stp.SegmentedTensorProduct.from_bytes(bin) - assert d == stp.SegmentedTensorProduct.from_base64(b64) - - -def test_to_text(): - d = stp.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") - d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) - d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) - d = d.flatten_modes("ijk") - - text = d.to_text() - assert ( - text - == """u,u,u sizes=50,64,50 num_segments=4,5,4 num_paths=29 u={12, 14} -operand #0 subscripts=u - | u: [12, 12, 12, 14] -operand #1 subscripts=u - | u: [12, 12, 12, 14, 14] -operand #2 subscripts=u - | u: [12, 12, 12, 14] -Flop cost: 0->704 1->704 2->704 -Memory cost: 164 -Path indices: 0 0 0, 0 0 1, 0 0 2, 0 1 0, 0 1 1, 0 1 2, 0 2 0, 0 2 1, 0 2 2, 1 0 0, 1 0 1, 1 0 2, 1 1 0, 1 1 1, 1 1 2, 1 2 0, 1 2 1, 1 2 2, 2 0 0, 2 0 1, 2 0 2, 2 1 0, 2 1 1, 2 1 2, 2 2 0, 2 2 1, 2 2 2, 3 3 3, 3 4 3 -Path coefficients: [1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0]""" - ) - - -def test_hash(): - d = stp.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") - d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) - d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) - assert hash(d) == hash(d) - - d2 = stp.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") - assert hash(d) != hash(d2) - d2.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) - assert hash(d) != hash(d2) - d2.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) - assert hash(d) == hash(d2) diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index 045e5a75..8e577cd0 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -19,34 +19,33 @@ ) -from .rep_array.jax_rep_array import RepArray, from_segments, IrrepsArray +from .rep_array.rep_array_ import RepArray, from_segments from .rep_array.vmap import vmap -from .rep_array.utils import concatenate, randn, as_irreps_array, clebsch_gordan +from .rep_array.rep_array_utils import concatenate, randn, as_irreps_array, clebsch_gordan -from .primitives.tensor_product import tensor_product -from .primitives.equivariant_tensor_product import equivariant_tensor_product +from .segmented_polynomials.segmented_polynomial import segmented_polynomial +from .equivariant_polynomial import equivariant_polynomial -from .operations.activation import ( +from .activation import ( normalspace, normalize_function, function_parity, scalar_activation, ) -from .operations.spherical_harmonics import spherical_harmonics, normalize, norm +from .spherical_harmonics import spherical_harmonics, normalize, norm from cuequivariance_jax import flax_linen __all__ = [ "RepArray", "from_segments", - "IrrepsArray", "vmap", "concatenate", "randn", "as_irreps_array", "clebsch_gordan", - "tensor_product", - "equivariant_tensor_product", + "segmented_polynomial", + "equivariant_polynomial", "normalspace", "normalize_function", "function_parity", diff --git a/cuequivariance_jax/cuequivariance_jax/operations/activation.py b/cuequivariance_jax/cuequivariance_jax/activation.py similarity index 97% rename from cuequivariance_jax/cuequivariance_jax/operations/activation.py rename to cuequivariance_jax/cuequivariance_jax/activation.py index a171ce97..d919a1bb 100644 --- a/cuequivariance_jax/cuequivariance_jax/operations/activation.py +++ b/cuequivariance_jax/cuequivariance_jax/activation.py @@ -74,6 +74,10 @@ def rho(x): def function_parity(phi: ActFn) -> int: + r"""Determine the parity of a function. + + For example, sin is odd, cos is even, and exp is neither. + """ with jax.ensure_compile_time_eval(): x = jnp.linspace(0.0, 10.0, 256) diff --git a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py new file mode 100644 index 00000000..26f72130 --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py @@ -0,0 +1,168 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp + +import cuequivariance as cue +import cuequivariance_jax as cuex + + +def equivariant_polynomial( + poly: cue.EquivariantPolynomial, + inputs: list[cuex.RepArray | jax.Array], + outputs_shape_dtype: list[jax.ShapeDtypeStruct] + | jax.ShapeDtypeStruct + | None = None, + indices: list[jax.Array | None] | None = None, + math_dtype: jnp.dtype | None = None, + name: str | None = None, + impl: str = "auto", +) -> list[cuex.RepArray] | cuex.RepArray: + """Compute an equivariant polynomial. + + Args: + poly: The equivariant polynomial descriptor. + inputs: List of input :class:`cuex.RepArray `. + outputs_shape_dtype: Shape and dtype specifications for outputs. If None, + inferred from inputs when possible. When output indices are provided, this must be specified. + indices: Optional list of indices for inputs and outputs. Length must match + total number of operands (inputs + outputs). Use None for unindexed + operands. Defaults to None. + math_dtype: Data type for computational operations. If None, automatically + determined from input types. Defaults to None. + name: Optional name for the operation. Defaults to None. + impl: Implementation to use, one of ["auto", "cuda", "jax"]. If "auto", + uses CUDA when available, falling back to JAX otherwise. Defaults to "auto". + + Returns: + Single :class:`cuex.RepArray ` if one output, or list of :class:`cuex.RepArray ` for multiple outputs. + + Examples: + Create and compute spherical harmonics of degree 0, 1, and 2: + + >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) + >>> e + ╭ a=1 -> B=0+1+2 + │ []➜B[] ───────── num_paths=1 + │ []·a[]➜B[] ───── num_paths=3 + ╰─ []·a[]·a[]➜B[] ─ num_paths=11 + + Basic usage with single input: + + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) + >>> cuex.equivariant_polynomial(e, [x]) + {0: 0+1+2} + [1. ... ] + + Using indices: + + >>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32) + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... x = cuex.RepArray("1", jnp.array([ + ... [0.0, 1.0, 0.0], + ... [0.0, 0.0, 1.0], + ... [1.0, 0.0, 0.0], + ... ])) + >>> result = cuex.equivariant_polynomial( + ... e, + ... [x], + ... [jax.ShapeDtypeStruct((2, e.outputs[0].dim), jnp.float32)], + ... indices=[None, i_out], + ... ) + >>> result + {1: 0+1+2} + [[ 1. ... ] + [ 2. ... ]] + """ + if name is None: + name = "equivariant_polynomial" + + if len(inputs) != poly.num_inputs: + raise ValueError( + f"Unexpected number of inputs. Expected {poly.num_inputs}, got {len(inputs)}." + ) + + for i, (x, rep) in enumerate(zip(inputs, poly.inputs)): + if isinstance(x, cuex.RepArray): + assert x.rep(-1) == rep, ( + f"Input {i} should have representation {rep}, got {x.rep(-1)}." + ) + else: + assert x.ndim >= 1, ( + f"Input {i} should have at least one dimension, got {x.ndim}." + ) + assert x.shape[-1] == rep.dim, ( + f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}." + ) + if not rep.is_scalar(): + raise ValueError( + f"Input {i} should be a RepArray unless the input is scalar. Got {type(x)} for {rep}." + ) + + inputs: list[jax.Array] = [getattr(x, "array", x) for x in inputs] + + if indices is None: + indices = [None] * poly.num_operands + + if len(indices) != poly.num_operands: + raise ValueError( + f"Unexpected number of indices. indices should None or a list of length {poly.num_operands}, got a list of length {len(indices)}." + ) + + if outputs_shape_dtype is None: + if not all(i is None for i in indices[poly.num_inputs :]): + raise ValueError( + "When output indices are provided, outputs_shape_dtype must be provided." + ) + if poly.num_inputs == 0: + raise ValueError( + "When no inputs are provided, outputs_shape_dtype must be provided." + ) + inferred_shape = jnp.broadcast_shapes( + *[ + x.shape[:-1] if i is None else i.shape + x.shape[1:-1] + for i, x in zip(indices, inputs) + ] + ) + inferred_dtype = jnp.result_type(*inputs) + outputs_shape_dtype = [ + jax.ShapeDtypeStruct(inferred_shape + (rep.dim,), inferred_dtype) + for rep in poly.outputs + ] + + if hasattr(outputs_shape_dtype, "shape"): + outputs_shape_dtype = [outputs_shape_dtype] + + if len(outputs_shape_dtype) != poly.num_outputs: + raise ValueError( + f"Unexpected number of outputs. Expected {poly.num_outputs}, got {len(outputs_shape_dtype)}." + ) + + outputs = cuex.segmented_polynomial( + poly.polynomial, + inputs, + outputs_shape_dtype, + indices, + math_dtype=math_dtype, + name=name, + impl=impl, + ) + outputs = [cuex.RepArray(rep, x) for rep, x in zip(poly.outputs, outputs)] + + if poly.num_outputs == 1: + return outputs[0] + return outputs diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py deleted file mode 100644 index e40a31d4..00000000 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ /dev/null @@ -1,164 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import jax -import jax.numpy as jnp - -import cuequivariance as cue -import cuequivariance_jax as cuex - - -def equivariant_tensor_product( - e: cue.EquivariantTensorProduct, - *inputs: cuex.RepArray | jax.Array, - indices: list[jax.Array | None] | None = None, - output_batch_shape: tuple[int, ...] | None = None, - output_dtype: jnp.dtype | None = None, - math_dtype: jnp.dtype | None = None, - name: str | None = None, - impl: str = "auto", -) -> cuex.RepArray: - """Compute the equivariant tensor product of the input arrays. - - Args: - e (:class:`cue.EquivariantTensorProduct `): The equivariant tensor product descriptor. - *inputs (RepArray or jax.Array): The input arrays. - indices (list of jax.Array or None, optional): The optional indices of the inputs and output. - output_batch_shape (tuple of int, optional): The batch shape of the output array. - output_dtype (jnp.dtype, optional): The data type for the output array. Defaults to None. - math_dtype (jnp.dtype, optional): The data type for computational operations. Defaults to None. - name (str, optional): The name of the operation. Defaults to None. - - Returns: - RepArray: The result of the equivariant tensor product. - - Examples: - - Let's create a descriptor for the spherical harmonics of degree 0, 1, and 2. - - >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) - >>> e - EquivariantTensorProduct((1)^(0..2) -> 0+1+2) - - We need some input data. - - >>> with cue.assume(cue.SO3, cue.ir_mul): - ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) - >>> x - {0: 1} [0. 1. 0.] - - Now we can execute the equivariant tensor product. - - >>> cuex.equivariant_tensor_product(e, x) - {0: 0+1+2} - [1. ... ] - - The `indices` argument allows to specify a list of optional int32 arrays for each input and for the output (`None` means no index and `indices[-1]` is the output index). The indices are used to select the elements of the input arrays and to specify the output index. - In the following example, we will index the output. The input has a batch shape of (3,) and the output has a batch shape of (2,). - - >>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32) - - The `i_out` array is used to map the result to the output indices. - - >>> with cue.assume(cue.SO3, cue.ir_mul): - ... x = cuex.RepArray("1", jnp.array([ - ... [0.0, 1.0, 0.0], - ... [0.0, 0.0, 1.0], - ... [1.0, 0.0, 0.0], - ... ])) - >>> cuex.equivariant_tensor_product( - ... e, - ... x, - ... indices=[None, i_out], - ... output_batch_shape=(2,), - ... ) - {1: 0+1+2} - [[ 1. ... ] - [ 2. ... ]] - """ - assert e.num_inputs > 0 - - if len(inputs) == 0: - return lambda *inputs: equivariant_tensor_product( - e, - *inputs, - indices=indices, - output_batch_shape=output_batch_shape, - output_dtype=output_dtype, - math_dtype=math_dtype, - name=name, - impl=impl, - ) - - if len(inputs) != e.num_inputs: - raise ValueError( - f"Unexpected number of inputs. Expected {e.num_inputs}, got {len(inputs)}." - ) - - for i, (x, rep) in enumerate(zip(inputs, e.inputs)): - if isinstance(x, cuex.RepArray): - assert x.rep(-1) == rep, ( - f"Input {i} should have representation {rep}, got {x.rep(-1)}." - ) - else: - assert x.ndim >= 1, ( - f"Input {i} should have at least one dimension, got {x.ndim}." - ) - assert x.shape[-1] == rep.dim, ( - f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}." - ) - if not rep.is_scalar(): - raise ValueError( - f"Input {i} should be a RepArray unless the input is scalar. Got {type(x)} for {rep}." - ) - - inputs: list[jax.Array] = [getattr(x, "array", x) for x in inputs] - - if indices is None: - indices = [None] * e.num_operands - - if len(indices) != e.num_operands: - raise ValueError( - f"Unexpected number of indices. indices should None or a list of length {e.num_operands}, got a list of length {len(indices)}." - ) - - if output_dtype is None: - output_dtype = jnp.result_type(*inputs) - - if output_batch_shape is None: - if indices[-1] is not None: - raise ValueError( - "When output indices are provided, output_batch_shape must be provided." - ) - output_batch_shape = jnp.broadcast_shapes( - *[ - x.shape[:-1] if i is None else i.shape + x.shape[1:-1] - for i, x in zip(indices, inputs) - ] - ) - - descriptors = [(cue.Operation(e.map_operands(d.num_operands)), d) for d in e.ds] - - [x] = cuex.tensor_product( - descriptors, - inputs, - [jax.ShapeDtypeStruct(output_batch_shape + (e.output.dim,), output_dtype)], - indices, - math_dtype=math_dtype, - name=name, - impl=impl, - ) - - return cuex.RepArray(e.output, x) diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py b/cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_.py similarity index 99% rename from cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py rename to cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_.py index 9cee2d62..35f539f1 100644 --- a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_.py @@ -603,8 +603,6 @@ def decode_rep_array(static, data) -> RepArray: jax.tree_util.register_pytree_node(RepArray, encode_rep_array, decode_rep_array) -IrrepsArray = RepArray # TODO: do we deprecate IrrepsArray? - def from_segments( irreps: cue.Irreps | str, diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py b/cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_utils.py similarity index 98% rename from cuequivariance_jax/cuequivariance_jax/rep_array/utils.py rename to cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_utils.py index f7ce2e84..4c279ab2 100644 --- a/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_utils.py @@ -20,7 +20,7 @@ import cuequivariance as cue import cuequivariance_jax as cuex -from cuequivariance.irreps_array.misc_ui import assert_same_group +from cuequivariance.group_theory.irreps_array.misc_ui import assert_same_group def concatenate(arrays: list[cuex.RepArray]) -> cuex.RepArray: diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py similarity index 60% rename from cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py rename to cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py index edbe8390..93ed9285 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py @@ -20,22 +20,22 @@ import jax.extend import jax.lax import jax.numpy as jnp -from jax.interpreters import ad, batching, mlir, xla +from jax.interpreters import ad, batching, mlir, partial_eval, xla import cuequivariance as cue -from cuequivariance_jax.primitives.primitives_utils import reshape -from cuequivariance_jax.primitives.tensor_product_ops_impl import ( - tensor_product_ops_impl, +from cuequivariance_jax.segmented_polynomials.segmented_polynomial_ops_impl import ( + segmented_polynomial_ops_impl, ) -from cuequivariance_jax.primitives.tensor_product_vanilla_impl import ( - tensor_product_vanilla_impl, +from cuequivariance_jax.segmented_polynomials.segmented_polynomial_vanilla_impl import ( + segmented_polynomial_vanilla_impl, ) +from cuequivariance_jax.segmented_polynomials.utils import reshape logger = logging.getLogger(__name__) -def tensor_product( - descriptors: list[tuple[cue.Operation, cue.SegmentedTensorProduct]], +def segmented_polynomial( + polynomial: cue.SegmentedPolynomial, inputs: list[jax.Array], outputs_shape_dtype: list[jax.ShapeDtypeStruct], indices: list[jax.Array | None] | None = None, @@ -44,40 +44,62 @@ def tensor_product( name: str | None = None, impl: str = "auto", ) -> list[jax.Array]: - r"""Compute a polynomial described by a list of descriptors. + """Compute a segmented polynomial. - Features: - - Calls a CUDA kernel if: - - STPs have a single mode which is a multiple of 32 (e.g. a channelwise tensor product that has subscripts ``u,u,,u`` with u=128) - - math data type is float32 or float64 - - in/out data type is a mix of float32, float64, float16 and bfloat16 - - indices are int32 - - Supports of infinite derivatives (JVP and tranpose rules maps to a single corresponding primitive) - - Limited support for batching (we cannot batch a buffer that has indices and if the batching is non trivial the performace will be bad) - - Automatic optimizations based on the symmetries of the STPs and on the repetition of the input buffers - - Automatic drop of unused buffers and indices + This function evaluates a segmented polynomial using either CUDA or JAX implementation. + The implementation choice is determined by the input characteristics and availability + of CUDA support. Args: - descriptors (list of pairs): The list of descriptors. - Each descriptor is formed by a pair of :class:`cue.Operation ` and :class:`cue.SegmentedTensorProduct `. - inputs (list of jax.Array): The input buffers. - outputs_shape_dtype (list of jax.ShapeDtypeStruct): The output shapes and dtypes. - indices (list of jax.Array or None, optional): The optional indices of the inputs and outputs. - math_dtype (jnp.dtype, optional): The data type for computational operations. Defaults to None. - name (str, optional): The name of the operation. Defaults to None. - impl (str, optional): The implementation to use. Defaults to "auto". - If "auto", it will use the CUDA implementation if available, otherwise it will use the JAX implementation. - If "cuda", it will use the CUDA implementation. - If "jax", it will use the JAX implementation. + polynomial: The segmented polynomial to compute. + inputs: List of input buffers as JAX arrays. + outputs_shape_dtype: List of output shapes and dtypes specifications. + indices: Optional list of indices for inputs and outputs. If None, no indexing + is applied. Defaults to None. + math_dtype: Data type for computational operations. If None, automatically + determined from input types, defaulting to float32 if no float64 inputs + are present. Defaults to None. + name: Optional name for the operation. Defaults to None. + impl: Implementation to use, one of ["auto", "cuda", "jax"]. If "auto", + uses CUDA when available, falling back to JAX otherwise. Defaults to "auto". Returns: - list of jax.Array: The result of the tensor product. + List of JAX arrays containing the computed tensor product results. + + Features: + - CUDA kernel activation conditions: + - STPs have a single mode which is a multiple of 32 (e.g. channelwise + tensor product with subscripts ``u,u,,u`` where u=128) + - Math data type is float32 or float64 + - Input/output data types can be float32, float64, float16, or bfloat16 + - Indices must be int32 + - Supports infinite derivatives through JVP and transpose rules + - Limited batching support: + - Cannot batch buffers with indices + - Non-trivial batching may impact performance + - Automatic optimizations: + - Based on STP symmetries + - Based on input buffer repetition patterns + - Automatic pruning of unused buffers and indices + + Note: + The function automatically determines the best implementation based on the + input characteristics when impl="auto". For maximum performance with CUDA-capable + hardware, ensure inputs match the CUDA kernel activation conditions. """ if name is None: - name = "tensor_product" + name = "segmented_polynomial" + + assert len(inputs) == polynomial.num_inputs + assert len(outputs_shape_dtype) == polynomial.num_outputs + + outputs_shape_dtype = [ + jax.ShapeDtypeStruct(x.shape[:-1] + (ope.size,), x.dtype) + for x, ope in zip(outputs_shape_dtype, polynomial.outputs) + ] - buffers = inputs + outputs_shape_dtype + buffers = list(inputs) + list(outputs_shape_dtype) if indices is None: indices = [None] * len(buffers) @@ -137,15 +159,15 @@ def fn( outputs_shape_dtype=buffers[len(inputs) :], indices=unique_indices, buffer_index=buffer_index, - descriptors=descriptors, + polynomial=polynomial, math_dtype=math_dtype, name=name, ) if impl == "naive_jax": - outputs = tensor_product_vanilla_impl(**kwargs) + outputs = segmented_polynomial_vanilla_impl(**kwargs) else: - outputs = tensor_product_prim(**kwargs, impl=impl) + outputs = segmented_polynomial_prim(**kwargs, impl=impl) def fn(x: jax.Array, shape: tuple[int, ...]) -> jax.Array: return jnp.reshape(x, shape) @@ -153,16 +175,40 @@ def fn(x: jax.Array, shape: tuple[int, ...]) -> jax.Array: return list(map(fn, outputs, [out.shape for out in outputs_shape_dtype])) -tensor_product_p = jax.extend.core.Primitive("tensor_product") -tensor_product_p.multiple_results = True +segmented_polynomial_p = jax.extend.core.Primitive("segmented_polynomial") +segmented_polynomial_p.multiple_results = True + + +def _dce_helper( + used_inputs: list[bool], + used_outputs: list[bool], + buffer_index: list[int], + num_indices: int, +) -> tuple[list[bool], list[int]]: + used_indices_id: list[int] = sorted( + { + buffer_index[i] + for i, used in enumerate(used_inputs + used_outputs) + if used and buffer_index[i] >= 0 + } + ) + used_indices: list[bool] = [i in used_indices_id for i in range(num_indices)] + + buffer_index = tuple( + used_indices_id.index(buffer_index[i]) if buffer_index[i] >= 0 else -1 + for i, used in enumerate(used_inputs + used_outputs) + if used + ) + + return used_indices, buffer_index -def tensor_product_prim( +def segmented_polynomial_prim( inputs: list[jax.Array], # input buffers outputs_shape_dtype: list[jax.ShapeDtypeStruct], # output shapes and dtypes indices: list[jax.Array], # index buffers buffer_index: list[int], # maps: buffer index -> unique indices index - descriptors: list[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str = "auto", @@ -176,60 +222,43 @@ def tensor_product_prim( assert len(inputs) + len(outputs_shape_dtype) == len(buffer_index) assert max(buffer_index) < len(indices) - descriptors = map( - lambda x: ( - x[0], - x[1].consolidate_modes().remove_empty_segments().consolidate_paths(), - ), - descriptors, + # fuse STPs, consolidate modes, squeeze modes, remove empty segments, consolidate paths, sort paths + polynomial = polynomial.consolidate() + + used_inputs, used_outputs = polynomial.used_inputs(), polynomial.used_outputs() + + used_indices, buffer_index = _dce_helper( + used_inputs, used_outputs, buffer_index, len(indices) ) - descriptors = list(filter(lambda x: x[1].num_paths > 0, descriptors)) - - used_buffers = set() - used_indices = set() - for ope, _ in descriptors: - for i in ope.buffers: - used_buffers.add(i) - if buffer_index[i] >= 0: - used_indices.add(buffer_index[i]) - used_buffers = sorted(used_buffers) # maps: new buffer index -> old buffer index - used_indices = sorted(used_indices) # maps: new index -> old index - - new_num_inputs = sum([i < len(inputs) for i in used_buffers]) - - new_outputs = tensor_product_p.bind( - *[inputs[i] for i in used_buffers[:new_num_inputs]], - *[indices[i] for i in used_indices], - buffer_index=tuple( - used_indices.index(buffer_index[i]) if buffer_index[i] >= 0 else -1 - for i in used_buffers - ), + + new_outputs = segmented_polynomial_p.bind( + *[v for v, used in zip(inputs, used_inputs) if used], + *[v for v, used in zip(indices, used_indices) if used], + buffer_index=buffer_index, outputs_shape_dtype=tuple( - outputs_shape_dtype[i - len(inputs)] for i in used_buffers[new_num_inputs:] - ), - descriptors=frozenset( - [ - (cue.Operation([used_buffers.index(i) for i in ope.buffers]), stp) - for ope, stp in descriptors - ] + x for x, used in zip(outputs_shape_dtype, used_outputs) if used ), + polynomial=polynomial.select_buffers(used_inputs + used_outputs), math_dtype=jnp.dtype(math_dtype), name=str(name), impl=impl, ) if return_none_if_empty: - outputs = [None] * len(outputs_shape_dtype) + old_outputs = [None] * len(outputs_shape_dtype) else: - outputs = [jnp.zeros(out.shape, out.dtype) for out in outputs_shape_dtype] + old_outputs = [jnp.zeros(out.shape, out.dtype) for out in outputs_shape_dtype] - for i, output in zip(used_buffers[new_num_inputs:], new_outputs): - outputs[i - len(inputs)] = output + i_new = 0 + for i_old, used in enumerate(used_outputs): + if used: + old_outputs[i_old] = new_outputs[i_new] + i_new += 1 - return tuple(outputs) + return tuple(old_outputs) -def map_indices( +def _remap_indices_and_buffer_index( old_indices: list[jax.Array], old_buffer_index: list[int], mapping: list[int] ) -> tuple[list[jax.Array], list[int]]: new_indices = [] @@ -252,11 +281,11 @@ def map_indices( return new_indices, new_buffer_index -def tensor_product_abstract_eval( +def segmented_polynomial_abstract_eval( *inputs_and_indices: jax.core.ShapedArray, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -266,12 +295,12 @@ def tensor_product_abstract_eval( ) -def tensor_product_impl( +def segmented_polynomial_impl( platform: str | None, *inputs_and_indices: jax.Array, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -280,49 +309,46 @@ def tensor_product_impl( inputs, indices = inputs_and_indices[:num_inputs], inputs_and_indices[num_inputs:] del inputs_and_indices - def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): - for set_of_operands in ope.operands_with_identical_buffers(): - stp = stp.sort_indices_for_identical_operands(set_of_operands) - stp = stp.sort_paths() - return ope, stp + assert all(polynomial.used_buffers()) + polynomial = polynomial.unsymmetrize_for_identical_operands() - descriptors = list(map(optimize_paths, *zip(*descriptors))) - - outputs = None kwargs = dict( inputs=inputs, outputs_shape_dtype=outputs_shape_dtype, indices=indices, buffer_index=buffer_index, - descriptors=descriptors, + polynomial=polynomial, math_dtype=math_dtype, name=name, ) assert impl in ("auto", "cuda", "jax") - if platform == "cuda" and impl in ("auto", "cuda"): - outputs, msg = tensor_product_ops_impl(**kwargs) + outputs = None + if platform == "cuda": + if impl in ("auto", "cuda"): + outputs = segmented_polynomial_ops_impl(**kwargs) + if impl == "cuda" and not outputs.is_ok(): + raise RuntimeError(f"Failed to use CUDA implementation: {outputs.msg}") + outputs = outputs.unwrap_or(None) else: - msg = f"{platform=}, {impl=}" - - if impl == "cuda" and outputs is None: - raise RuntimeError(f"Failed to use CUDA implementation: {msg}") + if impl == "cuda": + raise RuntimeError(f"{impl=} but platform is {platform}") if outputs is None: - outputs = tensor_product_vanilla_impl(**kwargs) + outputs = segmented_polynomial_vanilla_impl(**kwargs) assert outputs is not None return outputs -def tensor_product_jvp( +def segmented_polynomial_jvp( primals_and_indices: tuple[jax.Array, ...], tangents_and_zeros: tuple[jax.Array | ad.Zero, ...], *, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -337,18 +363,18 @@ def tensor_product_jvp( assert all(isinstance(t, ad.Zero) for t in tangents_and_zeros[num_inputs:]) del primals_and_indices, tangents_and_zeros - out_primals = tensor_product_prim( + out_primals = segmented_polynomial_prim( primals, outputs_shape_dtype, indices, buffer_index, - descriptors, + polynomial, math_dtype, name, impl=impl, ) - jvp_indices, jvp_buffer_index = map_indices( + jvp_indices, jvp_buffer_index = _remap_indices_and_buffer_index( indices, buffer_index, [i for i, x in enumerate(primals)] @@ -356,21 +382,12 @@ def tensor_product_jvp( + [num_inputs + i for i, x in enumerate(outputs_shape_dtype)], ) - jvp_descriptors = [] - for ope, stp in descriptors: - jvps = ope.jvp([not isinstance(t, ad.Zero) for t in tangents]) - permutations: list[tuple[int, ...]] = stp.symmetries() - for multiplicator, ope in cue.Operation.group_by_operational_symmetries( - permutations, jvps - ): - jvp_descriptors.append((ope, multiplicator * stp)) - - out_tangents = tensor_product_prim( + out_tangents = segmented_polynomial_prim( list(primals) + [t for t in tangents if not isinstance(t, ad.Zero)], outputs_shape_dtype, jvp_indices, jvp_buffer_index, - jvp_descriptors, + polynomial.jvp([not isinstance(t, ad.Zero) for t in tangents]), math_dtype, name + "_jvp" @@ -381,12 +398,12 @@ def tensor_product_jvp( return out_primals, out_tangents -def tensor_product_transpose( +def segmented_polynomial_transpose( cotangents: tuple[jax.Array | ad.Zero, ...], *inputs_and_indices: jax.Array | ad.UndefinedPrimal, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -399,7 +416,7 @@ def tensor_product_transpose( # The cotangents replace the outputs as inputs # The undefined primal inputs become outputs - tr_indices, tr_buffer_index = map_indices( + tr_indices, tr_buffer_index = _remap_indices_and_buffer_index( indices, buffer_index, [i for i, x in enumerate(inputs) if not ad.is_undefined_primal(x)] @@ -411,16 +428,7 @@ def tensor_product_transpose( + [i for i, x in enumerate(inputs) if ad.is_undefined_primal(x)], ) - tr_descriptors = [] - for ope, stp in descriptors: - ope = ope.transpose( - [ad.is_undefined_primal(x) for x in inputs], - [not isinstance(x, ad.Zero) for x in cotangents], - ) - if ope is not None: - tr_descriptors.append((ope, stp)) - - tmp = tensor_product_prim( + tmp = segmented_polynomial_prim( [x for x in inputs if not ad.is_undefined_primal(x)] + [x for x in cotangents if not isinstance(x, ad.Zero)], # inputs [ @@ -430,7 +438,10 @@ def tensor_product_transpose( ], tr_indices, tr_buffer_index, - tr_descriptors, + polynomial.transpose( + [ad.is_undefined_primal(x) for x in inputs], + [not isinstance(x, ad.Zero) for x in cotangents], + ), math_dtype, name + "_transpose", impl=impl, @@ -446,13 +457,13 @@ def tensor_product_transpose( return tuple(outputs) -def tensor_product_batching( +def segmented_polynomial_batching( batched_inputs_and_indices: tuple[jax.Array, ...], batch_axes_of_inputs_and_indices: tuple[int | None, ...], *, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -531,12 +542,12 @@ def flatten_index(x: jax.Array) -> jax.Array: for out in outputs_shape_dtype ) - outputs = tensor_product_p.bind( + outputs = segmented_polynomial_p.bind( *batched_inputs, *batched_indices, buffer_index=buffer_index, outputs_shape_dtype=new_outputs_shape_dtype, - descriptors=descriptors, + polynomial=polynomial, math_dtype=math_dtype, name=name + "_batching", impl=impl, @@ -551,22 +562,68 @@ def flatten_index(x: jax.Array) -> jax.Array: return outputs, (0,) * len(outputs) -tensor_product_p.def_abstract_eval(tensor_product_abstract_eval) -tensor_product_p.def_impl(partial(xla.apply_primitive, tensor_product_p)) +def segmented_polynomial_dce( + used_outputs: list[bool], + eqn: jax.extend.core.JaxprEqn, +) -> tuple[list[bool], jax.extend.core.JaxprEqn | None]: + assert len(used_outputs) == len(eqn.outvars) + + polynomial: cue.SegmentedPolynomial = eqn.params["polynomial"] + buffer_index = eqn.params["buffer_index"] + outputs_shape_dtype = eqn.params["outputs_shape_dtype"] + + # If no outputs are used, we can eliminate the operation entirely + if not any(used_outputs) and not eqn.effects: + return [False] * len(eqn.invars), None + + num_inputs = polynomial.num_inputs + + polynomial = polynomial.compute_only(used_outputs) + used_inputs: list[bool] = polynomial.used_inputs() + + used_indices, buffer_index = _dce_helper( + used_inputs, used_outputs, buffer_index, len(eqn.invars) - num_inputs + ) + + new_eqn = jax.extend.core.JaxprEqn( + [v for v, used in zip(eqn.invars, used_inputs + used_indices) if used], + [v for v, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, + dict( + eqn.params, + polynomial=polynomial.select_buffers(used_inputs + used_outputs), + buffer_index=buffer_index, + outputs_shape_dtype=tuple( + x for x, used in zip(outputs_shape_dtype, used_outputs) if used + ), + ), + eqn.effects, + eqn.source_info, + eqn.ctx, + ) + + return used_inputs + used_indices, new_eqn + + +segmented_polynomial_p.def_abstract_eval(segmented_polynomial_abstract_eval) +segmented_polynomial_p.def_impl(partial(xla.apply_primitive, segmented_polynomial_p)) mlir.register_lowering( - tensor_product_p, + segmented_polynomial_p, mlir.lower_fun( - partial(tensor_product_impl, "cuda"), tensor_product_p.multiple_results + partial(segmented_polynomial_impl, "cuda"), + segmented_polynomial_p.multiple_results, ), "cuda", ) mlir.register_lowering( - tensor_product_p, + segmented_polynomial_p, mlir.lower_fun( - partial(tensor_product_impl, None), tensor_product_p.multiple_results + partial(segmented_polynomial_impl, None), + segmented_polynomial_p.multiple_results, ), None, ) -ad.primitive_jvps[tensor_product_p] = tensor_product_jvp -ad.primitive_transposes[tensor_product_p] = tensor_product_transpose -batching.primitive_batchers[tensor_product_p] = tensor_product_batching +ad.primitive_jvps[segmented_polynomial_p] = segmented_polynomial_jvp +ad.primitive_transposes[segmented_polynomial_p] = segmented_polynomial_transpose +batching.primitive_batchers[segmented_polynomial_p] = segmented_polynomial_batching +partial_eval.dce_rules[segmented_polynomial_p] = segmented_polynomial_dce diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py similarity index 57% rename from cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py rename to cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py index 1ea6d94a..1534c1fc 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py @@ -14,16 +14,48 @@ # limitations under the License. import logging import re +from dataclasses import dataclass import jax import jax.numpy as jnp import cuequivariance as cue -from cuequivariance_jax.primitives.primitives_utils import reshape +from cuequivariance_jax.segmented_polynomials.utils import reshape logger = logging.getLogger(__name__) +@dataclass +class Ok: + outputs: list[jax.Array] + + def unwrap(self) -> list[jax.Array]: + return self.outputs + + def unwrap_or(self, default): + return self.outputs + + def is_ok(self) -> bool: + return True + + +@dataclass +class Err: + msg: str + + def unwrap(self): + raise ValueError(self.msg) + + def unwrap_or(self, default): + return default + + def is_ok(self) -> bool: + return False + + +Result = Ok | Err + + def sanitize_string(s): s = re.sub(r"[^A-Za-z0-9_]", "", s) if s == "" or s[0].isdigit(): @@ -31,31 +63,32 @@ def sanitize_string(s): return s -def tensor_product_ops_impl( +def segmented_polynomial_ops_impl( inputs: list[jax.Array], # shape (batch_size, operand_size) outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], indices: list[jax.Array], buffer_index: list[int], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, -) -> tuple[list[jax.Array] | None, str]: - def log(msg: str): +) -> Result: + def make_error(msg: str) -> Err: logger.info(f"[{name}] {msg}") - return None, name + return Err(msg) - num_inputs = len(buffer_index) - len(outputs_shape_dtype) + assert polynomial.num_inputs == len(buffer_index) - len(outputs_shape_dtype) + assert polynomial.num_outputs == len(outputs_shape_dtype) buffers = list(inputs) + list(outputs_shape_dtype) for b in buffers: assert b.ndim == 2, f"Buffer {b.shape} must be 2D" # Reshape buffers to 3D by using the STP informations - for ope, stp in descriptors: + for ope, stp in polynomial.operations: if len(stp.subscripts.modes()) != 1: - return log(f"Unsupported STP: {stp}") + return make_error(f"Unsupported STP: {stp}") if not stp.all_same_segment_shape(): - return log(f"Unsupported STP: {stp}") + return make_error(f"Unsupported STP: {stp}") for i, operand in zip(ope.buffers, stp.operands): b = buffers[i] @@ -63,28 +96,34 @@ def log(msg: str): if b.ndim == 2: b = buffers[i] = reshape(b, shape) if b.shape != shape: - return log(f"Shape mismatch: {b.shape} != {shape} for {i} {stp} {ope}") + return make_error( + f"Shape mismatch: {b.shape} != {shape} for {i} {stp} {ope}" + ) for b in buffers: if b.dtype.type not in {jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16}: - return log(f"Unsupported buffer type: {b.dtype}") + return make_error(f"Unsupported buffer type: {b.dtype}") for i in indices: + # TODO: this restriction will be removed by MR109 if i.dtype.type != jnp.int32: - return log(f"Unsupported index type: {i.dtype}") + return make_error(f"Unsupported index type: {i.dtype}") if not all(b.ndim == 3 for b in buffers): - return log("All buffers must be used") + return make_error("All buffers must be used") if len({b.shape[2] for b in buffers}.union({1})) != 2: - return log(f"Buffer shapes not compatible {[b.shape for b in buffers]}") + return make_error(f"Buffer shapes not compatible {[b.shape for b in buffers]}") + # TODO: this restriction will be removed by MR109 if max(b.shape[2] for b in buffers) % 32 != 0: - return log(f"Extend must be a multiple of 32, got {[b.shape for b in buffers]}") + return make_error( + f"Extend must be a multiple of 32, got {[b.shape for b in buffers]}" + ) math_dtype = jnp.dtype(math_dtype) if math_dtype.type not in {jnp.float32, jnp.float64}: - return log(f"Unsupported math_dtype: {math_dtype}") + return make_error(f"Unsupported math_dtype: {math_dtype}") batch_size = 1 for i, b in zip(buffer_index, buffers): @@ -93,11 +132,13 @@ def log(msg: str): elif b.shape[0] != 1: batch_size = b.shape[0] - # TODO: remove if the backend supports atomic operations for float16/bfloat16 - for i, b in zip(buffer_index[num_inputs:], buffers[num_inputs:]): + # TODO: this restriction will be removed by MR109 + for i, b in zip( + buffer_index[polynomial.num_inputs :], buffers[polynomial.num_inputs :] + ): if b.dtype.type not in {jnp.float32, jnp.float64}: if i >= 0 or b.shape[0] != batch_size: - return log( + return make_error( f"Output buffer {b.shape} of type {b.dtype} and buffer index {i} is not supported" ) @@ -107,25 +148,25 @@ def log(msg: str): Path, tensor_product_uniform_1d_jit, ) - except ImportError: - return log("cuequivariance_ops_jax is not installed") + except ImportError as e: + return make_error(f"cuequivariance_ops_jax is not installed: {e}") operations = [] paths = [] - for ope, stp in descriptors: + for ope, stp in polynomial.operations: operations.append(Operation(ope.buffers, len(paths), stp.num_paths)) for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) - log("Using the uniform 1d kernel of cuequivariance_ops_jax 🚀") + logger.info(f"Using the uniform 1d kernel for '{name}' 🚀\n" + str(polynomial)) outputs = tensor_product_uniform_1d_jit( - buffers[:num_inputs], - buffers[num_inputs:], - indices, + buffers[: polynomial.num_inputs], + buffers[polynomial.num_inputs :], + list(indices), buffer_index, operations=operations, paths=paths, math_dtype=math_dtype, name=sanitize_string(name), ) - return [jnp.reshape(x, (x.shape[0], x.shape[1] * x.shape[2])) for x in outputs], "" + return Ok([jnp.reshape(x, (x.shape[0], x.shape[1] * x.shape[2])) for x in outputs]) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py similarity index 91% rename from cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py rename to cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py index 7abbf2c8..26eea25c 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py @@ -26,12 +26,12 @@ logger = logging.getLogger(__name__) -def tensor_product_vanilla_impl( +def segmented_polynomial_vanilla_impl( inputs: list[jax.Array], # shape (batch_size, operand_size) outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], indices: list[jax.Array], buffer_index: list[int], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, ) -> tuple[jax.Array, ...]: # output buffers @@ -52,7 +52,7 @@ def gather(i: int, x: jax.Array) -> jax.Array: return buffer.at[idx].add(x) return buffer + x - for operation, d in descriptors: + for operation, d in polynomial.operations: ope_out, b_out = operation.output_operand_buffer(num_inputs) out = outputs_shape_dtype[b_out - num_inputs] @@ -86,7 +86,7 @@ def flatten(x: jax.Array, axis: int) -> jax.Array: def sum_cat_list_list( - operand: cue.segmented_tensor_product.Operand, + operand: cue.SegmentedOperand, list_list: list[list[jax.Array]] | jax.Array, batch_shape: tuple[int, ...], dtype: jnp.dtype, @@ -164,6 +164,7 @@ def tp_list_list( d = d.sort_paths(-1) pids = d.compressed_path_segment(-1) ope_out = d.operands[-1] + ss_out = d.subscripts.operands[-1] def ein( coefficients: jax.Array, segments: list[jax.Array], mode: str = "normal" @@ -181,14 +182,14 @@ def ein( term_out = ( "".join(m for m, s in zip(batch_modes, out_batch_shape) if s > 1) + path_out - + ope_out.subscripts + + ss_out ) terms = [path_in + d.coefficient_subscripts] + terms_in + [term_out] formula = ",".join(terms[:-1]) + "->" + terms[-1] segments = [x.astype(coefficients.dtype) for x in segments] segment = jnp.einsum(formula, coefficients, *segments, precision=precision) - segment_shape = segment.shape[segment.ndim - len(ope_out.subscripts) :] + segment_shape = segment.shape[segment.ndim - len(ss_out) :] if mode == "vectorized": num_paths = coefficients.shape[0] @@ -201,10 +202,7 @@ def ein( def prepare(): if not d.all_same_segment_shape(): - raise ValueError( - "cuex.tensor_product: all operands must have the same segment shape\n" - + str(d) - ) + raise ValueError("all operands must have the same segment shape\n" + str(d)) reshaped_inputs = [ jnp.reshape( input, input.shape[:-1] + (ope.num_segments,) + ope.segment_shape @@ -216,7 +214,7 @@ def prepare(): return reshaped_inputs, indices, coefficients if algorithm == "stacked": - logger.debug(f"cuex.tensor_product: {d} with stacked strategy") + logger.debug(f"{d} with stacked strategy") reshaped_inputs, indices, coefficients = prepare() return [ @@ -234,7 +232,7 @@ def prepare(): ] elif algorithm == "compact_stacked": - logger.debug(f"cuex.tensor_product: {d} with compact_stacked strategy") + logger.debug(f"{d} with compact_stacked strategy") reshaped_inputs, indices, coefficients = prepare() return [ @@ -256,7 +254,7 @@ def prepare(): ] elif algorithm == "indexed_vmap": - logger.debug(f"cuex.tensor_product: {d} with indexed_vmap strategy") + logger.debug(f"{d} with indexed_vmap strategy") reshaped_inputs, indices, coefficients = prepare() return ( @@ -279,7 +277,7 @@ def prepare(): ) elif algorithm == "indexed_compact": - logger.debug(f"cuex.tensor_product: {d} with indexed_compact strategy") + logger.debug(f"{d} with indexed_compact strategy") reshaped_inputs, indices, coefficients = prepare() return ( @@ -303,7 +301,7 @@ def prepare(): ) elif algorithm == "indexed_for_loop": - logger.debug(f"cuex.tensor_product: {d} with indexed_for_loop strategy") + logger.debug(f"{d} with indexed_for_loop strategy") reshaped_inputs, indices, coefficients = prepare() def body(pid: int, output: jax.Array) -> jax.Array: @@ -330,7 +328,7 @@ def body(pid: int, output: jax.Array) -> jax.Array: ) elif algorithm == "sliced": - logger.debug(f"cuex.tensor_product: {d} with sliced strategy") + logger.debug(f"{d} with sliced strategy") slices = [operand.segment_slices() for operand in d.operands] return [ @@ -356,7 +354,7 @@ def body(pid: int, output: jax.Array) -> jax.Array: ] elif algorithm == "no-op": - warnings.warn(f"cuex.tensor_product: {d} skipping computation!!!") + warnings.warn(f"{d} skipping computation!!!") dummy = sum([jnp.sum(input) for input in inputs]) @@ -369,4 +367,4 @@ def body(pid: int, output: jax.Array) -> jax.Array: for pid_start, pid_end in zip(pids[:-1], pids[1:]) ] - raise NotImplementedError(f"cuex.tensor_product: unknown algorithm {algorithm}") + raise NotImplementedError(f"unknown algorithm {algorithm}") diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/primitives_utils.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/primitives/primitives_utils.py rename to cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py diff --git a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py b/cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py similarity index 62% rename from cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py rename to cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py index f02f6623..3da4c16f 100644 --- a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py +++ b/cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py @@ -28,14 +28,20 @@ def spherical_harmonics( ) -> cuex.RepArray: """Compute the spherical harmonics of a vector. + The spherical harmonics are polynomials of an input vector. This function computes the polynomials of the specified degrees. + Args: - ls (list of int): List of spherical harmonic degrees. - vector (RepArray): Input vector(s). - normalize (bool): Whether to normalize the vector before computing the spherical harmonics. - algorithm (str): Algorithm to use for the tensor product. See :class:`cuex.tensor_product ` for more information. + ls (list[int]): List of spherical harmonic degrees. Each degree must be non-negative. + vector (RepArray): Input vector. Must be a single vector (multiplicity 1) with 3 components. + normalize (bool, optional): Whether to normalize the vector before computing the spherical harmonics. Defaults to True. Returns: - RepArray: Spherical harmonics of the vector. + RepArray: The spherical harmonics of the vector, containing the polynomials of each specified degree. + + Example: + >>> import cuequivariance_jax as cuex + >>> vector = cuex.randn(jax.random.key(0), cue.IrrepsAndLayout(cue.Irreps(cue.O3, "1o"), cue.mul_ir)) + >>> harmonics = spherical_harmonics([0, 1, 2], vector) """ ls = list(ls) assert vector.is_irreps_array() @@ -49,9 +55,9 @@ def spherical_harmonics( if normalize: vector = _normalize(vector) - return cuex.equivariant_tensor_product( + return cuex.equivariant_polynomial( descriptors.spherical_harmonics(ir, ls, vector.layout), - vector, + [vector], name="spherical_harmonics", ) @@ -85,7 +91,25 @@ def f(x: jax.Array) -> jax.Array: def norm(array: cuex.RepArray, *, squared: bool = False) -> cuex.RepArray: - """Norm of `RepArray`.""" + """Compute the norm of a `RepArray`. + + This function calculates the norm for each element in the irreps array by summing + the squared magnitudes of the elements along the irrep dimension. By default, + the function returns the square root of this sum (the regular norm), but it can + also return the squared norm if requested. + + When the squared norm is zero, the function handles this special case: + - If squared=True, it returns 0.0 + - If squared=False, it safely computes the square root and returns 0.0 + + Args: + array: The equivariant array (RepArray) whose norm should be calculated + squared: If True, returns the squared norm; if False (default), returns the regular norm + + Returns: + A new RepArray with trivial irreps where each element represents the norm + (or squared norm) of the corresponding element in the input array + """ assert array.is_irreps_array() match array.layout: diff --git a/cuequivariance_jax/tests/activation_test.py b/cuequivariance_jax/tests/activation_test.py new file mode 100644 index 00000000..c3ad9d8a --- /dev/null +++ b/cuequivariance_jax/tests/activation_test.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax +import jax.numpy as jnp +import pytest + +import cuequivariance as cue +import cuequivariance_jax as cuex +from cuequivariance_jax.activation import ( + function_parity, + normalize_function, + normalspace, + scalar_activation, +) + + +def test_normalspace(): + """Test the normalspace function.""" + n = 100 + points = normalspace(n) + + assert points.shape == (n,) + assert jnp.all(jnp.diff(points) > 0) # Points in ascending order + assert jnp.isclose(jnp.mean(points), 0.0, atol=1e-6) + + +def test_normalize_function(): + """Test the normalize_function.""" + # Test with constant and non-constant functions + norm_const = normalize_function(lambda x: jnp.ones_like(x)) + norm_linear = normalize_function(lambda x: x) + + test_points = normalspace(1001) + + # Check normalization + assert jnp.isclose(jnp.mean(norm_const(test_points) ** 2), 1.0, atol=1e-2) + assert jnp.isclose(jnp.mean(norm_linear(test_points) ** 2), 1.0, atol=5e-2) + + # Test zero function (should raise ValueError) + with pytest.raises(ValueError): + normalize_function(lambda x: jnp.zeros_like(x)) + + +def test_function_parity(): + """Test the function_parity function.""" + # Test even, odd, and neither functions + assert function_parity(jnp.cos) == 1 # Even + assert function_parity(jnp.sin) == -1 # Odd + assert function_parity(jnp.exp) == 0 # Neither + + +@cue.assume("SO3", cue.ir_mul) +def test_scalar_activation(): + """Test scalar_activation function.""" + # Create test data + irreps = cue.Irreps("SO3", "2x0 + 0") + x = cuex.randn(jax.random.key(42), irreps, (5,)) + + # Test with a single activation + y = scalar_activation(x, lambda x: 2 * x) + assert y.irreps == x.irreps + assert y.shape == x.shape + + # Test with multiple activations + y = scalar_activation(x, [jnp.sin, jnp.cos]) + assert y.irreps == x.irreps + + # Test with a dict of activations + y = scalar_activation(x, {cue.SO3(0): jnp.sin}) + assert y.shape == x.shape + + # Test with non-scalar irreps + irreps_with_vectors = cue.Irreps("SO3", "0 + 1") + x_with_vectors = cuex.randn(jax.random.key(43), irreps_with_vectors, (5,)) + + # Should assert when trying to apply activation to non-scalar + with pytest.raises(AssertionError): + scalar_activation(x_with_vectors, jnp.sin) + + # Should work with None for non-scalar components + y = scalar_activation(x_with_vectors, [jnp.sin, None]) + assert y.irreps == x_with_vectors.irreps diff --git a/cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_jax/tests/equivariant_polynomial_test.py similarity index 91% rename from cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py rename to cuequivariance_jax/tests/equivariant_polynomial_test.py index 93b6a171..1d95ddd9 100644 --- a/cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_jax/tests/equivariant_polynomial_test.py @@ -26,7 +26,9 @@ def test_special_double_backward(): 32 * cue.Irreps("O3", "0e + 1o + 2e"), 32 * cue.Irreps("O3", "0e + 1o"), [1, 2] ) rep_w, rep_x = e.inputs - h = cuex.equivariant_tensor_product(e) + + def h(*inputs): + return cuex.equivariant_polynomial(e, inputs) h0 = lambda w, x: h(w, x).array.sum() ** 2 # noqa h1 = lambda w, x: jax.grad(h0, 1)(w, x).array.sum() ** 2 # noqa @@ -56,7 +58,7 @@ def make_uniform1d_descriptors(): @pytest.mark.parametrize("e", make_uniform1d_descriptors()) -def test_custom_kernel(e: cue.EquivariantTensorProduct): +def test_custom_kernel(e: cue.EquivariantPolynomial): if jax.default_backend() != "gpu": pytest.skip("test_custom_kernel requires CUDA") @@ -79,13 +81,14 @@ def test_custom_kernel(e: cue.EquivariantTensorProduct): output_batch_shape = (num_nodes,) def fwd(inputs, indices, impl): - return cuex.equivariant_tensor_product( + return cuex.equivariant_polynomial( e, - indices=indices, - output_batch_shape=output_batch_shape, + inputs, + jax.ShapeDtypeStruct(output_batch_shape + (e.outputs[0].dim,), jnp.float64), + indices, math_dtype=jnp.float64, impl=impl, - )(*inputs).array + ).array out0 = fwd(inputs, indices, impl="jax") out1 = fwd(inputs, indices, impl="cuda") diff --git a/cuequivariance_jax/tests/irreps_array/jax_irreps_array_test.py b/cuequivariance_jax/tests/jax_rep_array_test.py similarity index 100% rename from cuequivariance_jax/tests/irreps_array/jax_irreps_array_test.py rename to cuequivariance_jax/tests/jax_rep_array_test.py diff --git a/cuequivariance_jax/tests/primitives/tensor_product_test.py b/cuequivariance_jax/tests/segmented_polynomial_test.py similarity index 74% rename from cuequivariance_jax/tests/primitives/tensor_product_test.py rename to cuequivariance_jax/tests/segmented_polynomial_test.py index 0f48dafd..a2b05aad 100644 --- a/cuequivariance_jax/tests/primitives/tensor_product_test.py +++ b/cuequivariance_jax/tests/segmented_polynomial_test.py @@ -24,14 +24,22 @@ def test_one_operand(): d = cue.SegmentedTensorProduct.empty_segments([1]) - [out] = cuex.tensor_product( - [(cue.Operation([0]), d)], [], [jax.ShapeDtypeStruct((2, 1), jnp.float32)] + [out] = cuex.segmented_polynomial( + cue.SegmentedPolynomial( + [], [cue.SegmentedOperand.empty_segments(1)], [(cue.Operation([0]), d)] + ), + [], + [jax.ShapeDtypeStruct((2, 1), jnp.float32)], ) np.testing.assert_array_equal(out, np.array([[0.0], [0.0]])) d.add_path(0, c=123) - [out] = cuex.tensor_product( - [(cue.Operation([0]), d)], [], [jax.ShapeDtypeStruct((2, 1), jnp.float32)] + [out] = cuex.segmented_polynomial( + cue.SegmentedPolynomial( + [], [cue.SegmentedOperand.empty_segments(1)], [(cue.Operation([0]), d)] + ), + [], + [jax.ShapeDtypeStruct((2, 1), jnp.float32)], ) np.testing.assert_array_equal(out, np.array([[123.0], [123.0]])) @@ -44,10 +52,8 @@ def test_UnshapedArray_bug(): x = jnp.ones((2, 1)) def f(w, x): - [out] = cuex.tensor_product( - [(cue.Operation([0, 2]), e.ds[0]), (cue.Operation([0, 1, 2]), e.ds[1])], - [w, x], - [jax.ShapeDtypeStruct((2, 1), jnp.float32)], + [out] = cuex.segmented_polynomial( + e.polynomial, [w, x], [jax.ShapeDtypeStruct((2, 1), jnp.float32)] ) return jnp.sum(out) @@ -59,9 +65,9 @@ def test_multiple_operand_shape_bug(): # Before, it was not possible to have an input # with a different shape than the output of the same operand. def h(x): - d = cue.descriptors.spherical_harmonics(cue.SO3(1), [2]).d - [out] = cuex.tensor_product( - [(cue.Operation([0, 0, 1]), d)], + e = cue.descriptors.spherical_harmonics(cue.SO3(1), [2]) + [out] = cuex.segmented_polynomial( + e.polynomial, [x], [jax.ShapeDtypeStruct((5,), jnp.float32)], ) @@ -89,13 +95,18 @@ def test_vmap(): e = cue.descriptors.full_tensor_product( cue.Irreps("SO3", "1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "1") ) + d = e.polynomial.operations[0][1] def f(x1, x2, i1): - return cuex.tensor_product( - [ - (cue.Operation([0, 1, 2]), e.ds[0]), - (cue.Operation([0, 1, 3]), e.ds[0]), - ], + return cuex.segmented_polynomial( + cue.SegmentedPolynomial( + d.operands[:2], + [d.operands[2], d.operands[2]], + [ + (cue.Operation([0, 1, 2]), d), + (cue.Operation([0, 1, 3]), d), + ], + ), [x1, x2], [ jax.ShapeDtypeStruct((2, 3), jnp.float32), diff --git a/cuequivariance_jax/tests/operations/spherical_harmonics_test.py b/cuequivariance_jax/tests/spherical_harmonics_test.py similarity index 53% rename from cuequivariance_jax/tests/operations/spherical_harmonics_test.py rename to cuequivariance_jax/tests/spherical_harmonics_test.py index 68305557..416fe383 100644 --- a/cuequivariance_jax/tests/operations/spherical_harmonics_test.py +++ b/cuequivariance_jax/tests/spherical_harmonics_test.py @@ -36,3 +36,35 @@ def test_spherical_harmonics(shape): y = cuex.spherical_harmonics([0, 1, 2], x) assert y.shape == shape + (9,) assert y.irreps == cue.Irreps(cue.O3, "0e + 1o + 2e") + + +@pytest.mark.parametrize("squared", [True, False]) +def test_norm(squared): + """Test the norm function with a batch of vectors.""" + # Create a RepArray with known values + irreps = cue.Irreps(cue.O3, "1o") + shape = (10,) + + # Create batch data with specific values in first two positions + data = np.zeros(shape + (3,)) + data[0] = [3.0, 4.0, 0.0] # norm = 5.0 + data[1] = [0.0, 0.0, 0.0] # norm = 0.0 + data[2:] = np.random.randn(shape[0] - 2, 3) # random values + + array = cuex.RepArray(irreps, data, cue.ir_mul) + + # Calculate norm + norm_result = cuex.norm(array, squared=squared) + + # Verify basic properties + assert norm_result.irreps.dim == 1 + for _, ir in norm_result.irreps: + assert ir.is_trivial() + assert norm_result.shape == shape + (1,) + + # Verify specific test values + expected_first = 25.0 if squared else 5.0 # norm of [3,4,0] + expected_second = 0.0 # norm of [0,0,0] + + np.testing.assert_allclose(norm_result.array[0, 0], expected_first, rtol=1e-6) + np.testing.assert_allclose(norm_result.array[1, 0], expected_second, rtol=1e-6) diff --git a/cuequivariance_torch/cuequivariance_torch/layers/batchnorm.py b/cuequivariance_torch/cuequivariance_torch/layers/batchnorm.py index d71fe644..98f0c766 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/batchnorm.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/batchnorm.py @@ -16,7 +16,10 @@ import torch import cuequivariance as cue -from cuequivariance.irreps_array.misc_ui import default_irreps, default_layout +from cuequivariance.group_theory.irreps_array.misc_ui import ( + default_irreps, + default_layout, +) # This implementation is an adaptation of https://github.com/e3nn/e3nn/blob/ef93f876c9985b3816aefb2982b3cf4325df6ba4/e3nn/nn/_batchnorm.py diff --git a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py index a19ba1a4..55e030ea 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py @@ -19,7 +19,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance.irreps_array.misc_ui import ( +from cuequivariance.group_theory.irreps_array.misc_ui import ( assert_same_group, default_irreps, default_layout, diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index 0191c8b5..fc124023 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -19,7 +19,10 @@ import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -from cuequivariance.irreps_array.misc_ui import assert_same_group, default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import ( + assert_same_group, + default_irreps, +) class Linear(torch.nn.Module): @@ -59,7 +62,7 @@ def __init__( math_dtype = math_dtype or dtype e = descriptors.linear(irreps_in, irreps_out) - assert e.d.subscripts == "uv,iu,iv" + assert e.polynomial.operations[0][1].subscripts == "uv,iu,iv" self.irreps_in = irreps_in self.irreps_out = irreps_out diff --git a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py index 69a88a07..f457e7b2 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py @@ -19,7 +19,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -from cuequivariance.irreps_array.misc_ui import default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import default_irreps class Rotation(torch.nn.Module): diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 2366c9b7..d2e6b4f8 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -18,10 +18,13 @@ import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance.experimental.mace.symmetric_contractions import ( +from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( symmetric_contraction, ) -from cuequivariance.irreps_array.misc_ui import assert_same_group, default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import ( + assert_same_group, + default_irreps, +) class SymmetricContraction(torch.nn.Module): diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index 9bb213f5..eb541bdc 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -19,7 +19,10 @@ import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -from cuequivariance.irreps_array.misc_ui import assert_same_group, default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import ( + assert_same_group, + default_irreps, +) class ChannelWiseTensorProduct(torch.nn.Module): @@ -68,7 +71,10 @@ def __init__( e = descriptors.channelwise_tensor_product( irreps_in1, irreps_in2, filter_irreps_out ) - descriptor, irreps_out = e.d, e.operands[-1].irreps + descriptor, irreps_out = ( + e.polynomial.operations[0][1], + e.operands[-1].irreps, + ) assert descriptor.subscripts == "uv,iu,jv,kuv+ijk" self.irreps_in1 = irreps_in1 diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index 1a003f7c..7a4a724b 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -19,7 +19,10 @@ import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -from cuequivariance.irreps_array.misc_ui import assert_same_group, default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import ( + assert_same_group, + default_irreps, +) class FullyConnectedTensorProduct(torch.nn.Module): @@ -70,13 +73,13 @@ def __init__( e = descriptors.fully_connected_tensor_product( irreps_in1, irreps_in2, irreps_out ) - assert e.d.subscripts == "uvw,iu,jv,kw+ijk" + assert e.polynomial.operations[0][1].subscripts == "uvw,iu,jv,kw+ijk" self.irreps_in1 = irreps_in1 self.irreps_in2 = irreps_in2 self.irreps_out = irreps_out - self.weight_numel = e.d.operands[0].size + self.weight_numel = e.polynomial.operations[0][1].operands[0].size self.shared_weights = shared_weights self.internal_weights = ( diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index cfa76d96..b099c299 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -18,7 +18,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance.irreps_array.misc_ui import default_layout +from cuequivariance.group_theory.irreps_array.misc_ui import default_layout class Dispatcher(torch.nn.Module): @@ -156,6 +156,23 @@ def __init__( use_fallback: Optional[bool] = None, ): super().__init__() + + # TODO: remove this when re-design + if isinstance(e, cue.EquivariantPolynomial): + assert e.num_outputs == 1 + for ope, stp in e.polynomial.operations: + inputs = list(range(e.num_inputs)) + output = e.num_inputs + expected = tuple( + inputs[: stp.num_operands - 1] + + [inputs[-1]] * max(0, stp.num_operands - e.num_operands) + + [output] + ) + assert ope.buffers == expected, f"{ope.buffers} != {expected}" + e = cue.EquivariantTensorProduct( + [stp for _, stp in e.polynomial.operations], e.inputs + e.outputs + ) + if not isinstance(layout_in, tuple): layout_in = (layout_in,) * e.num_inputs if len(layout_in) != e.num_inputs: diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 9c7a5739..5da21f2c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -19,7 +19,7 @@ import torch import torch.fx -import cuequivariance.segmented_tensor_product as stp +import cuequivariance as cue import cuequivariance_torch as cuet logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class SymmetricTensorProduct(torch.nn.Module): def __init__( self, - descriptors: list[stp.SegmentedTensorProduct], + descriptors: list[cue.SegmentedTensorProduct], *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, @@ -47,10 +47,14 @@ def __init__( self.descriptors = descriptors descriptors = [ - stp.SegmentedTensorProduct( - operands=[stp.Operand.empty_segments(1)] + d.operands, + cue.SegmentedTensorProduct( + operands_and_subscripts=[(cue.SegmentedOperand.empty_segments(1), "")] + + list(d.operands_and_subscripts), paths=[ - stp.Path((0,) + path.indices, path.coefficients) for path in d.paths + cue.segmented_polynomials.Path( + (0,) + path.indices, path.coefficients + ) + for path in d.paths ], coefficient_subscripts=d.coefficient_subscripts, ) @@ -99,7 +103,7 @@ class IWeightedSymmetricTensorProduct(torch.nn.Module): Parameters ---------- - descriptors : list[stp.SegmentedTensorProduct] + descriptors : list[cue.SegmentedTensorProduct] The list of SegmentedTensorProduct descriptors math_dtype : torch.dtype, optional The data type of the coefficients and calculations @@ -107,7 +111,7 @@ class IWeightedSymmetricTensorProduct(torch.nn.Module): def __init__( self, - descriptors: list[stp.SegmentedTensorProduct], + descriptors: list[cue.SegmentedTensorProduct], *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, @@ -197,7 +201,7 @@ def forward( return self.f(x0, i0, x1) -def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): +def _check_descriptors(descriptors: list[cue.SegmentedTensorProduct]): if len(descriptors) == 0: raise ValueError("stps must contain at least one STP.") @@ -218,7 +222,7 @@ def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): class CUDAKernel(torch.nn.Module): def __init__( self, - ds: list[stp.SegmentedTensorProduct], + ds: list[cue.SegmentedTensorProduct], device: Optional[torch.device], math_dtype: torch.dtype, ): @@ -239,14 +243,14 @@ def __init__( if len({d.operands[-1].num_segments for d in ds}) != 1: raise ValueError("All STPs must have the same number of segments in x2.") - def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: + def f(d: cue.SegmentedTensorProduct) -> cue.SegmentedTensorProduct: d = d.move_operand(0, -2) d = d.flatten_coefficient_modes(force=True) d = d.flatten_modes( [ m for m in d.subscripts.modes() - if not all(m in ope.subscripts for ope in d.operands) + if not all(m in ss for ss in d.subscripts.operands) ] ) d = d.consolidate_modes() @@ -261,7 +265,7 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: m = d.subscripts.modes()[0] - if not all(ope.subscripts == m for ope in d.operands): + if not all(ss == m for ss in d.subscripts.operands): raise NotImplementedError("Different subscripts are not supported.") d = d.split_mode(m, math.gcd(*d.get_dims(m))) @@ -336,7 +340,7 @@ def forward( class FallbackImpl(torch.nn.Module): def __init__( self, - stps: list[stp.SegmentedTensorProduct], + stps: list[cue.SegmentedTensorProduct], device: Optional[torch.device], math_dtype: Optional[torch.dtype], ): diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index daf668fb..38c33b13 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -21,7 +21,7 @@ import torch import torch.fx -from cuequivariance import segmented_tensor_product as stp +import cuequivariance as cue logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class TensorProduct(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, @@ -198,7 +198,7 @@ def disable_type_conv(t): def _tensor_product_fx( - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, optimize_einsums: bool, @@ -223,9 +223,7 @@ def _tensor_product_fx( for i in range(num_inputs) ] - operand_subscripts = [ - f"Z{operand.subscripts}" for operand in descriptor.operands - ] + operand_subscripts = [f"Z{ss}" for ss in descriptor.subscripts.operands] formula = ( ",".join([descriptor.coefficient_subscripts] + operand_subscripts[:-1]) @@ -315,7 +313,7 @@ def _sum(tensors, *, shape=None, like=None): elif num_inputs == 0: class _no_input(torch.nn.Module): - def __init__(self, descriptor: stp.SegmentedTensorProduct): + def __init__(self, descriptor: cue.SegmentedTensorProduct): super().__init__() for pid, path in enumerate(descriptor.paths): @@ -410,7 +408,7 @@ def forward(self, args: List[torch.Tensor]): class _Wrapper(torch.nn.Module): - def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProduct): + def __init__(self, module: torch.nn.Module, descriptor: cue.SegmentedTensorProduct): super().__init__() self.module = CALL_DISPATCHERS[descriptor.num_operands - 1](module) self.descriptor = descriptor @@ -420,7 +418,7 @@ def forward(self, args: List[torch.Tensor]): def _tensor_product_cuda( - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, ) -> torch.nn.Module: @@ -465,7 +463,7 @@ def _tensor_product_cuda( return TensorProductUniform4x1d(d, device, math_dtype) supported_targets = [ - stp.Subscripts(subscripts) + cue.segmented_polynomials.Subscripts(subscripts) for subscripts in [ "u__uw_w", "_v_vw_w", @@ -483,7 +481,9 @@ def _tensor_product_cuda( try: descriptor, perm = next( - stp.dispatch(descriptor, supported_targets, "permute_all_but_last") + cue.segmented_polynomials.dispatch( + descriptor, supported_targets, "permute_all_but_last" + ) ) except StopIteration: raise NotImplementedError( @@ -507,7 +507,7 @@ def _permutation_module(permutation: Tuple[int, ...]): class FusedTensorProductOp3(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, perm: Tuple[int, int], device: Optional[torch.device], math_dtype: torch.dtype, @@ -527,7 +527,7 @@ def __init__( import cuequivariance_ops_torch as ops self._f = ops.FusedTensorProductOp3( - operand_segment_modes=[ope.subscripts for ope in descriptor.operands], + operand_segment_modes=descriptor.subscripts.operands, operand_segment_offsets=[ [s.start for s in ope.segment_slices()] for ope in descriptor.operands ], @@ -562,7 +562,7 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: class FusedTensorProductOp4(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, perm: Tuple[int, int, int], device: Optional[torch.device], math_dtype: torch.dtype, @@ -582,7 +582,7 @@ def __init__( import cuequivariance_ops_torch as ops self._f = ops.FusedTensorProductOp4( - operand_segment_modes=[ope.subscripts for ope in descriptor.operands], + operand_segment_modes=descriptor.subscripts.operands, operand_segment_offsets=[ [s.start for s in ope.segment_slices()] for ope in descriptor.operands ], @@ -620,7 +620,7 @@ def forward( class TensorProductUniform1d(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, ): @@ -694,7 +694,7 @@ def forward( class TensorProductUniform3x1dIndexed(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, ): @@ -751,7 +751,7 @@ def forward( class TensorProductUniform4x1dIndexed(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, ): diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 3bc312d4..632e2092 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -145,7 +145,7 @@ def _transpose_segments_fx(segments: list[tuple[int, int]]) -> torch.nn.Module: x = torch.fx.Proxy(graph.placeholder("input"), tracer) outputs = [] - source = cue.segmented_tensor_product.Operand(subscripts="ij", segments=segments) + source = cue.SegmentedOperand(ndim=2, segments=segments) for sl, (u, v) in zip(source.segment_slices(), source.segments): outputs += [ x[..., sl] diff --git a/cuequivariance_torch/tests/operations/channel_wise_test.py b/cuequivariance_torch/tests/operations/channel_wise_test.py index 18bf586d..f88591aa 100644 --- a/cuequivariance_torch/tests/operations/channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/channel_wise_test.py @@ -69,7 +69,9 @@ def test_channel_wise_fwd( out1 = m1(x1, x2) - d = descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3).d + d = descriptors.channelwise_tensor_product( + irreps1, irreps2, irreps3 + ).polynomial.operations[0][1] d = d.squeeze_modes("v") assert d.subscripts == "u,iu,j,ku+ijk" if layout == cue.mul_ir: diff --git a/cuequivariance_torch/tests/operations/fully_connected_test.py b/cuequivariance_torch/tests/operations/fully_connected_test.py index 7d62d6a2..fad1de08 100644 --- a/cuequivariance_torch/tests/operations/fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/fully_connected_test.py @@ -70,7 +70,9 @@ def test_fully_connected( out1 = m1(x1, x2) - d = descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3).d + d = descriptors.fully_connected_tensor_product( + irreps1, irreps2, irreps3 + ).polynomial.operations[0][1] if layout == cue.mul_ir: d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") m2 = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).to(device) diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 8a67026c..55b504bd 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -20,7 +20,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance.experimental.e3nn import O3_e3nn +from cuequivariance.group_theory.experimental.e3nn import O3_e3nn from cuequivariance_torch._tests.utils import ( module_with_mode, ) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 99dec626..71f2b34d 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -227,4 +227,4 @@ def test_high_degrees(use_fallback: bool, batch_size: int): for rep in e.inputs ] output = m(*inputs) - assert output.shape == (batch_size, e.output.dim) + assert output.shape == (batch_size, e.outputs[0].dim) diff --git a/cuequivariance_torch/tests/primitives/primitive_export_test.py b/cuequivariance_torch/tests/primitives/primitive_export_test.py index f188434f..9fa32228 100644 --- a/cuequivariance_torch/tests/primitives/primitive_export_test.py +++ b/cuequivariance_torch/tests/primitives/primitive_export_test.py @@ -39,9 +39,10 @@ def test_script_symmetric_contraction(mode, tmp_path): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") - ds = cue.descriptors.symmetric_contraction( + e = cue.descriptors.symmetric_contraction( 32 * cue.Irreps("SO3", "0 + 1"), 32 * cue.Irreps("SO3", "0 + 1"), [1, 2, 3] - ).ds + ) + ds = [stp for _, stp in e.polynomial.operations] batch = 12 x0 = torch.randn(3, ds[0].operands[0].size, device=device, dtype=torch.float32) @@ -65,7 +66,8 @@ def test_script_fused_tp_3(mode, tmp_path): cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") ) - .d.flatten_coefficient_modes() + .polynomial.operations[0][1] + .flatten_coefficient_modes() .squeeze_modes("v") ) @@ -97,7 +99,8 @@ def test_script_fused_tp_4(mode, tmp_path): cue.descriptors.fully_connected_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") ) - .d.flatten_coefficient_modes() + .polynomial.operations[0][1] + .flatten_coefficient_modes() .squeeze_modes("v") .permute_operands([1, 2, 0, 3]) ) @@ -133,7 +136,8 @@ def test_script_uniform_tp_3(mode, tmp_path): cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") ) - .d.flatten_coefficient_modes() + .polynomial.operations[0][1] + .flatten_coefficient_modes() .squeeze_modes("v") ) @@ -166,7 +170,8 @@ def test_script_uniform_tp_4(mode, tmp_path): cue.descriptors.channelwise_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") ) - .d.flatten_coefficient_modes() + .polynomial.operations[0][1] + .flatten_coefficient_modes() .squeeze_modes("v") ) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 37463f85..42af7b66 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -16,7 +16,6 @@ import torch import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp import cuequivariance_torch as cuet from cuequivariance import descriptors from cuequivariance_torch._tests.utils import ( @@ -27,14 +26,15 @@ def make_descriptors(): - yield descriptors.symmetric_contraction( + [(_, d1), (_, d2), (_, d3)] = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] - ).ds + ).polynomial.operations + yield [d1, d2, d3] - d1 = stp.SegmentedTensorProduct.from_subscripts(",,") + d1 = cue.SegmentedTensorProduct.from_subscripts(",,") d1.add_path(None, None, None, c=2.0) - d3 = stp.SegmentedTensorProduct.from_subscripts(",,,,") + d3 = cue.SegmentedTensorProduct.from_subscripts(",,,,") d3.add_path(None, None, None, None, None, c=3.0) yield [d1, d3] @@ -59,7 +59,7 @@ def make_descriptors(): @pytest.mark.parametrize("ds", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings1) def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( - ds: list[stp.SegmentedTensorProduct], dtype, math_dtype, tol: float, batch_size: int + ds: list[cue.SegmentedTensorProduct], dtype, math_dtype, tol: float, batch_size: int ): use_fallback = not torch.cuda.is_available() @@ -117,9 +117,10 @@ def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: b if use_fallback is False and not torch.cuda.is_available(): pytest.skip("CUDA is not available") - ds = descriptors.symmetric_contraction( + e = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] - ).ds + ) + ds = [stp for _, stp in e.polynomial.operations] m = cuet.IWeightedSymmetricTensorProduct( ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback ) @@ -151,7 +152,7 @@ def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: b @pytest.mark.parametrize("mode", export_modes) @pytest.mark.parametrize("use_fallback", [True, False]) def test_export( - ds: list[stp.SegmentedTensorProduct], + ds: list[cue.SegmentedTensorProduct], mode: str, use_fallback: bool, tmp_path, diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index 6a19072a..3f9bc25f 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -31,23 +31,23 @@ def make_descriptors(): cue.Irreps("O3", "4x0e + 4x1o"), cue.Irreps("O3", "6x0e + 6x1o"), cue.Irreps("O3", "5x0e + 5x1o + 5x2e + 5x1e"), - ).d + ).polynomial.operations[0][1] - yield descriptors.spherical_harmonics(cue.SO3(1), [2]).d - yield descriptors.spherical_harmonics(cue.SO3(1), [3]).d + yield descriptors.spherical_harmonics(cue.SO3(1), [2]).polynomial.operations[0][1] + yield descriptors.spherical_harmonics(cue.SO3(1), [3]).polynomial.operations[0][1] d = descriptors.channelwise_tensor_product( cue.Irreps("SU2", "3x1/2 + 4x1"), cue.Irreps("SU2", "1/2 + 1 + 3/2"), cue.Irreps("SU2", "1/2 + 1"), - ).d + ).polynomial.operations[0][1] yield d d = descriptors.channelwise_tensor_product( cue.Irreps("SO3", "32x1 + 32x2"), cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), - ).d + ).polynomial.operations[0][1] yield d for subscripts in [ diff --git a/docs/api/cuequivariance.descriptors.rst b/docs/api/cuequivariance.descriptors.rst index aed06c7c..aa90c7d8 100644 --- a/docs/api/cuequivariance.descriptors.rst +++ b/docs/api/cuequivariance.descriptors.rst @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. -.. module:: cuequivariance.descriptors -.. currentmodule:: cuequivariance.descriptors +.. module:: cuequivariance.group_theory.descriptors +.. currentmodule:: cuequivariance.group_theory.descriptors Descriptors =========== diff --git a/docs/api/cuequivariance.rst b/docs/api/cuequivariance.rst index 3cc00fc7..d96f8890 100644 --- a/docs/api/cuequivariance.rst +++ b/docs/api/cuequivariance.rst @@ -41,7 +41,7 @@ Group Representations Equivariant Tensor Products --------------------------- -These classes represent tensor products. +These classes represent tensor products and polynomials. .. autosummary:: :toctree: generated/ @@ -51,8 +51,9 @@ These classes represent tensor products. IrrepsLayout IrrepsAndLayout SegmentedTensorProduct - EquivariantTensorProduct Operation + SegmentedPolynomial + EquivariantPolynomial Descriptors ----------- diff --git a/docs/api/cuequivariance_jax.rst b/docs/api/cuequivariance_jax.rst index 5264330f..791fdb1a 100644 --- a/docs/api/cuequivariance_jax.rst +++ b/docs/api/cuequivariance_jax.rst @@ -44,8 +44,8 @@ Tensor Products :toctree: generated/ :template: function_template.rst - equivariant_tensor_product - tensor_product + equivariant_polynomial + segmented_polynomial Extra Modules ------------- diff --git a/docs/index.rst b/docs/index.rst index 6468a5b4..15e5c5dc 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,13 +40,15 @@ The easiest way to install cuEquivariance is from `PyPi `_ us pip install cuequivariance # Installs only the core non-ML components # CUDA kernels for different CUDA versions + pip install cuequivariance-ops-jax-cu12 pip install cuequivariance-ops-torch-cu11 pip install cuequivariance-ops-torch-cu12 Requirements ------------ -``cuequivariance-ops-torch-*`` packages are only available for Linux x86_64 and require PyTorch 2.4.0 or later. + - ``cuequivariance-ops-torch-*`` packages are only available for Linux x86_64 and require PyTorch 2.4.0 or later. + - ``cuequivariance-ops-jax-cu12`` package is only available for Linux x86_64 and requires JAX 0.5.0 or later. Organization ------------ @@ -70,7 +72,7 @@ cuEquivariance is split into three packages: :align: center Most tensor products are defined using the :class:`cue.EquivariantTensorProduct ` class, which encapsulates the :class:`cue.Irreps ` and :class:`cue.IrrepsLayout ` for each input and the output. It also includes one or more instances of :class:`cue.SegmentedTensorProduct `, which define the tensor product operations. -This descriptor is then used to create a :class:`cuet.EquivariantTensorProduct ` module, which can be used in PyTorch models. Or used to execute the tensor product operations using :class:`cuex.equivariant_tensor_product ` in JAX. +This descriptor is then used to create a :class:`cuet.EquivariantTensorProduct ` module, which can be used in PyTorch models. Or used to execute the tensor product operations using :class:`cuex.equivariant_polynomial ` in JAX. Tutorials --------- diff --git a/docs/tutorials/beta.rst b/docs/tutorials/beta.rst index e82d76aa..5dc5dcd2 100644 --- a/docs/tutorials/beta.rst +++ b/docs/tutorials/beta.rst @@ -42,7 +42,7 @@ The segmented tensor product with 3 or 4 operands with one mode can be executed .squeeze_modes() .flatten_coefficient_modes() ) - print(e.ds[0]) + print(e) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") m = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device) @@ -63,7 +63,8 @@ Again for segmented tensor product with 3 or 4 operands with one mode, we can us ) if device.type == "cuda": - m = TensorProductUniform4x1dIndexed(e.ds[0], device, torch.float32) + ((_, d),) = e.polynomial.operations + m = TensorProductUniform4x1dIndexed(d, device, torch.float32) x0 = torch.randn(16, e.inputs[0].dim, device=device) i0 = torch.randint(0, 16, (128,), device=device) diff --git a/docs/tutorials/etp.rst b/docs/tutorials/etp.rst index 08c15efb..0ef852a4 100644 --- a/docs/tutorials/etp.rst +++ b/docs/tutorials/etp.rst @@ -69,10 +69,10 @@ Execution on JAX w = cuex.randn(jax.random.key(0), e.inputs[0]) x = cuex.randn(jax.random.key(1), e.inputs[1]) - cuex.equivariant_tensor_product(e, w, x) + cuex.equivariant_polynomial(e, [w, x]) The function :func:`cuex.randn ` generates random :class:`cuex.RepArray ` objects. -The function :func:`cuex.equivariant_tensor_product ` executes the tensor product. +The function :func:`cuex.equivariant_polynomial ` executes the tensor product. The output is a :class:`cuex.RepArray ` object. diff --git a/docs/tutorials/stp.rst b/docs/tutorials/stp.rst index 65d63df4..28e35a22 100644 --- a/docs/tutorials/stp.rst +++ b/docs/tutorials/stp.rst @@ -175,15 +175,15 @@ Now we are verifying that the output is well normalized. -In JAX, we can use the :func:`cuex.tensor_product ` function. +In JAX, we can use the :func:`cuex.segmented_polynomial ` function. .. jupyter-execute:: w = jax.random.normal(jax.random.key(0), (d.operands[0].size,)) x1 = jax.random.normal(jax.random.key(1), (3000, irreps1.dim)) - [x2] = cuex.tensor_product( - [(cue.Operation([0, 1, 2]), d)], + [x2] = cuex.segmented_polynomial( + cue.SegmentedPolynomial.eval_last_operand(d), [w, x1], [jax.ShapeDtypeStruct((3000, irreps2.dim), jnp.float32)], )