From b55b22538179bd96baa0d624ef9c546e083bec96 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 26 Feb 2025 03:26:54 -0800 Subject: [PATCH 001/107] add SegmentedPolynomial --- cuequivariance/cuequivariance/__init__.py | 11 ++- cuequivariance/cuequivariance/operation.py | 2 + .../cuequivariance/segmented_polynomial.py | 72 +++++++++++++++++++ .../segmented_tensor_product/operand.py | 3 + .../segmented_tensor_product/path.py | 7 ++ .../segmented_tensor_product.py | 26 +++++++ 6 files changed, 115 insertions(+), 6 deletions(-) create mode 100644 cuequivariance/cuequivariance/segmented_polynomial.py diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index 812ad350..724800de 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -47,14 +47,12 @@ reduced_antisymmetric_tensor_product_basis, ) +from cuequivariance.operation import Operation from cuequivariance.segmented_tensor_product import SegmentedTensorProduct +from cuequivariance.segmented_polynomial import SegmentedPolynomial from cuequivariance.equivariant_tensor_product import EquivariantTensorProduct -from cuequivariance.operation import Operation -from cuequivariance import ( - segmented_tensor_product, - descriptors, -) +from cuequivariance import segmented_tensor_product, descriptors __all__ = [ "Rep", @@ -80,9 +78,10 @@ "reduced_tensor_product_basis", "reduced_symmetric_tensor_product_basis", "reduced_antisymmetric_tensor_product_basis", + "Operation", "SegmentedTensorProduct", + "SegmentedPolynomial", "EquivariantTensorProduct", - "Operation", "segmented_tensor_product", "descriptors", ] diff --git a/cuequivariance/cuequivariance/operation.py b/cuequivariance/cuequivariance/operation.py index 7e6ea98f..135dc7a4 100644 --- a/cuequivariance/cuequivariance/operation.py +++ b/cuequivariance/cuequivariance/operation.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import dataclasses import itertools from collections import defaultdict @@ -21,6 +22,7 @@ OVARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" +@dataclasses.dataclass(init=False, frozen=True) class Operation: """Descriptor mapping input/output buffers to tensor product operands. diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py new file mode 100644 index 00000000..1ec6a307 --- /dev/null +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +from typing import Sequence + +import cuequivariance as cue +from cuequivariance.operation import IVARS, OVARS + + +@dataclasses.dataclass(init=False, frozen=True) +class SegmentedPolynomial: + num_inputs: int + num_outputs: int + tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] + + def __init__( + self, + num_inputs: int, + num_outputs: int, + tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], + ): + object.__setattr__(self, "num_inputs", num_inputs) + object.__setattr__(self, "num_outputs", num_outputs) + object.__setattr__(self, "tensor_products", sorted(tensor_products)) + + def __hash__(self) -> int: + return hash((self.num_inputs, self.num_outputs, tuple(self.tensor_products))) + + def __eq__(self, value): + assert isinstance(value, SegmentedPolynomial) + return ( + self.num_inputs == value.num_inputs + and self.num_outputs == value.num_outputs + and self.tensor_products == value.tensor_products + ) + + def __lt__(self, value): + assert isinstance(value, SegmentedPolynomial) + return ( + self.num_inputs, + self.num_outputs, + self.tensor_products, + ) < ( + value.num_inputs, + value.num_outputs, + value.tensor_products, + ) + + def __repr__(self): + text = "" + text += " ".join(IVARS[self.num_inputs]) + text += " -> " + text += " ".join(OVARS[self.num_inputs : self.num_inputs + self.num_outputs]) + tab = "\n " + for ope, stp in self.tensor_products: + text += tab + f"{stp}" + text += tab + ope.to_string(self.num_inputs) + return text diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/operand.py b/cuequivariance/cuequivariance/segmented_tensor_product/operand.py index 67e21f1a..0c6627fe 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/operand.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/operand.py @@ -105,6 +105,9 @@ def __hash__(self) -> int: def __eq__(self, other: Operand) -> bool: return self.subscripts == other.subscripts and self.segments == other.segments + def __lt__(self, other: Operand) -> bool: + return (self.subscripts, self.segments) < (other.subscripts, other.segments) + def __repr__(self) -> str: dims = format_dimensions_dict(self.get_dimensions_dict()) return f"Operand(subscripts={self.subscripts} num_segments={self.num_segments} {dims})" diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/path.py b/cuequivariance/cuequivariance/segmented_tensor_product/path.py index be980265..8af08de5 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/path.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/path.py @@ -75,6 +75,13 @@ def __eq__(self, other: Path) -> bool: self.coefficients, other.coefficients ) + def __lt__(self, other: Path) -> bool: + k1 = (self.indices, self.coefficients.shape) + k2 = (other.indices, other.coefficients.shape) + if k1 != k2: + return k1 < k2 + return tuple(self.coefficients.flatten()) < tuple(other.coefficients.flatten()) + def permute_operands(self, perm: tuple[int, ...]) -> Path: """ Apply a permutation to the operands. diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index 2eb0af2c..8c73b66f 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -182,6 +182,32 @@ def __hash__(self) -> int: (tuple(self.operands), tuple(self.paths), self.coefficient_subscripts) ) + def __eq__(self, value: SegmentedTensorProduct) -> bool: + assert isinstance(value, SegmentedTensorProduct) + return ( + self.operands == value.operands + and self.paths == value.paths + and self.coefficient_subscripts == value.coefficient_subscripts + ) + + def __lt__(self, value: SegmentedTensorProduct) -> bool: + assert isinstance(value, SegmentedTensorProduct) + return ( + self.num_operands, + self.num_paths, + self.subscripts, + self.operands, + self.paths, + self.coefficient_subscripts, + ) < ( + value.num_operands, + value.num_paths, + value.subscripts, + value.operands, + value.paths, + value.coefficient_subscripts, + ) + def __repr__(self) -> str: if max(len(operand) for operand in self.operands) == 1 and len(self.paths) == 1: operands = ",".join( From beba10217ef202bc9b160aed176791c7d8cfc7a9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 26 Feb 2025 04:36:46 -0800 Subject: [PATCH 002/107] print --- cuequivariance/cuequivariance/operation.py | 4 ++-- .../cuequivariance/segmented_polynomial.py | 24 ++++++++++++------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/cuequivariance/cuequivariance/operation.py b/cuequivariance/cuequivariance/operation.py index 135dc7a4..8e90cccd 100644 --- a/cuequivariance/cuequivariance/operation.py +++ b/cuequivariance/cuequivariance/operation.py @@ -51,7 +51,7 @@ def __init__(self, buffers: tuple[int, ...]): assert len(buffers) > 0, buffers assert all(isinstance(b, int) for b in buffers), buffers assert all(i >= 0 for i in buffers), buffers - self.buffers = tuple(int(b) for b in buffers) + object.__setattr__(self, "buffers", tuple(int(b) for b in buffers)) def __repr__(self): return f"Operation({self.buffers})" @@ -72,7 +72,7 @@ def list_to_string( def __lt__(self, value): assert isinstance(value, Operation) - return self.buffers < value.buffers + return (len(self.buffers), self.buffers) < (len(value.buffers), value.buffers) def __hash__(self) -> int: return hash(self.buffers) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 1ec6a307..25dca40e 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -61,12 +61,20 @@ def __lt__(self, value): ) def __repr__(self): - text = "" - text += " ".join(IVARS[self.num_inputs]) - text += " -> " - text += " ".join(OVARS[self.num_inputs : self.num_inputs + self.num_outputs]) - tab = "\n " - for ope, stp in self.tensor_products: - text += tab + f"{stp}" - text += tab + ope.to_string(self.num_inputs) + header = ( + " ".join(IVARS[: self.num_inputs]) + + " -> " + + " ".join(OVARS[self.num_inputs : self.num_inputs + self.num_outputs]) + + " " + ) + ope_txts = [ + " " + ope.to_string(self.num_inputs) for ope, _ in self.tensor_products + ] + n = max(len(ope_txt) for ope_txt in ope_txts) + n = max(len(header), n) + + text = header + "═" * (n - len(header)) + "═╗" + + for ope_txt, (_, stp) in zip(ope_txts, self.tensor_products): + text += "\n" + ope_txt + " " * (n - len(ope_txt)) + " ║ " + str(stp) return text From 6e080e85611f15c9c9f22c45a8af99aeac633030 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 26 Feb 2025 05:27:44 -0800 Subject: [PATCH 003/107] addind cuex.segmented_polynomial --- .../cuequivariance/segmented_polynomial.py | 67 ++++++++++++++++++- .../cuequivariance_jax/__init__.py | 2 + ...pl.py => segmented_polynomial_ops_impl.py} | 19 +++--- ...y => segmented_polynomial_vanilla_impl.py} | 6 +- 4 files changed, 82 insertions(+), 12 deletions(-) rename cuequivariance_jax/cuequivariance_jax/primitives/{tensor_product_ops_impl.py => segmented_polynomial_ops_impl.py} (89%) rename cuequivariance_jax/cuequivariance_jax/primitives/{tensor_product_vanilla_impl.py => segmented_polynomial_vanilla_impl.py} (98%) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 25dca40e..ecfc7e3d 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -15,7 +15,7 @@ from __future__ import annotations import dataclasses -from typing import Sequence +from typing import Callable, Sequence import cuequivariance as cue from cuequivariance.operation import IVARS, OVARS @@ -78,3 +78,68 @@ def __repr__(self): for ope_txt, (_, stp) in zip(ope_txts, self.tensor_products): text += "\n" + ope_txt + " " * (n - len(ope_txt)) + " ║ " + str(stp) return text + + @property + def buffer_sizes(self) -> list[int | None]: + sizes = [None] * (self.num_inputs + self.num_outputs) + for ope, stp in self.tensor_products: + for buffer, operand in zip(ope.buffers, stp.operands): + if sizes[buffer] is None: + sizes[buffer] = operand.size + if sizes[buffer] != operand.size: + raise ValueError( + f"Buffer {buffer} has inconsistent sizes: {sizes[buffer]} vs {operand.size}" + ) + return sizes + + @property + def input_sizes(self) -> list[int | None]: + return self.buffer_sizes[: self.num_inputs] + + @property + def output_sizes(self) -> list[int | None]: + return self.buffer_sizes[self.num_inputs :] + + def map_tensor_products( + self, + f: Callable[ + [cue.Operation, cue.SegmentedTensorProduct], + tuple[cue.Operation, cue.SegmentedTensorProduct] | None, + ], + ) -> SegmentedPolynomial: + new_tensor_products = [f(ope, stp) for ope, stp in self.tensor_products] + new_tensor_products = [ + ope_stp for ope_stp in new_tensor_products if ope_stp is not None + ] + return SegmentedPolynomial( + self.num_inputs, self.num_outputs, new_tensor_products + ) + + def jvp(self, has_tangent: list[bool]) -> SegmentedPolynomial: + new_tps = [] + for ope, stp in self.tensor_products: + jvps = ope.jvp(has_tangent) + permutations: list[tuple[int, ...]] = stp.symmetries() + for multiplicator, ope in cue.Operation.group_by_operational_symmetries( + permutations, jvps + ): + new_tps.append((ope, multiplicator * stp)) + return SegmentedPolynomial( + self.num_inputs + sum(has_tangent), self.num_outputs, new_tps + ) + + def transpose( + self, + is_undefined_primal: list[bool], + has_cotangent: list[bool], + ) -> SegmentedPolynomial: + new_tps = [] + for ope, stp in self.tensor_products: + ope = ope.transpose(is_undefined_primal, has_cotangent) + if ope is not None: + new_tps.append((ope, stp)) + return SegmentedPolynomial( + sum(map(lambda u: not u, is_undefined_primal)) + sum(has_cotangent), + sum(is_undefined_primal), + new_tps, + ) diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index ce52d9dd..5126a150 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -23,6 +23,7 @@ from .rep_array.vmap import vmap from .rep_array.utils import concatenate, randn, as_irreps_array, clebsch_gordan +from .primitives.segmented_polynomial import segmented_polynomial from .primitives.tensor_product import tensor_product from .primitives.equivariant_tensor_product import equivariant_tensor_product @@ -45,6 +46,7 @@ "randn", "as_irreps_array", "clebsch_gordan", + "segmented_polynomial", "tensor_product", "equivariant_tensor_product", "normalspace", diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py similarity index 89% rename from cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py rename to cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py index 2d6cd979..35e2cc3a 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py @@ -28,12 +28,12 @@ def sanitize_string(s): return re.sub(r"[^A-Za-z_]", "", s) -def tensor_product_ops_impl( +def segmented_polynomial_ops_impl( inputs: list[jax.Array], # shape (batch_size, operand_size) outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], indices: list[jax.Array], buffer_index: list[int], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, ) -> tuple[list[jax.Array] | None, str]: @@ -41,14 +41,15 @@ def log(msg: str): logger.info(f"[{name}] {msg}") return None, name - num_inputs = len(buffer_index) - len(outputs_shape_dtype) + assert polynomial.num_inputs == len(buffer_index) - len(outputs_shape_dtype) + assert polynomial.num_outputs == len(outputs_shape_dtype) buffers = list(inputs) + list(outputs_shape_dtype) for b in buffers: assert b.ndim == 2, f"Buffer {b.shape} must be 2D" # Reshape buffers to 3D by using the STP informations - for ope, stp in descriptors: + for ope, stp in polynomial.tensor_products: if len(stp.subscripts.modes()) != 1: return log(f"Unsupported STP: {stp}") if not stp.all_same_segment_shape(): @@ -91,7 +92,9 @@ def log(msg: str): batch_size = b.shape[0] # TODO: remove if the backend supports atomic operations for float16/bfloat16 - for i, b in zip(buffer_index[num_inputs:], buffers[num_inputs:]): + for i, b in zip( + buffer_index[polynomial.num_inputs :], buffers[polynomial.num_inputs :] + ): if b.dtype.type not in {jnp.float32, jnp.float64}: if i >= 0 or b.shape[0] != batch_size: return log( @@ -109,15 +112,15 @@ def log(msg: str): operations = [] paths = [] - for ope, stp in descriptors: + for ope, stp in polynomial.tensor_products: operations.append(Operation(ope.buffers, len(paths), stp.num_paths)) for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) log("Using the uniform 1d kernel of cuequivariance_ops_jax 🚀") outputs = tensor_product_uniform_1d_jit( - buffers[:num_inputs], - buffers[num_inputs:], + buffers[: polynomial.num_inputs], + buffers[polynomial.num_inputs :], indices, buffer_index, operations=operations, diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py similarity index 98% rename from cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py rename to cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py index ed14d515..4eba9e71 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product_vanilla_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py @@ -26,12 +26,12 @@ logger = logging.getLogger(__name__) -def tensor_product_vanilla_impl( +def segmented_polynomial_vanilla_impl( inputs: list[jax.Array], # shape (batch_size, operand_size) outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], indices: list[jax.Array], buffer_index: list[int], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, ) -> tuple[jax.Array, ...]: # output buffers @@ -52,7 +52,7 @@ def gather(i: int, x: jax.Array) -> jax.Array: return buffer.at[idx].add(x) return buffer + x - for operation, d in descriptors: + for operation, d in polynomial.tensor_products: ope_out, b_out = operation.output_operand_buffer(num_inputs) out = outputs_shape_dtype[b_out - num_inputs] From dcda13f4e7f75a7008622ea3a0feb58de2f1a7ec Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 26 Feb 2025 05:29:40 -0800 Subject: [PATCH 004/107] rename --- .../primitives/{tensor_product.py => segmented_polynomial.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename cuequivariance_jax/cuequivariance_jax/primitives/{tensor_product.py => segmented_polynomial.py} (100%) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py rename to cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py From 5548eeb77dbafccb9b6626fb6a62b4b0bf86101c Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 26 Feb 2025 05:30:29 -0800 Subject: [PATCH 005/107] add cuex.segmented_polynomial --- .../primitives/segmented_polynomial.py | 158 +++++++++--------- .../primitives/tensor_product.py | 77 +++++++++ 2 files changed, 153 insertions(+), 82 deletions(-) create mode 100644 cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index d0d770c2..6bbe6b7a 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -24,18 +24,18 @@ import cuequivariance as cue from cuequivariance_jax.primitives.primitives_utils import reshape -from cuequivariance_jax.primitives.tensor_product_ops_impl import ( - tensor_product_ops_impl, +from cuequivariance_jax.primitives.segmented_polynomial_ops_impl import ( + segmented_polynomial_ops_impl, ) -from cuequivariance_jax.primitives.tensor_product_vanilla_impl import ( - tensor_product_vanilla_impl, +from cuequivariance_jax.primitives.segmented_polynomial_vanilla_impl import ( + segmented_polynomial_vanilla_impl, ) logger = logging.getLogger(__name__) -def tensor_product( - descriptors: list[tuple[cue.Operation, cue.SegmentedTensorProduct]], +def segmented_polynomial( + polynomial: cue.SegmentedPolynomial, inputs: list[jax.Array], outputs_shape_dtype: list[jax.ShapeDtypeStruct], indices: list[jax.Array | None] | None = None, @@ -44,7 +44,7 @@ def tensor_product( name: str | None = None, impl: str = "auto", ) -> list[jax.Array]: - r"""Compute a polynomial described by a list of descriptors. + r"""Compute a segmented polynomial. Features: - Calls a CUDA kernel if: @@ -58,8 +58,7 @@ def tensor_product( - Automatic drop of unused buffers and indices Args: - descriptors (list of pairs): The list of descriptors. - Each descriptor is formed by a pair of :class:`cue.Operation ` and :class:`cue.SegmentedTensorProduct `. + polynomial (cue.SegmentedPolynomial): The segmented polynomial to compute. inputs (list of jax.Array): The input buffers. outputs_shape_dtype (list of jax.ShapeDtypeStruct): The output shapes and dtypes. indices (list of jax.Array or None, optional): The optional indices of the inputs and outputs. @@ -75,8 +74,10 @@ def tensor_product( """ if name is None: - name = "tensor_product" + name = "segmented_polynomial" + assert len(inputs) == polynomial.num_inputs + assert len(outputs_shape_dtype) == polynomial.num_outputs buffers = inputs + outputs_shape_dtype if indices is None: @@ -137,15 +138,15 @@ def fn( outputs_shape_dtype=buffers[len(inputs) :], indices=unique_indices, buffer_index=buffer_index, - descriptors=descriptors, + polynomial=polynomial, math_dtype=math_dtype, name=name, ) if impl == "naive_jax": - outputs = tensor_product_vanilla_impl(**kwargs) + outputs = segmented_polynomial_vanilla_impl(**kwargs) else: - outputs = tensor_product_prim(**kwargs, impl=impl) + outputs = segmented_polynomial_prim(**kwargs, impl=impl) def fn(x: jax.Array, shape: tuple[int, ...]) -> jax.Array: return jnp.reshape(x, shape) @@ -153,16 +154,16 @@ def fn(x: jax.Array, shape: tuple[int, ...]) -> jax.Array: return list(map(fn, outputs, [out.shape for out in outputs_shape_dtype])) -tensor_product_p = jax.extend.core.Primitive("tensor_product") -tensor_product_p.multiple_results = True +segmented_polynomial_p = jax.extend.core.Primitive("segmented_polynomial") +segmented_polynomial_p.multiple_results = True -def tensor_product_prim( +def segmented_polynomial_prim( inputs: list[jax.Array], # input buffers outputs_shape_dtype: list[jax.ShapeDtypeStruct], # output shapes and dtypes indices: list[jax.Array], # index buffers buffer_index: list[int], # maps: buffer index -> unique indices index - descriptors: list[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str = "auto", @@ -176,18 +177,21 @@ def tensor_product_prim( assert len(inputs) + len(outputs_shape_dtype) == len(buffer_index) assert max(buffer_index) < len(indices) - descriptors = map( - lambda x: ( - x[0], - x[1].consolidate_modes().remove_empty_segments().consolidate_paths(), - ), - descriptors, - ) - descriptors = list(filter(lambda x: x[1].num_paths > 0, descriptors)) + outputs_shape_dtype = [ + jax.ShapeDtypeStruct(x.shape, x.dtype) for x in outputs_shape_dtype + ] + + def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): + stp = stp.consolidate_modes().remove_empty_segments().consolidate_paths() + if stp.num_paths == 0: + return None + return ope, stp + + polynomial = polynomial.map_tensor_products(f) used_buffers = set() used_indices = set() - for ope, _ in descriptors: + for ope, _ in polynomial.tensor_products: for i in ope.buffers: used_buffers.add(i) if buffer_index[i] >= 0: @@ -196,8 +200,16 @@ def tensor_product_prim( used_indices = sorted(used_indices) # maps: new index -> old index new_num_inputs = sum([i < len(inputs) for i in used_buffers]) + new_poynomial = cue.SegmentedPolynomial( + new_num_inputs, + len(used_buffers) - new_num_inputs, + [ + (cue.Operation([used_buffers.index(i) for i in ope.buffers]), stp) + for ope, stp in polynomial.tensor_products + ], + ) - new_outputs = tensor_product_p.bind( + new_outputs = segmented_polynomial_p.bind( *[inputs[i] for i in used_buffers[:new_num_inputs]], *[indices[i] for i in used_indices], buffer_index=tuple( @@ -207,12 +219,7 @@ def tensor_product_prim( outputs_shape_dtype=tuple( outputs_shape_dtype[i - len(inputs)] for i in used_buffers[new_num_inputs:] ), - descriptors=frozenset( - [ - (cue.Operation([used_buffers.index(i) for i in ope.buffers]), stp) - for ope, stp in descriptors - ] - ), + polynomial=new_poynomial, math_dtype=jnp.dtype(math_dtype), name=str(name), impl=impl, @@ -252,11 +259,11 @@ def map_indices( return new_indices, new_buffer_index -def tensor_product_abstract_eval( +def segmented_polynomial_abstract_eval( *inputs_and_indices: jax.core.ShapedArray, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -266,12 +273,12 @@ def tensor_product_abstract_eval( ) -def tensor_product_impl( +def segmented_polynomial_impl( platform: str | None, *inputs_and_indices: jax.Array, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -286,7 +293,7 @@ def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): stp = stp.sort_paths() return ope, stp - descriptors = list(map(optimize_paths, *zip(*descriptors))) + polynomial = polynomial.map_tensor_products(optimize_paths) outputs = None kwargs = dict( @@ -294,7 +301,7 @@ def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): outputs_shape_dtype=outputs_shape_dtype, indices=indices, buffer_index=buffer_index, - descriptors=descriptors, + polynomial=polynomial, math_dtype=math_dtype, name=name, ) @@ -302,7 +309,7 @@ def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): assert impl in ("auto", "cuda", "jax") if platform == "cuda" and impl in ("auto", "cuda"): - outputs, msg = tensor_product_ops_impl(**kwargs) + outputs, msg = segmented_polynomial_ops_impl(**kwargs) else: msg = f"{platform=}, {impl=}" @@ -310,19 +317,19 @@ def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): raise RuntimeError(f"Failed to use CUDA implementation: {msg}") if outputs is None: - outputs = tensor_product_vanilla_impl(**kwargs) + outputs = segmented_polynomial_vanilla_impl(**kwargs) assert outputs is not None return outputs -def tensor_product_jvp( +def segmented_polynomial_jvp( primals_and_indices: tuple[jax.Array, ...], tangents_and_zeros: tuple[jax.Array | ad.Zero, ...], *, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -337,12 +344,12 @@ def tensor_product_jvp( assert all(isinstance(t, ad.Zero) for t in tangents_and_zeros[num_inputs:]) del primals_and_indices, tangents_and_zeros - out_primals = tensor_product_prim( + out_primals = segmented_polynomial_prim( primals, outputs_shape_dtype, indices, buffer_index, - descriptors, + polynomial, math_dtype, name, impl=impl, @@ -356,21 +363,12 @@ def tensor_product_jvp( + [num_inputs + i for i, x in enumerate(outputs_shape_dtype)], ) - jvp_descriptors = [] - for ope, stp in descriptors: - jvps = ope.jvp([not isinstance(t, ad.Zero) for t in tangents]) - permutations: list[tuple[int, ...]] = stp.symmetries() - for multiplicator, ope in cue.Operation.group_by_operational_symmetries( - permutations, jvps - ): - jvp_descriptors.append((ope, multiplicator * stp)) - - out_tangents = tensor_product_prim( + out_tangents = segmented_polynomial_prim( list(primals) + [t for t in tangents if not isinstance(t, ad.Zero)], outputs_shape_dtype, jvp_indices, jvp_buffer_index, - jvp_descriptors, + polynomial.jvp([not isinstance(t, ad.Zero) for t in tangents]), math_dtype, name + "_jvp", impl=impl, @@ -379,12 +377,12 @@ def tensor_product_jvp( return out_primals, out_tangents -def tensor_product_transpose( +def segmented_polynomial_transpose( cotangents: tuple[jax.Array | ad.Zero, ...], *inputs_and_indices: jax.Array | ad.UndefinedPrimal, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -409,16 +407,7 @@ def tensor_product_transpose( + [i for i, x in enumerate(inputs) if ad.is_undefined_primal(x)], ) - tr_descriptors = [] - for ope, stp in descriptors: - ope = ope.transpose( - [ad.is_undefined_primal(x) for x in inputs], - [not isinstance(x, ad.Zero) for x in cotangents], - ) - if ope is not None: - tr_descriptors.append((ope, stp)) - - tmp = tensor_product_prim( + tmp = segmented_polynomial_prim( [x for x in inputs if not ad.is_undefined_primal(x)] + [x for x in cotangents if not isinstance(x, ad.Zero)], # inputs [ @@ -428,7 +417,10 @@ def tensor_product_transpose( ], tr_indices, tr_buffer_index, - tr_descriptors, + polynomial.transpose( + [ad.is_undefined_primal(x) for x in inputs], + [not isinstance(x, ad.Zero) for x in cotangents], + ), math_dtype, name + "_transpose", impl=impl, @@ -444,13 +436,13 @@ def tensor_product_transpose( return tuple(outputs) -def tensor_product_batching( +def segmented_polynomial_batching( batched_inputs_and_indices: tuple[jax.Array, ...], batch_axes_of_inputs_and_indices: tuple[int | None, ...], *, buffer_index: tuple[int, ...], outputs_shape_dtype: tuple[jax.ShapeDtypeStruct, ...], - descriptors: frozenset[tuple[cue.Operation, cue.SegmentedTensorProduct]], + polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, impl: str, @@ -529,12 +521,12 @@ def flatten_index(x: jax.Array) -> jax.Array: for out in outputs_shape_dtype ) - outputs = tensor_product_p.bind( + outputs = segmented_polynomial_p.bind( *batched_inputs, *batched_indices, buffer_index=buffer_index, outputs_shape_dtype=new_outputs_shape_dtype, - descriptors=descriptors, + polynomial=polynomial, math_dtype=math_dtype, name=name + "_batching", impl=impl, @@ -549,22 +541,24 @@ def flatten_index(x: jax.Array) -> jax.Array: return outputs, (0,) * len(outputs) -tensor_product_p.def_abstract_eval(tensor_product_abstract_eval) -tensor_product_p.def_impl(partial(xla.apply_primitive, tensor_product_p)) +segmented_polynomial_p.def_abstract_eval(segmented_polynomial_abstract_eval) +segmented_polynomial_p.def_impl(partial(xla.apply_primitive, segmented_polynomial_p)) mlir.register_lowering( - tensor_product_p, + segmented_polynomial_p, mlir.lower_fun( - partial(tensor_product_impl, "cuda"), tensor_product_p.multiple_results + partial(segmented_polynomial_impl, "cuda"), + segmented_polynomial_p.multiple_results, ), "cuda", ) mlir.register_lowering( - tensor_product_p, + segmented_polynomial_p, mlir.lower_fun( - partial(tensor_product_impl, None), tensor_product_p.multiple_results + partial(segmented_polynomial_impl, None), + segmented_polynomial_p.multiple_results, ), None, ) -ad.primitive_jvps[tensor_product_p] = tensor_product_jvp -ad.primitive_transposes[tensor_product_p] = tensor_product_transpose -batching.primitive_batchers[tensor_product_p] = tensor_product_batching +ad.primitive_jvps[segmented_polynomial_p] = segmented_polynomial_jvp +ad.primitive_transposes[segmented_polynomial_p] = segmented_polynomial_transpose +batching.primitive_batchers[segmented_polynomial_p] = segmented_polynomial_batching diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py new file mode 100644 index 00000000..78c7395e --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import jax +import jax.numpy as jnp + +import cuequivariance as cue +import cuequivariance_jax as cuex + +logger = logging.getLogger(__name__) + + +def tensor_product( + descriptors: list[tuple[cue.Operation, cue.SegmentedTensorProduct]], + inputs: list[jax.Array], + outputs_shape_dtype: list[jax.ShapeDtypeStruct], + indices: list[jax.Array | None] | None = None, + *, + math_dtype: jnp.dtype | None = None, + name: str | None = None, + impl: str = "auto", +) -> list[jax.Array]: + r"""Compute a polynomial described by a list of descriptors. + + Features: + - Calls a CUDA kernel if: + - STPs have a single mode which is a multiple of 32 (e.g. a channelwise tensor product that has subscripts ``u,u,,u`` with u=128) + - math data type is float32 or float64 + - in/out data type is a mix of float32, float64, float16 and bfloat16 + - indices are int32 + - Supports of infinite derivatives (JVP and tranpose rules maps to a single corresponding primitive) + - Limited support for batching (we cannot batch a buffer that has indices and if the batching is non trivial the performace will be bad) + - Automatic optimizations based on the symmetries of the STPs and on the repetition of the input buffers + - Automatic drop of unused buffers and indices + + Args: + descriptors (list of pairs): The list of descriptors. + Each descriptor is formed by a pair of :class:`cue.Operation ` and :class:`cue.SegmentedTensorProduct `. + inputs (list of jax.Array): The input buffers. + outputs_shape_dtype (list of jax.ShapeDtypeStruct): The output shapes and dtypes. + indices (list of jax.Array or None, optional): The optional indices of the inputs and outputs. + math_dtype (jnp.dtype, optional): The data type for computational operations. Defaults to None. + name (str, optional): The name of the operation. Defaults to None. + impl (str, optional): The implementation to use. Defaults to "auto". + If "auto", it will use the CUDA implementation if available, otherwise it will use the JAX implementation. + If "cuda", it will use the CUDA implementation. + If "jax", it will use the JAX implementation. + + Returns: + list of jax.Array: The result of the tensor product. + """ + + if name is None: + name = "tensor_product" + + return cuex.segmented_polynomial( + cue.SegmentedPolynomial(len(inputs), len(outputs_shape_dtype), descriptors), + inputs, + outputs_shape_dtype, + indices, + math_dtype=math_dtype, + name=name, + impl=impl, + ) From 2eae581a03b5699034e56359c90cbdfe723ed91f Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 26 Feb 2025 05:54:52 -0800 Subject: [PATCH 006/107] add EquivariantPolynomial --- cuequivariance/cuequivariance/__init__.py | 2 + .../cuequivariance/equivariant_polynomial.py | 57 +++++++++++++++++++ .../cuequivariance/segmented_polynomial.py | 10 ++++ 3 files changed, 69 insertions(+) create mode 100644 cuequivariance/cuequivariance/equivariant_polynomial.py diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index 724800de..14e64bf3 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -50,6 +50,7 @@ from cuequivariance.operation import Operation from cuequivariance.segmented_tensor_product import SegmentedTensorProduct from cuequivariance.segmented_polynomial import SegmentedPolynomial +from cuequivariance.equivariant_polynomial import EquivariantPolynomial from cuequivariance.equivariant_tensor_product import EquivariantTensorProduct from cuequivariance import segmented_tensor_product, descriptors @@ -81,6 +82,7 @@ "Operation", "SegmentedTensorProduct", "SegmentedPolynomial", + "EquivariantPolynomial", "EquivariantTensorProduct", "segmented_tensor_product", "descriptors", diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/equivariant_polynomial.py new file mode 100644 index 00000000..a136207b --- /dev/null +++ b/cuequivariance/cuequivariance/equivariant_polynomial.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses + +import cuequivariance as cue + + +@dataclasses.dataclass(init=False, frozen=True) +class EquivariantPolynomial: + operands: tuple[cue.Rep, ...] + polynomial: cue.SegmentedPolynomial + + def __init__(self, operands: list[cue.Rep], polynomial: cue.SegmentedPolynomial): + object.__setattr__(self, "operands", tuple(operands)) + object.__setattr__(self, "polynomial", polynomial) + assert len(self.operands) == len(self.polynomial) + for rep, size in zip(self.operands, self.polynomial.buffer_sizes): + assert size is None or size == rep.dim + + def __hash__(self) -> int: + return hash((self.operands, self.polynomial)) + + def __mul__(self, factor: float) -> EquivariantPolynomial: + return EquivariantPolynomial(self.operands, self.polynomial * factor) + + def __rmul__(self, factor: float) -> EquivariantPolynomial: + return self.__mul__(factor) + + @property + def num_operands(self) -> int: + return len(self.operands) + + @property + def num_inputs(self) -> int: + return self.polynomial.num_inputs + + @property + def inputs(self) -> tuple[cue.Rep, ...]: + return self.operands[: self.num_inputs] + + @property + def outputs(self) -> tuple[cue.Rep, ...]: + return self.operands[self.num_inputs :] diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index ecfc7e3d..fc947816 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -60,6 +60,16 @@ def __lt__(self, value): value.tensor_products, ) + def __mul__(self, factor: float) -> SegmentedPolynomial: + return SegmentedPolynomial( + self.num_inputs, + self.num_outputs, + [(ope, factor * stp) for ope, stp in self.tensor_products], + ) + + def __rmul__(self, factor: float) -> SegmentedPolynomial: + return self.__mul__(factor) + def __repr__(self): header = ( " ".join(IVARS[: self.num_inputs]) From 3fc09714ba6093c6cf7cbb880d7d4860615d0ff3 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 26 Feb 2025 06:43:52 -0800 Subject: [PATCH 007/107] add method stack --- .../descriptors/spherical_harmonics_.py | 12 ++-- .../descriptors/symmetric_contractions.py | 12 ++-- .../cuequivariance/equivariant_polynomial.py | 47 ++++++++++++++- .../cuequivariance/segmented_polynomial.py | 58 ++++++++++++++++++- 4 files changed, 114 insertions(+), 15 deletions(-) diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py index 9f19ff8a..3d22e525 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py @@ -23,7 +23,7 @@ def spherical_harmonics( ir_vec: cue.Irrep, ls: list[int], layout: cue.IrrepsLayout = cue.ir_mul -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``vector[],...,vector[],Yl[]`` @@ -33,14 +33,14 @@ def spherical_harmonics( layout (IrrepsLayout, optional): layout of the output. Defaults to ``cue.ir_mul``. Returns: - :class:`cue.EquivariantTensorProduct `: The descriptor. + :class:`cue.EquivariantPolynomial `: The descriptor. Examples: >>> spherical_harmonics(cue.SO3(1), [0, 1, 2]) - EquivariantTensorProduct((1)^(0..2) -> 0+1+2) + EquivariantPolynomial((1)^(0..2) -> 0+1+2) """ if len(ls) != 1: - return cue.EquivariantTensorProduct.stack( + return cue.EquivariantPolynomial.stack( [spherical_harmonics(ir_vec, [ell], layout) for ell in ls], [False, True] ) @@ -54,12 +54,12 @@ def spherical_harmonics( indices = poly_degrees_to_path_indices(degrees) d.add_path(*indices, i, c=coeff) - return cue.EquivariantTensorProduct( - [d], + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul), cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul), ], + cue.SegmentedPolynomial(1, 1, [(cue.Operation([0] * ell + [1]), d)]), ) diff --git a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py index 09b13749..45ff4207 100644 --- a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py @@ -20,7 +20,7 @@ def symmetric_contraction( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: list[int], -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: r""" subscripts: ``weights[u],input[u],output[u]`` @@ -34,7 +34,7 @@ def symmetric_contraction( degree (int): The degree of the symmetric contraction. Returns: - :class:`cue.EquivariantTensorProduct `: + :class:`cue.EquivariantPolynomial `: The descriptor of the symmetric contraction. The operands are the weights, the input degree times and the output. @@ -44,13 +44,13 @@ def symmetric_contraction( ... 16 * cue.Irreps("SO3", "0 + 1"), ... [1, 2, 3] ... ) - EquivariantTensorProduct(32x0+80x0+176x0 x (16x0+16x1+16x2)^(1..3) -> 16x0+16x1) + EquivariantPolynomial(32x0+80x0+176x0 x (16x0+16x1+16x2)^(1..3) -> 16x0+16x1) Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). """ degrees = list(degrees) if len(degrees) != 1: - return cue.EquivariantTensorProduct.stack( + return cue.EquivariantPolynomial.stack( [ symmetric_contraction(irreps_in, irreps_out, [degree]) for degree in degrees @@ -102,11 +102,11 @@ def symmetric_contraction( d = d.flatten_coefficient_modes() d = d.append_modes_to_all_operands("u", {"u": mul}) - return cue.EquivariantTensorProduct( - [d], + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul), ], + cue.SegmentedPolynomial(2, 1, [(cue.Operation([0] + [1] * degree + [2]), d)]), ) diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/equivariant_polynomial.py index a136207b..5b905e4e 100644 --- a/cuequivariance/cuequivariance/equivariant_polynomial.py +++ b/cuequivariance/cuequivariance/equivariant_polynomial.py @@ -25,12 +25,19 @@ class EquivariantPolynomial: polynomial: cue.SegmentedPolynomial def __init__(self, operands: list[cue.Rep], polynomial: cue.SegmentedPolynomial): + assert isinstance(polynomial, cue.SegmentedPolynomial) object.__setattr__(self, "operands", tuple(operands)) object.__setattr__(self, "polynomial", polynomial) - assert len(self.operands) == len(self.polynomial) + assert ( + len(self.operands) + == self.polynomial.num_inputs + self.polynomial.num_outputs + ) for rep, size in zip(self.operands, self.polynomial.buffer_sizes): assert size is None or size == rep.dim + def __repr__(self): + return self.polynomial.to_string([f"{rep}" for rep in self.operands]) + def __hash__(self) -> int: return hash((self.operands, self.polynomial)) @@ -48,6 +55,10 @@ def num_operands(self) -> int: def num_inputs(self) -> int: return self.polynomial.num_inputs + @property + def num_outputs(self) -> int: + return self.polynomial.num_outputs + @property def inputs(self) -> tuple[cue.Rep, ...]: return self.operands[: self.num_inputs] @@ -55,3 +66,37 @@ def inputs(self) -> tuple[cue.Rep, ...]: @property def outputs(self) -> tuple[cue.Rep, ...]: return self.operands[self.num_inputs :] + + @classmethod + def stack( + cls, polys: list[EquivariantPolynomial], stacked: list[bool] + ) -> EquivariantPolynomial: + assert len(polys) > 0 + num_operands = polys[0].num_operands + + assert all(pol.num_operands == num_operands for pol in polys) + assert len(stacked) == num_operands + + operands = [] + for oid in range(num_operands): + if stacked[oid]: + for pol in polys: + if not isinstance(pol.operands[oid], cue.IrrepsAndLayout): + raise ValueError( + f"Cannot stack operand {oid} of type {type(pol.operands[oid])}" + ) + operands.append(cue.concatenate([pol.operands[oid] for pol in polys])) + else: + ope = polys[0].operands[oid] + for pol in polys: + if pol.operands[oid] != ope: + raise ValueError( + f"Operand {oid} must be the same for all polynomials." + f" Found {ope} and {pol.operands[oid]}" + ) + operands.append(ope) + + return cls( + operands, + cue.SegmentedPolynomial.stack([pol.polynomial for pol in polys], stacked), + ) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index fc947816..7e73e192 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import copy import dataclasses from typing import Callable, Sequence @@ -71,10 +72,24 @@ def __rmul__(self, factor: float) -> SegmentedPolynomial: return self.__mul__(factor) def __repr__(self): + return self.to_string() + + def to_string(self, buffer_names: list[str] | None = None) -> str: + buffer_txts = ( + IVARS[: self.num_inputs] + + OVARS[self.num_inputs : self.num_inputs + self.num_outputs] + ) + if buffer_names is not None: + buffer_txts = [ + f"{symbol}={name}" for symbol, name in zip(buffer_txts, buffer_names) + ] + header = ( - " ".join(IVARS[: self.num_inputs]) + " ".join(buffer_txts[: self.num_inputs]) + " -> " - + " ".join(OVARS[self.num_inputs : self.num_inputs + self.num_outputs]) + + " ".join( + buffer_txts[self.num_inputs : self.num_inputs + self.num_outputs] + ) + " " ) ope_txts = [ @@ -153,3 +168,42 @@ def transpose( sum(is_undefined_primal), new_tps, ) + + @classmethod + def stack( + cls, polys: list[SegmentedPolynomial], stacked: list[bool] + ) -> SegmentedPolynomial: + assert len(polys) > 0 + num_inputs = polys[0].num_inputs + num_outputs = polys[0].num_outputs + assert all(pol.num_inputs == num_inputs for pol in polys) + assert all(pol.num_outputs == num_outputs for pol in polys) + assert len(stacked) == num_inputs + num_outputs + + tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] = [] + for index, pol in enumerate(polys): + for ope, stp in pol.tensor_products: + stp = copy.deepcopy(stp) + for oid, buffer in enumerate(ope.buffers): + if stacked[buffer]: + for p in reversed(polys[:index]): + stp.insert_segments(oid, 0, p.buffer_segments(buffer)) + for p in polys[index + 1 :]: + stp.insert_segments(oid, -1, p.buffer_segments(buffer)) + tensor_products.append((ope, stp)) + return cls(num_inputs, num_outputs, tensor_products) + + def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: + segments = None + for ope, stp in self.tensor_products: + if buffer in ope.buffers: + ope = stp.operands[ope.buffers.index(buffer)] + if segments is None: + segments = ope.segments + elif segments != ope.segments: + raise ValueError( + f"Buffer {buffer} has inconsistent segments: {segments} vs {ope.segments}" + ) + if segments is None: + raise ValueError(f"Buffer {buffer} is not used") + return segments From 90dcb0b20c9e30d63fd827796d9a0fe495b8bfa2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 01:09:28 -0800 Subject: [PATCH 008/107] draft equivariant_polynomial --- .../cuequivariance_jax/__init__.py | 2 + .../primitives/equivariant_polynomial.py | 169 ++++++++++++++++++ 2 files changed, 171 insertions(+) create mode 100644 cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index 5126a150..f6a99482 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -25,6 +25,7 @@ from .primitives.segmented_polynomial import segmented_polynomial from .primitives.tensor_product import tensor_product +from .primitives.equivariant_polynomial import equivariant_polynomial from .primitives.equivariant_tensor_product import equivariant_tensor_product from .operations.activation import ( @@ -48,6 +49,7 @@ "clebsch_gordan", "segmented_polynomial", "tensor_product", + "equivariant_polynomial", "equivariant_tensor_product", "normalspace", "normalize_function", diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py new file mode 100644 index 00000000..e6cbcb39 --- /dev/null +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py @@ -0,0 +1,169 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.numpy as jnp + +import cuequivariance as cue +import cuequivariance_jax as cuex + + +def equivariant_polynomial( + poly: cue.EquivariantPolynomial, + inputs: list[cuex.RepArray | jax.Array], + outputs_shape_dtype: list[jax.ShapeDtypeStruct] + | jax.ShapeDtypeStruct + | None = None, + indices: list[jax.Array | None] | None = None, + math_dtype: jnp.dtype | None = None, + name: str | None = None, + impl: str = "auto", +) -> tuple[cuex.RepArray, ...] | cuex.RepArray: + """Compute an equivariant polynomial. + + Args: + poly (:class:`cue.EquivariantPolynomial `): The equivariant tensor product descriptor. + *inputs (RepArray or jax.Array): The input arrays. + indices (list of jax.Array or None, optional): The optional indices of the inputs and output. + output_batch_shape (tuple of int, optional): The batch shape of the output array. + output_dtype (jnp.dtype, optional): The data type for the output array. Defaults to None. + math_dtype (jnp.dtype, optional): The data type for computational operations. Defaults to None. + name (str, optional): The name of the operation. Defaults to None. + + Returns: + tuple of RepArray or RepArray: The output array(s). + + Examples: + + Let's create a descriptor for the spherical harmonics of degree 0, 1, and 2. + + >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) + >>> e + EquivariantPolynomial((1)^(0..2) -> 0+1+2) + + We need some input data. + + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) + >>> x + {0: 1} [0. 1. 0.] + + Now we can execute the equivariant tensor product. + + >>> cuex.equivariant_tensor_product(e, x) + {0: 0+1+2} + [1. ... ] + + The `indices` argument allows to specify a list of optional int32 arrays for each input and for the output (`None` means no index and `indices[-1]` is the output index). The indices are used to select the elements of the input arrays and to specify the output index. + In the following example, we will index the output. The input has a batch shape of (3,) and the output has a batch shape of (2,). + + >>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32) + + The `i_out` array is used to map the result to the output indices. + + >>> with cue.assume(cue.SO3, cue.ir_mul): + ... x = cuex.RepArray("1", jnp.array([ + ... [0.0, 1.0, 0.0], + ... [0.0, 0.0, 1.0], + ... [1.0, 0.0, 0.0], + ... ])) + >>> cuex.equivariant_tensor_product( + ... e, + ... x, + ... indices=[None, i_out], + ... output_batch_shape=(2,), + ... ) + {1: 0+1+2} + [[ 1. ... ] + [ 2. ... ]] + """ + if name is None: + name = "equivariant_polynomial" + + if len(inputs) != poly.num_inputs: + raise ValueError( + f"Unexpected number of inputs. Expected {poly.num_inputs}, got {len(inputs)}." + ) + + for i, (x, rep) in enumerate(zip(inputs, poly.inputs)): + if isinstance(x, cuex.RepArray): + assert x.rep(-1) == rep, ( + f"Input {i} should have representation {rep}, got {x.rep(-1)}." + ) + else: + assert x.ndim >= 1, ( + f"Input {i} should have at least one dimension, got {x.ndim}." + ) + assert x.shape[-1] == rep.dim, ( + f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}." + ) + if not rep.is_scalar(): + raise ValueError( + f"Input {i} should be a RepArray unless the input is scalar. Got {type(x)} for {rep}." + ) + + inputs: list[jax.Array] = [getattr(x, "array", x) for x in inputs] + + if indices is None: + indices = [None] * poly.num_operands + + if len(indices) != poly.num_operands: + raise ValueError( + f"Unexpected number of indices. indices should None or a list of length {poly.num_operands}, got a list of length {len(indices)}." + ) + + if outputs_shape_dtype is None: + if not all(i is None for i in indices[poly.num_inputs :]): + raise ValueError( + "When output indices are provided, outputs_shape_dtype must be provided." + ) + if poly.num_inputs == 0: + raise ValueError( + "When no inputs are provided, outputs_shape_dtype must be provided." + ) + inferred_shape = jnp.broadcast_shapes( + *[ + x.shape[:-1] if i is None else i.shape + x.shape[1:-1] + for i, x in zip(indices, inputs) + ] + ) + inferred_dtype = jnp.result_type(*inputs) + outputs_shape_dtype = [ + jax.ShapeDtypeStruct(inferred_shape + (rep.dim,), inferred_dtype) + for rep in poly.outputs + ] + + if hasattr(outputs_shape_dtype, "shape"): + outputs_shape_dtype = [outputs_shape_dtype] + + if len(outputs_shape_dtype) != poly.num_outputs: + raise ValueError( + f"Unexpected number of outputs. Expected {poly.num_outputs}, got {len(outputs_shape_dtype)}." + ) + + outputs = cuex.segmented_polynomial( + poly.polynomial, + inputs, + outputs_shape_dtype, + indices, + math_dtype=math_dtype, + name=name, + impl=impl, + ) + outputs = [cuex.RepArray(rep, x) for rep, x in zip(poly.outputs, outputs)] + + if poly.num_outputs == 1: + return outputs[0] + return tuple(outputs) From c960306dff36fdc8b79fbc3850a1d1082b1bc734 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 01:32:12 -0800 Subject: [PATCH 009/107] cool __repr__ --- .../cuequivariance/segmented_polynomial.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 7e73e192..43da8731 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -92,17 +92,36 @@ def to_string(self, buffer_names: list[str] | None = None) -> str: ) + " " ) - ope_txts = [ - " " + ope.to_string(self.num_inputs) for ope, _ in self.tensor_products + lines = [ + "│ " + ope.to_string(self.num_inputs) for ope, _ in self.tensor_products ] - n = max(len(ope_txt) for ope_txt in ope_txts) - n = max(len(header), n) + if len(lines) > 0: + lines[-1] = "╰─" + lines[-1][2:] + n = max(len(line) for line in lines) - text = header + "═" * (n - len(header)) + "═╗" + lines = [ + line + " " + "─" * (n - len(line)) + "─ " + str(stp) + for line, (_, stp) in zip(lines, self.tensor_products) + ] + + modes = sorted( + {mode for _, stp in self.tensor_products for mode in stp.subscripts.modes()} + ) + if len(modes) > 1: + modes = [] + for a in ["sizes=", "num_segments=", "num_paths="] + [f"{m}=" for m in modes]: + if not all(line.count(a) == 1 for line in lines): + continue + + splits = [line.split(a) for line in lines] + n = max(len(before) for before, _ in splits) + lines = [ + before + " " * (n - len(before)) + a + after for before, after in splits + ] + + lines = ["╭ " + header] + lines - for ope_txt, (_, stp) in zip(ope_txts, self.tensor_products): - text += "\n" + ope_txt + " " * (n - len(ope_txt)) + " ║ " + str(stp) - return text + return "\n".join(lines) @property def buffer_sizes(self) -> list[int | None]: From aea4cdeb331750744a80926aec49c75f6bf8f443 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 07:59:21 -0800 Subject: [PATCH 010/107] jvp, transpose, backward, consolidate, remove unused buffers --- .../cuequivariance/descriptors/irreps_tp.py | 45 +++-- .../cuequivariance/descriptors/rotations.py | 57 ++++--- .../descriptors/spherical_harmonics_.py | 5 +- .../descriptors/symmetric_contractions.py | 5 +- .../descriptors/transposition.py | 7 +- .../cuequivariance/equivariant_polynomial.py | 110 ++++++++++++- .../equivariant_tensor_product.py | 29 ---- .../cuequivariance/experimental/escn.py | 18 +- .../mace/symmetric_contractions.py | 28 ++-- .../cuequivariance/segmented_polynomial.py | 155 +++++++++++++++--- ...test.py => equivariant_polynomial_test.py} | 40 ++--- .../segmented_tensor_product/dot_test.py | 4 +- .../primitives/equivariant_tensor_product.py | 17 ++ 13 files changed, 360 insertions(+), 160 deletions(-) rename cuequivariance/tests/{equivariant_tensor_products_test.py => equivariant_polynomial_test.py} (64%) diff --git a/cuequivariance/cuequivariance/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/descriptors/irreps_tp.py index 59b07cd4..6b6c7806 100644 --- a/cuequivariance/cuequivariance/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/descriptors/irreps_tp.py @@ -22,7 +22,7 @@ def fully_connected_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3: cue.Irreps -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``weights[uvw],lhs[iu],rhs[jv],output[kw]`` @@ -39,7 +39,7 @@ def fully_connected_tensor_product( irreps3 (Irreps): Irreps of the output. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the fully connected tensor product. + :class:`cue.EquivariantPolynomial `: Descriptor of the fully connected tensor product. Examples: >>> cue.descriptors.fully_connected_tensor_product( @@ -47,7 +47,8 @@ def fully_connected_tensor_product( ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), ... ) - EquivariantTensorProduct(61440x0 x 16x0+16x1+16x2 x 16x0+16x1+16x2 -> 16x0+16x1+16x2) + ╭ a=61440x0 b=16x0+16x1+16x2 c=16x0+16x1+16x2 -> D=16x0+16x1+16x2 + ╰─ a b c D ─ uvw,iu,jv,kw+ijk sizes=61440,144,144,144 num_segments=15,3,3,3 num_paths=15 i={1, 3, 5} j={1, 3, 5} k={1, 3, 5} u=16 v=16 w=16 Where ``61440x0`` are the 61440 weights needed to mix all the inputs with all the outputs. """ @@ -73,14 +74,14 @@ def fully_connected_tensor_product( d.add_path((mul1, mul2, mul3), i1, i2, i3, c=cg) d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) @@ -88,7 +89,7 @@ def full_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``lhs[iu],rhs[jv],output[kuv]`` @@ -102,7 +103,7 @@ def full_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the full tensor product. + :class:`cue.EquivariantPolynomial `: Descriptor of the full tensor product. """ G = irreps1.irrep_class @@ -136,13 +137,13 @@ def full_tensor_product( d = d.permute_segments(2, inv) d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) @@ -150,7 +151,7 @@ def channelwise_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``weights[uv],lhs[iu],rhs[jv],output[kuv]`` @@ -166,7 +167,7 @@ def channelwise_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the channelwise tensor product. + :class:`cue.EquivariantPolynomial `: Descriptor of the channelwise tensor product. """ G = irreps1.irrep_class @@ -201,14 +202,14 @@ def channelwise_tensor_product( d = d.permute_segments(3, inv) d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) @@ -245,7 +246,7 @@ def elementwise_tensor_product( irreps1: cue.Irreps, irreps2: cue.Irreps, irreps3_filter: Optional[Sequence[cue.Irrep]] = None, -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subscripts: ``lhs[iu],rhs[ju],output[ku]`` @@ -257,7 +258,7 @@ def elementwise_tensor_product( irreps3_filter (sequence of Irrep, optional): Irreps of the output to consider. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the elementwise tensor product. + :class:`cue.EquivariantPolynomial `: Descriptor of the elementwise tensor product. """ G = irreps1.irrep_class @@ -286,19 +287,17 @@ def elementwise_tensor_product( irreps3 = cue.Irreps(G, irreps3) d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps1, cue.ir_mul), cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) -def linear( - irreps_in: cue.Irreps, irreps_out: cue.Irreps -) -> cue.EquivariantTensorProduct: +def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPolynomial: """ subscripts: ``weights[uv],input[iu],output[iv]`` @@ -309,7 +308,7 @@ def linear( irreps_out (Irreps): Irreps of the output. Returns: - :class:`cue.EquivariantTensorProduct `: Descriptor of the linear transformation. + :class:`cue.EquivariantPolynomial `: Descriptor of the linear transformation. """ d = stp.SegmentedTensorProduct.from_subscripts("uv_iu_iv") for mul, ir in irreps_in: @@ -325,11 +324,11 @@ def linear( d = d.normalize_paths_for_operand(-1) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps_in, cue.ir_mul), cue.IrrepsAndLayout(irreps_out, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) diff --git a/cuequivariance/cuequivariance/descriptors/rotations.py b/cuequivariance/cuequivariance/descriptors/rotations.py index 2257c1b5..f0801384 100644 --- a/cuequivariance/cuequivariance/descriptors/rotations.py +++ b/cuequivariance/cuequivariance/descriptors/rotations.py @@ -22,7 +22,7 @@ def fixed_axis_angle_rotation( irreps: cue.Irreps, axis: np.ndarray, angle: float -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``input[u],output[u]`` @@ -42,18 +42,18 @@ def fixed_axis_angle_rotation( ) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) def yxy_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``gamma[],beta[],alpha[],input[u],output[u]`` @@ -69,12 +69,12 @@ def yxy_rotation( where l is the maximum L in the input and output irreps. """ - cbio = xy_rotation(irreps, lmax).d # gamma, beta, input, A - aio = y_rotation(irreps, lmax).d # alpha, A, output + # gamma, beta, input, A + cbio = xy_rotation(irreps, lmax).polynomial.tensor_products[0][1] + aio = y_rotation(irreps, lmax).polynomial.tensor_products[0][1] # alpha, A, output cbiao = stp.dot(cbio, aio, (3, 1)) # gamma, beta, input, alpha, output cbaio = cbiao.move_operand(2, 3) # gamma, beta, alpha, input, output - return cue.EquivariantTensorProduct( - cbaio, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps.new_scalars(cbaio.operands[1].size), cue.ir_mul), @@ -82,35 +82,36 @@ def yxy_rotation( cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(cbaio), ) def xy_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``gamma[],beta[],input[u],output[u]`` Rotation around the y-axis followed by rotation around the x-axis """ - cio = y_rotation(irreps, lmax).d # gamma, input, A - bio = x_rotation(irreps, lmax).d # beta, A, output + cio = y_rotation(irreps, lmax).polynomial.tensor_products[0][1] # gamma, input, A + bio = x_rotation(irreps, lmax).polynomial.tensor_products[0][1] # beta, A, output cibo = stp.dot(cio, bio, (2, 1)) # gamma, input, beta, output cbio = cibo.move_operand(1, 2) # gamma, beta, input, output - return cue.EquivariantTensorProduct( - cbio, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(cbio), ) def yx_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``phi[],theta[],input[u],output[u]`` @@ -120,20 +121,20 @@ def yx_rotation( bio = y_rotation(irreps, lmax).d cibo = stp.dot(cio, bio, (2, 1)) cbio = cibo.move_operand(1, 2) - return cue.EquivariantTensorProduct( - cbio, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps.new_scalars(cbio.operands[1].size), cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(cbio), ) def y_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``phi[],input[u],output[u]`` @@ -190,19 +191,19 @@ def y_rotation( d.add_path(phc, slo, slo, c=c) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) def x_rotation( irreps: cue.Irreps, lmax: Optional[int] = None -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``phi[],input[u],output[u]`` @@ -216,21 +217,23 @@ def x_rotation( """ assert irreps.irrep_class in [cue.SO3, cue.O3] - dy = y_rotation(irreps, lmax).d - dz90 = fixed_axis_angle_rotation(irreps, np.array([0.0, 0.0, 1.0]), np.pi / 2.0).d + dy = y_rotation(irreps, lmax).polynomial.tensor_products[0][1] + dz90 = fixed_axis_angle_rotation( + irreps, np.array([0.0, 0.0, 1.0]), np.pi / 2.0 + ).polynomial.tensor_products[0][1] d = stp.dot(stp.dot(dy, dz90, (1, 1)), dz90, (1, 1)) - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) -def inversion(irreps: cue.Irreps) -> cue.EquivariantTensorProduct: +def inversion(irreps: cue.Irreps) -> cue.EquivariantPolynomial: """ subsrcipts: ``input[u],output[u]`` """ @@ -241,10 +244,10 @@ def inversion(irreps: cue.Irreps) -> cue.EquivariantTensorProduct: assert np.allclose(H @ H, np.eye(ir.dim), atol=1e-6) d.add_path(None, None, c=H, dims={"u": mul}) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py index 3d22e525..3093bcc7 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py @@ -37,7 +37,10 @@ def spherical_harmonics( Examples: >>> spherical_harmonics(cue.SO3(1), [0, 1, 2]) - EquivariantPolynomial((1)^(0..2) -> 0+1+2) + ╭ a=1 -> B=0+1+2 + │ B ───── sizes=9 num_segments=9 num_paths=1 + │ a B ─── , sizes=3,9 num_segments=3,9 num_paths=3 + ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=8 """ if len(ls) != 1: return cue.EquivariantPolynomial.stack( diff --git a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py index 45ff4207..bab9a267 100644 --- a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py @@ -44,7 +44,10 @@ def symmetric_contraction( ... 16 * cue.Irreps("SO3", "0 + 1"), ... [1, 2, 3] ... ) - EquivariantPolynomial(32x0+80x0+176x0 x (16x0+16x1+16x2)^(1..3) -> 16x0+16x1) + ╭ a=32x0+80x0+176x0 b=16x0+16x1+16x2 -> C=16x0+16x1 + │ a b C ───── u,u,u sizes=288,144,64 num_segments=18,9,4 num_paths=4 u=16 + │ a b b C ─── u,u,u,u sizes=288,144,144,64 num_segments=18,9,9,4 num_paths=37 u=16 + ╰─ a b b b C ─ u,u,u,u,u sizes=288,144,144,144,64 num_segments=18,9,9,9,4 num_paths=437 u=16 Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). """ diff --git a/cuequivariance/cuequivariance/descriptors/transposition.py b/cuequivariance/cuequivariance/descriptors/transposition.py index ae671164..0cb4925c 100644 --- a/cuequivariance/cuequivariance/descriptors/transposition.py +++ b/cuequivariance/cuequivariance/descriptors/transposition.py @@ -17,7 +17,7 @@ def transpose( irreps: cue.Irreps, source: cue.IrrepsLayout, target: cue.IrrepsLayout -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """Transpose the irreps layout of a tensor.""" d = cue.SegmentedTensorProduct( operands=[ @@ -31,6 +31,7 @@ def transpose( ) for mul, ir in irreps: d.add_path(None, None, c=1, dims={"u": mul, "i": ir.dim}) - return cue.EquivariantTensorProduct( - d, [cue.IrrepsAndLayout(irreps, source), cue.IrrepsAndLayout(irreps, target)] + return cue.EquivariantPolynomial( + [cue.IrrepsAndLayout(irreps, source), cue.IrrepsAndLayout(irreps, target)], + cue.SegmentedPolynomial.trivial(d), ) diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/equivariant_polynomial.py index 5b905e4e..4c3d14ea 100644 --- a/cuequivariance/cuequivariance/equivariant_polynomial.py +++ b/cuequivariance/cuequivariance/equivariant_polynomial.py @@ -16,6 +16,8 @@ import dataclasses +import numpy as np + import cuequivariance as cue @@ -28,25 +30,50 @@ def __init__(self, operands: list[cue.Rep], polynomial: cue.SegmentedPolynomial) assert isinstance(polynomial, cue.SegmentedPolynomial) object.__setattr__(self, "operands", tuple(operands)) object.__setattr__(self, "polynomial", polynomial) - assert ( + if ( len(self.operands) - == self.polynomial.num_inputs + self.polynomial.num_outputs - ) + != self.polynomial.num_inputs + self.polynomial.num_outputs + ): + raise ValueError( + f"Number of operands {len(self.operands)} must equal the number of inputs" + f" {self.polynomial.num_inputs} plus the number of outputs {self.polynomial.num_outputs}" + ) for rep, size in zip(self.operands, self.polynomial.buffer_sizes): assert size is None or size == rep.dim - def __repr__(self): - return self.polynomial.to_string([f"{rep}" for rep in self.operands]) - def __hash__(self) -> int: return hash((self.operands, self.polynomial)) + def __eq__(self, value) -> bool: + assert isinstance(value, EquivariantPolynomial) + return self.operands == value.operands and self.polynomial == value.polynomial + + def __lt__(self, value) -> bool: + assert isinstance(value, EquivariantPolynomial) + return ( + self.num_inputs, + self.num_outputs, + self.operands, + self.polynomial, + ) < ( + value.num_inputs, + value.num_outputs, + value.operands, + value.polynomial, + ) + def __mul__(self, factor: float) -> EquivariantPolynomial: return EquivariantPolynomial(self.operands, self.polynomial * factor) def __rmul__(self, factor: float) -> EquivariantPolynomial: return self.__mul__(factor) + def __repr__(self): + return self.polynomial.to_string([f"{rep}" for rep in self.operands]) + + def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: + return self.polynomial(*inputs) + @property def num_operands(self) -> int: return len(self.operands) @@ -67,6 +94,21 @@ def inputs(self) -> tuple[cue.Rep, ...]: def outputs(self) -> tuple[cue.Rep, ...]: return self.operands[self.num_inputs :] + def consolidate(self) -> EquivariantPolynomial: + return EquivariantPolynomial( + self.operands, + self.polynomial.consolidate(), + ) + + def buffer_used(self) -> list[bool]: + return self.polynomial.buffer_used() + + def remove_unused_buffers(self) -> EquivariantPolynomial: + return EquivariantPolynomial( + [rep for u, rep in zip(self.buffer_used(), self.operands) if u], + self.polynomial.remove_unused_buffers(), + ) + @classmethod def stack( cls, polys: list[EquivariantPolynomial], stacked: list[bool] @@ -96,7 +138,61 @@ def stack( ) operands.append(ope) - return cls( + poly = cls( operands, cue.SegmentedPolynomial.stack([pol.polynomial for pol in polys], stacked), ) + return poly.consolidate() + + def squeeze_modes(self) -> EquivariantPolynomial: + return EquivariantPolynomial( + self.operands, + self.polynomial.squeeze_modes(), + ) + + def flatten_coefficient_modes(self) -> EquivariantPolynomial: + return EquivariantPolynomial( + self.operands, + self.polynomial.flatten_coefficient_modes(), + ) + + def jvp(self, has_tangent: list[bool]) -> EquivariantPolynomial: + return EquivariantPolynomial( + list(self.inputs) + + [x for has, x in zip(has_tangent, self.inputs) if has] + + list(self.outputs), + self.polynomial.jvp(has_tangent), + ) + + def transpose( + self, + is_undefined_primal: list[bool], + has_cotangent: list[bool], + ) -> EquivariantPolynomial: + return EquivariantPolynomial( + # defined inputs + [ + x + for is_undefined, x in zip(is_undefined_primal, self.inputs) + if not is_undefined + ] + # cotangent outputs + + [x for has, x in zip(has_cotangent, self.outputs) if has] + # undefined inputs + + [ + x + for is_undefined, x in zip(is_undefined_primal, self.inputs) + if is_undefined + ], + self.polynomial.transpose(is_undefined_primal, has_cotangent), + ) + + def backward( + self, requires_gradient: list[bool], has_cotangent: list[bool] + ) -> EquivariantPolynomial: + return EquivariantPolynomial( + list(self.inputs) + + [x for has, x in zip(has_cotangent, self.outputs) if has] + + [x for req, x in zip(requires_gradient, self.inputs) if req], + self.polynomial.backward(requires_gradient, has_cotangent), + ) diff --git a/cuequivariance/cuequivariance/equivariant_tensor_product.py b/cuequivariance/cuequivariance/equivariant_tensor_product.py index fee1e136..ca14896c 100644 --- a/cuequivariance/cuequivariance/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/equivariant_tensor_product.py @@ -326,35 +326,6 @@ def memory_cost( def backward(self, input: int) -> tuple[EquivariantTensorProduct, tuple[int, ...]]: """ The backward pass of the equivariant tensor product. - - Args: - input: The input with respect to which the backward pass is computed. - - Returns: - A tuple containing the ETP representing the backward pass and the permutation of the operands between the original and the backward ETP. - - Examples: - >>> e = cue.descriptors.fully_connected_tensor_product( - ... cue.Irreps("SO3", "4x0+1x1"), - ... cue.Irreps("SO3", "4x0+2x1"), - ... cue.Irreps("SO3", "4x0+3x1") - ... ) - >>> e - EquivariantTensorProduct(114x0 x 4x0+1 x 4x0+2x1 -> 4x0+3x1) - >>> e_bwd, i = e.backward(0) - >>> e_bwd - EquivariantTensorProduct(4x0+3x1 x 4x0+1 x 4x0+2x1 -> 114x0) - >>> i - (3, 1, 2, 0) - - >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2, 3]) - >>> e - EquivariantTensorProduct((1)^(0..3) -> 0+1+2+3) - >>> e_bwd, i = e.backward(0) - >>> e_bwd - EquivariantTensorProduct(0+1+2+3 x (1)^(0..2) -> 1) - >>> i - (1, 0, 0) """ assert input < self.num_inputs diff --git a/cuequivariance/cuequivariance/experimental/escn.py b/cuequivariance/cuequivariance/experimental/escn.py index c08b0f61..54ec4e24 100644 --- a/cuequivariance/cuequivariance/experimental/escn.py +++ b/cuequivariance/cuequivariance/experimental/escn.py @@ -17,7 +17,6 @@ import numpy as np import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp # The function escn_iu_ju_ku below is a 1:1 adaptation of https://github.com/e3nn/e3nn-jax/blob/a2a81ab451b9cd597d7be27b3e1faba79457475d/e3nn_jax/experimental/linear_shtp.py#L38-L165 @@ -26,7 +25,7 @@ def escn_tp( irreps_out: cue.Irreps, m_max: Optional[int] = None, l_max: Optional[int] = None, -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: """ subsrcipts: ``weights[uv],input[u],output[v]`` @@ -39,7 +38,7 @@ def escn_tp( l_max (int, optional): Maximum angular resolution along the principal axis. Returns: - EquivariantTensorProduct: + EquivariantPolynomial: Descriptor of the tensor product part of the eSCN convolution. - Operand 0: weights @@ -61,7 +60,7 @@ def pr(mul_ir: cue.MulIrrep) -> bool: irreps_out = irreps_out.filter(keep=pr) - d = stp.SegmentedTensorProduct.from_subscripts("iuv,ju,kv+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iuv,ju,kv+ijk") for mul, ir in irreps_in: d.add_segment(1, (ir.dim, mul)) @@ -105,13 +104,13 @@ def pr(mul_ir: cue.MulIrrep) -> bool: d = d.normalize_paths_for_operand(2) d = d.flatten_coefficient_modes() - return cue.EquivariantTensorProduct( - d, + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(irreps_in, cue.ir_mul), cue.IrrepsAndLayout(irreps_out, cue.ir_mul), ], + cue.SegmentedPolynomial.trivial(d), ) @@ -119,7 +118,7 @@ def escn_tp_compact( irreps_in: cue.Irreps, irreps_out: cue.Irreps, m_max: Optional[int] = None, -) -> stp.SegmentedTensorProduct: +) -> cue.SegmentedPolynomial: """ subsrcipts: ``weights[uv],input[u],output[v]`` @@ -146,7 +145,7 @@ def escn_tp_compact( if G not in [cue.SO3]: raise NotImplementedError("Only SO3 is supported") - d = stp.SegmentedTensorProduct.from_subscripts("uv,u,v") + d = cue.SegmentedTensorProduct.from_subscripts("uv,u,v") l_max_in = max(ir.l for _, ir in irreps_in) for m in range(-l_max_in, l_max_in + 1): @@ -178,7 +177,8 @@ def escn_tp_compact( d.add_path(i, l_max_in - m, l_max_out + m, c=-1.0) d = d.normalize_paths_for_operand(2) - return d # TODO: return an EquivariantTensorProduct using SphericalSignal + # TODO: return an EquivariantPolynomial using SphericalSignal + return cue.SegmentedPolynomial.trivial(d) class SphericalSignal(cue.Rep): diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py index 78aee0e0..f648e31f 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py @@ -18,14 +18,12 @@ import numpy as np import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp -from cuequivariance import descriptors from cuequivariance.misc.linalg import round_to_sqrt_rational, triu_array def symmetric_contraction( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degrees: list[int] -) -> tuple[cue.EquivariantTensorProduct, np.ndarray]: +) -> tuple[cue.EquivariantPolynomial, np.ndarray]: r""" subscripts: ``weights[u],input[u],output[u]`` @@ -45,9 +43,11 @@ def symmetric_contraction( cuex.equivariant_tensor_product(e, w, cuex.randn(jax.random.key(1), e.inputs[1])) """ assert min(degrees) > 0 - e1 = cue.EquivariantTensorProduct.stack( + + # poly1 replicates the behavior of the original MACE implementation + poly1 = cue.EquivariantPolynomial.stack( [ - cue.EquivariantTensorProduct.stack( + cue.EquivariantPolynomial.stack( [ _symmetric_contraction(irreps_in, irreps_out[i : i + 1], deg) for deg in reversed(degrees) @@ -58,7 +58,7 @@ def symmetric_contraction( ], [True, False, True], ) - e2 = descriptors.symmetric_contraction(irreps_in, irreps_out, degrees) + poly2 = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, degrees) a1, a2 = [ np.concatenate( [ @@ -67,11 +67,11 @@ def symmetric_contraction( 1, None, ) - for d in sorted(e.ds, key=lambda d: d.num_operands) + for _, d in pol.polynomial.tensor_products ], axis=1, ) - for e in [e1, e2] + for pol in [poly1, poly2] ] # This nonzeros selection is just for lightening the inversion @@ -83,7 +83,7 @@ def symmetric_contraction( projection = round_to_sqrt_rational(projection) np.testing.assert_allclose(a1, projection @ a2, atol=1e-7) - return e2, projection + return poly2, projection def _flatten( @@ -103,7 +103,7 @@ def _flatten( def _stp_to_matrix( - d: stp.SegmentedTensorProduct, + d: cue.SegmentedTensorProduct, ) -> np.ndarray: m = np.zeros([ope.num_segments for ope in d.operands]) for path in d.paths: @@ -114,7 +114,7 @@ def _stp_to_matrix( # This function is an adaptation of https://github.com/ACEsuit/mace/blob/bd412319b11c5f56c37cec6c4cfae74b2a49ff43/mace/modules/symmetric_contraction.py def _symmetric_contraction( irreps_in: cue.Irreps, irreps_out: cue.Irreps, degree: int -) -> cue.EquivariantTensorProduct: +) -> cue.EquivariantPolynomial: mul = irreps_in.muls[0] assert all(mul == m for m in irreps_in.muls) assert all(mul == m for m in irreps_out.muls) @@ -125,7 +125,7 @@ def _symmetric_contraction( output_operand = degree + 1 abc = "abcdefgh"[:degree] - d = stp.SegmentedTensorProduct.from_subscripts( + d = cue.SegmentedTensorProduct.from_subscripts( f"u_{'_'.join(f'{a}' for a in abc)}_i+{abc}ui" ) @@ -145,13 +145,13 @@ def _symmetric_contraction( d = d.flatten_coefficient_modes() d = d.append_modes_to_all_operands("u", {"u": mul}) - return cue.EquivariantTensorProduct( - [d], + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul), ], + cue.SegmentedPolynomial(2, 1, [(cue.Operation([0] + [1] * degree + [2]), d)]), ) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 43da8731..e2568400 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -16,8 +16,11 @@ import copy import dataclasses +import itertools from typing import Callable, Sequence +import numpy as np + import cuequivariance as cue from cuequivariance.operation import IVARS, OVARS @@ -38,10 +41,18 @@ def __init__( object.__setattr__(self, "num_outputs", num_outputs) object.__setattr__(self, "tensor_products", sorted(tensor_products)) + @classmethod + def trivial(cls, stp: cue.SegmentedTensorProduct): + return cls( + stp.num_operands - 1, + 1, + [(cue.Operation(tuple(range(stp.num_operands))), stp)], + ) + def __hash__(self) -> int: return hash((self.num_inputs, self.num_outputs, tuple(self.tensor_products))) - def __eq__(self, value): + def __eq__(self, value) -> bool: assert isinstance(value, SegmentedPolynomial) return ( self.num_inputs == value.num_inputs @@ -49,7 +60,7 @@ def __eq__(self, value): and self.tensor_products == value.tensor_products ) - def __lt__(self, value): + def __lt__(self, value) -> bool: assert isinstance(value, SegmentedPolynomial) return ( self.num_inputs, @@ -90,7 +101,6 @@ def to_string(self, buffer_names: list[str] | None = None) -> str: + " ".join( buffer_txts[self.num_inputs : self.num_inputs + self.num_outputs] ) - + " " ) lines = [ "│ " + ope.to_string(self.num_inputs) for ope, _ in self.tensor_products @@ -123,6 +133,24 @@ def to_string(self, buffer_names: list[str] | None = None) -> str: return "\n".join(lines) + def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: + inferred_shape = np.broadcast_shapes(*[x.shape[:-1] for x in inputs]) + inferred_dtype = np.result_type(*[x.dtype for x in inputs]) + outputs = [ + np.zeros(inferred_shape + (size,), dtype=inferred_dtype) + for size in self.output_sizes + ] + for ope, stp in self.tensor_products: + oid, bid = ope.output_operand_buffer(self.num_inputs) + outputs[bid - self.num_inputs] += ( + cue.segmented_tensor_product.compute_last_operand( + stp.move_operand_last(oid), + *[inputs[bid] for bid in ope.input_buffers(self.num_inputs)], + dtype=inferred_dtype, + ) + ) + return outputs + @property def buffer_sizes(self) -> list[int | None]: sizes = [None] * (self.num_inputs + self.num_outputs) @@ -159,7 +187,96 @@ def map_tensor_products( self.num_inputs, self.num_outputs, new_tensor_products ) + def consolidate(self) -> SegmentedPolynomial: + groups = itertools.groupby( + self.tensor_products, + key=lambda x: (x[0], x[1].operands, x[1].coefficient_subscripts), + ) + new_tensor_products = [ + ( + ope, + cue.SegmentedTensorProduct( + operands=operands, + coefficient_subscripts=coefficient_subscripts, + paths=[path for _, stp in elements for path in stp.paths], + ).consolidate_paths(), + ) + for (ope, operands, coefficient_subscripts), elements in groups + ] + return SegmentedPolynomial( + self.num_inputs, self.num_outputs, new_tensor_products + ) + + def buffer_used(self) -> list[bool]: + return [ + any(buffer in ope.buffers for ope, _ in self.tensor_products) + for buffer in range(self.num_inputs + self.num_outputs) + ] + + def remove_unused_buffers(self) -> SegmentedPolynomial: + used = self.buffer_used() + new_index = [] + i = 0 + for u in used: + if u: + new_index.append(i) + i += 1 + else: + new_index.append(None) + + return SegmentedPolynomial( + sum(used[: self.num_inputs]), + sum(used[self.num_inputs :]), + [ + (cue.Operation([new_index[buffer] for buffer in ope.buffers]), stp) + for ope, stp in self.tensor_products + ], + ) + + @classmethod + def stack( + cls, polys: list[SegmentedPolynomial], stacked: list[bool] + ) -> SegmentedPolynomial: + assert len(polys) > 0 + num_inputs = polys[0].num_inputs + num_outputs = polys[0].num_outputs + assert all(pol.num_inputs == num_inputs for pol in polys) + assert all(pol.num_outputs == num_outputs for pol in polys) + assert len(stacked) == num_inputs + num_outputs + + tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] = [] + for index, pol in enumerate(polys): + for ope, stp in pol.tensor_products: + stp = copy.deepcopy(stp) + for oid, buffer in enumerate(ope.buffers): + if stacked[buffer]: + for p in reversed(polys[:index]): + stp.insert_segments(oid, 0, p.buffer_segments(buffer)) + for p in polys[index + 1 :]: + stp.insert_segments(oid, -1, p.buffer_segments(buffer)) + tensor_products.append((ope, stp)) + return cls(num_inputs, num_outputs, tensor_products) + + def squeeze_modes(self) -> SegmentedPolynomial: + return SegmentedPolynomial( + self.num_inputs, + self.num_outputs, + [(ope, stp.squeeze_modes()) for ope, stp in self.tensor_products], + ) + + def flatten_coefficient_modes(self) -> SegmentedPolynomial: + return SegmentedPolynomial( + self.num_inputs, + self.num_outputs, + [ + (ope, stp.flatten_coefficient_modes()) + for ope, stp in self.tensor_products + ], + ) + def jvp(self, has_tangent: list[bool]) -> SegmentedPolynomial: + assert len(has_tangent) == self.num_inputs + new_tps = [] for ope, stp in self.tensor_products: jvps = ope.jvp(has_tangent) @@ -177,6 +294,9 @@ def transpose( is_undefined_primal: list[bool], has_cotangent: list[bool], ) -> SegmentedPolynomial: + assert len(is_undefined_primal) == self.num_inputs + assert len(has_cotangent) == self.num_outputs + new_tps = [] for ope, stp in self.tensor_products: ope = ope.transpose(is_undefined_primal, has_cotangent) @@ -188,29 +308,14 @@ def transpose( new_tps, ) - @classmethod - def stack( - cls, polys: list[SegmentedPolynomial], stacked: list[bool] + def backward( + self, requires_gradient: list[bool], has_cotangent: list[bool] ) -> SegmentedPolynomial: - assert len(polys) > 0 - num_inputs = polys[0].num_inputs - num_outputs = polys[0].num_outputs - assert all(pol.num_inputs == num_inputs for pol in polys) - assert all(pol.num_outputs == num_outputs for pol in polys) - assert len(stacked) == num_inputs + num_outputs - - tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] = [] - for index, pol in enumerate(polys): - for ope, stp in pol.tensor_products: - stp = copy.deepcopy(stp) - for oid, buffer in enumerate(ope.buffers): - if stacked[buffer]: - for p in reversed(polys[:index]): - stp.insert_segments(oid, 0, p.buffer_segments(buffer)) - for p in polys[index + 1 :]: - stp.insert_segments(oid, -1, p.buffer_segments(buffer)) - tensor_products.append((ope, stp)) - return cls(num_inputs, num_outputs, tensor_products) + return self.jvp(requires_gradient).transpose( + is_undefined_primal=[False] * self.num_inputs + + [True] * sum(requires_gradient), + has_cotangent=has_cotangent, + ) def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: segments = None diff --git a/cuequivariance/tests/equivariant_tensor_products_test.py b/cuequivariance/tests/equivariant_polynomial_test.py similarity index 64% rename from cuequivariance/tests/equivariant_tensor_products_test.py rename to cuequivariance/tests/equivariant_polynomial_test.py index 14cd90b0..19ef6382 100644 --- a/cuequivariance/tests/equivariant_tensor_products_test.py +++ b/cuequivariance/tests/equivariant_polynomial_test.py @@ -16,8 +16,10 @@ import pytest import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp -from cuequivariance import descriptors + + +def test_transpose(): + cue.descriptors.transpose(cue.Irreps(cue.O3, "3x1e"), cue.mul_ir, cue.ir_mul) def test_commutativity_squeeze_flatten(): @@ -25,45 +27,45 @@ def test_commutativity_squeeze_flatten(): irreps2 = cue.Irreps("O3", "1x0e + 1x1o") irreps3 = cue.Irreps("O3", "32x0e + 32x1o") - d = descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3).d + poly = cue.descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3) assert ( - d.squeeze_modes().flatten_coefficient_modes() - == d.flatten_coefficient_modes().squeeze_modes() + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() ) - d = descriptors.full_tensor_product(irreps1, irreps2, irreps3).d + poly = cue.descriptors.full_tensor_product(irreps1, irreps2, irreps3) assert ( - d.squeeze_modes().flatten_coefficient_modes() - == d.flatten_coefficient_modes().squeeze_modes() + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() ) - d = descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3).d + poly = cue.descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3) assert ( - d.squeeze_modes().flatten_coefficient_modes() - == d.flatten_coefficient_modes().squeeze_modes() + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() ) - d = descriptors.linear(irreps1, irreps2).d + poly = cue.descriptors.linear(irreps1, irreps2) assert ( - d.squeeze_modes().flatten_coefficient_modes() - == d.flatten_coefficient_modes().squeeze_modes() + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() ) @pytest.mark.parametrize("ell", [1, 2, 3, 4]) def test_spherical_harmonics(ell: int): - d = descriptors.spherical_harmonics(cue.SO3(1), [ell]).d + poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [ell]) vec = np.random.randn(3) axis = np.random.randn(3) angle = np.random.rand() - yl = stp.compute_last_operand(d, *(vec,) * ell) + [yl] = poly(vec) R = cue.SO3(1).rotation(axis, angle) Rl = cue.SO3(ell).rotation(axis, angle) - yl1 = stp.compute_last_operand(d, *(R @ vec,) * ell) + [yl1] = poly(R @ vec) yl2 = Rl @ yl np.testing.assert_allclose(yl1, yl2) @@ -77,7 +79,7 @@ def test_y_rotation(ell: int): gamma = -0.5 irrep = cue.SO3(ell) - d = descriptors.yxy_rotation(cue.Irreps("SO3", [irrep])).d + poly = cue.descriptors.yxy_rotation(cue.Irreps("SO3", [irrep])) def enc(th: float): m = np.arange(1, ell + 1) @@ -86,7 +88,7 @@ def enc(th: float): return np.concatenate([c[::-1], [1.0], s]) x = np.random.randn(irrep.dim) - y1 = stp.compute_last_operand(d, enc(gamma), enc(beta), enc(alpha), x) + [y1] = poly(enc(gamma), enc(beta), enc(alpha), x) A = irrep.rotation(np.array([0.0, 1.0, 0.0]), alpha) B = irrep.rotation(np.array([1.0, 0.0, 0.0]), beta) diff --git a/cuequivariance/tests/segmented_tensor_product/dot_test.py b/cuequivariance/tests/segmented_tensor_product/dot_test.py index 4d8c93c4..18d237de 100644 --- a/cuequivariance/tests/segmented_tensor_product/dot_test.py +++ b/cuequivariance/tests/segmented_tensor_product/dot_test.py @@ -49,11 +49,11 @@ def make_examples(): cue.Irreps("SO3", "4x0 + 3x1"), cue.Irreps("SO3", "3x0 + 5x1"), irreps_middle, - ).d + ).polynomial.tensor_products[0][1] assert dx.subscripts == "uvw,iu,jv,kw+ijk" dy = descriptors.channelwise_tensor_product( irreps_middle, cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0 + 1") - ).d + ).polynomial.tensor_products[0][1] dy = dy.squeeze_modes("v") assert dy.subscripts == "u,iu,j,ku+ijk" dy = dy.add_or_rename_modes("w_kw_l_mw+klm") diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index cfa76d96..0ce3c481 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -156,6 +156,23 @@ def __init__( use_fallback: Optional[bool] = None, ): super().__init__() + + # TODO: remove this when re-design + if isinstance(e, cue.EquivariantPolynomial): + assert e.num_outputs == 1 + for ope, stp in e.tensor_products: + inputs = list(range(e.num_inputs)) + output = e.num_inputs + expected = ( + inputs[: stp.num_operands - 1] + + [inputs[-1]] * max(0, stp.num_operands - e.num_operands) + + [output] + ) + assert ope.buffers == expected + e = cue.EquivariantTensorProduct( + [stp for _, stp in e.tensor_products], e.inputs + e.outputs + ) + if not isinstance(layout_in, tuple): layout_in = (layout_in,) * e.num_inputs if len(layout_in) != e.num_inputs: From 7fff505e3a6873e48642592bcd53da03ea536d28 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 08:20:36 -0800 Subject: [PATCH 011/107] add flops and memory methods to EquivariantPolynomial and SegmentedPolynomial classes --- .../cuequivariance/equivariant_polynomial.py | 7 +++++++ .../cuequivariance/segmented_polynomial.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/equivariant_polynomial.py index 4c3d14ea..d6db7c74 100644 --- a/cuequivariance/cuequivariance/equivariant_polynomial.py +++ b/cuequivariance/cuequivariance/equivariant_polynomial.py @@ -196,3 +196,10 @@ def backward( + [x for req, x in zip(requires_gradient, self.inputs) if req], self.polynomial.backward(requires_gradient, has_cotangent), ) + + def flops(self, batch_size: int = 1) -> int: + return self.polynomial.flops(batch_size) + + def memory(self, batch_sizes: list[int]) -> int: + assert len(batch_sizes) == len(self.operands) + return sum(Z * rep.dim for Z, rep in zip(batch_sizes, self.operands)) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index e2568400..c9a401cb 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -151,6 +151,10 @@ def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: ) return outputs + @property + def num_operands(self) -> int: + return self.num_inputs + self.num_outputs + @property def buffer_sizes(self) -> list[int | None]: sizes = [None] * (self.num_inputs + self.num_outputs) @@ -317,6 +321,17 @@ def backward( has_cotangent=has_cotangent, ) + def flops(self, batch_size: int = 1) -> int: + n = 0 + for ope, stp in self.tensor_products: + oid, _ = ope.output_operand_buffer(self.num_inputs) + n += stp.flop_cost(oid) + return batch_size * n + + def memory(self, batch_sizes: list[int]) -> int: + assert len(batch_sizes) == self.num_operands + return sum(Z * size for Z, size in zip(batch_sizes, self.buffer_sizes)) + def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: segments = None for ope, stp in self.tensor_products: From 845ea37e19c2324cfd77343a584b6546f9f2a001 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 10:36:12 -0800 Subject: [PATCH 012/107] fix --- .../segmented_tensor_product/segmented_tensor_product.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index 8c73b66f..31285c57 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -288,7 +288,7 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str: ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1") - ... ).d + ... ).polynomial.tensor_products[0][1] >>> d = d.flatten_coefficient_modes() >>> print(d.to_text()) uvw,u,v,w sizes=320,16,16,16 num_segments=5,4,4,4 num_paths=16 u=4 v=4 w=4 @@ -405,7 +405,7 @@ def to_base64(self, extended: bool = False) -> str: ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1") - ... ).d + ... ).polynomial.tensor_products[0][1] >>> print(d.to_base64()) eJytkstuwjAQRX/F8r...lTF2zlX91/fHyvj2Z4= """ @@ -429,7 +429,7 @@ def get_dims(self, m: str) -> set[int]: ... cue.Irreps("SO3", "4x0+8x1"), ... cue.Irreps("SO3", "3x0+3x1"), ... cue.Irreps("SO3", "5x0+7x1") - ... ).d + ... ).polynomial.tensor_products[0][1] >>> d.get_dims("u") {8, 4} >>> d.get_dims("v") From e979258805281d9f798fe78de85157a122925447 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 12:45:10 -0800 Subject: [PATCH 013/107] simplify code --- .../cuequivariance/equivariant_polynomial.py | 6 +-- .../cuequivariance/segmented_polynomial.py | 35 ++++++++++++- .../primitives/segmented_polynomial.py | 51 +++++-------------- 3 files changed, 50 insertions(+), 42 deletions(-) diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/equivariant_polynomial.py index d6db7c74..7ab330e1 100644 --- a/cuequivariance/cuequivariance/equivariant_polynomial.py +++ b/cuequivariance/cuequivariance/equivariant_polynomial.py @@ -94,10 +94,10 @@ def inputs(self) -> tuple[cue.Rep, ...]: def outputs(self) -> tuple[cue.Rep, ...]: return self.operands[self.num_inputs :] - def consolidate(self) -> EquivariantPolynomial: + def fuse_stps(self) -> EquivariantPolynomial: return EquivariantPolynomial( self.operands, - self.polynomial.consolidate(), + self.polynomial.fuse_stps(), ) def buffer_used(self) -> list[bool]: @@ -142,7 +142,7 @@ def stack( operands, cue.SegmentedPolynomial.stack([pol.polynomial for pol in polys], stacked), ) - return poly.consolidate() + return poly.fuse_stps() def squeeze_modes(self) -> EquivariantPolynomial: return EquivariantPolynomial( diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index c9a401cb..f4a10faf 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -191,7 +191,7 @@ def map_tensor_products( self.num_inputs, self.num_outputs, new_tensor_products ) - def consolidate(self) -> SegmentedPolynomial: + def fuse_stps(self) -> SegmentedPolynomial: groups = itertools.groupby( self.tensor_products, key=lambda x: (x[0], x[1].operands, x[1].coefficient_subscripts), @@ -211,6 +211,30 @@ def consolidate(self) -> SegmentedPolynomial: self.num_inputs, self.num_outputs, new_tensor_products ) + def consolidate(self) -> SegmentedPolynomial: + def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): + stp = ( + stp.consolidate_modes() + .squeeze_modes() + .remove_empty_segments() + .consolidate_paths() + .sort_paths() + ) + if stp.num_paths == 0: + return None + return ope, stp + + return self.fuse_stps().map_tensor_products(f) + + def used_buffers(self) -> list[int]: + return sorted( + set( + itertools.chain.from_iterable( + ope.buffers for ope, _ in self.tensor_products + ) + ) + ) + def buffer_used(self) -> list[bool]: return [ any(buffer in ope.buffers for ope, _ in self.tensor_products) @@ -346,3 +370,12 @@ def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: if segments is None: raise ValueError(f"Buffer {buffer} is not used") return segments + + def sort_indices_for_identical_operands(self) -> SegmentedPolynomial: + def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): + for set_of_operands in ope.operands_with_identical_buffers(): + stp = stp.sort_indices_for_identical_operands(set_of_operands) + stp = stp.sort_paths() + return ope, stp + + return self.map_tensor_products(optimize_paths) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index 6bbe6b7a..e0eda59c 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -181,45 +181,26 @@ def segmented_polynomial_prim( jax.ShapeDtypeStruct(x.shape, x.dtype) for x in outputs_shape_dtype ] - def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): - stp = stp.consolidate_modes().remove_empty_segments().consolidate_paths() - if stp.num_paths == 0: - return None - return ope, stp - - polynomial = polynomial.map_tensor_products(f) - - used_buffers = set() - used_indices = set() - for ope, _ in polynomial.tensor_products: - for i in ope.buffers: - used_buffers.add(i) - if buffer_index[i] >= 0: - used_indices.add(buffer_index[i]) - used_buffers = sorted(used_buffers) # maps: new buffer index -> old buffer index - used_indices = sorted(used_indices) # maps: new index -> old index - - new_num_inputs = sum([i < len(inputs) for i in used_buffers]) - new_poynomial = cue.SegmentedPolynomial( - new_num_inputs, - len(used_buffers) - new_num_inputs, - [ - (cue.Operation([used_buffers.index(i) for i in ope.buffers]), stp) - for ope, stp in polynomial.tensor_products - ], + polynomial = polynomial.consolidate() + used_buffers = polynomial.used_buffers() + polynomial = polynomial.remove_unused_buffers() + + used_indices = sorted( + {buffer_index[i] for i in used_buffers if buffer_index[i] >= 0} ) - new_outputs = segmented_polynomial_p.bind( - *[inputs[i] for i in used_buffers[:new_num_inputs]], + used_outputs = segmented_polynomial_p.bind( + *[inputs[i] for i in used_buffers[: polynomial.num_inputs]], *[indices[i] for i in used_indices], buffer_index=tuple( used_indices.index(buffer_index[i]) if buffer_index[i] >= 0 else -1 for i in used_buffers ), outputs_shape_dtype=tuple( - outputs_shape_dtype[i - len(inputs)] for i in used_buffers[new_num_inputs:] + outputs_shape_dtype[i - len(inputs)] + for i in used_buffers[polynomial.num_inputs :] ), - polynomial=new_poynomial, + polynomial=polynomial, math_dtype=jnp.dtype(math_dtype), name=str(name), impl=impl, @@ -230,7 +211,7 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): else: outputs = [jnp.zeros(out.shape, out.dtype) for out in outputs_shape_dtype] - for i, output in zip(used_buffers[new_num_inputs:], new_outputs): + for i, output in zip(used_buffers[polynomial.num_inputs :], used_outputs): outputs[i - len(inputs)] = output return tuple(outputs) @@ -287,13 +268,7 @@ def segmented_polynomial_impl( inputs, indices = inputs_and_indices[:num_inputs], inputs_and_indices[num_inputs:] del inputs_and_indices - def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): - for set_of_operands in ope.operands_with_identical_buffers(): - stp = stp.sort_indices_for_identical_operands(set_of_operands) - stp = stp.sort_paths() - return ope, stp - - polynomial = polynomial.map_tensor_products(optimize_paths) + polynomial = polynomial.sort_indices_for_identical_operands() outputs = None kwargs = dict( From bfa3a44eca9cc0d6614104dca2a7e2eb2f7b46fb Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 12:55:52 -0800 Subject: [PATCH 014/107] add warnings --- .../cuequivariance/equivariant_tensor_product.py | 7 +++++++ .../primitives/equivariant_tensor_product.py | 8 ++++++++ .../cuequivariance_jax/primitives/tensor_product.py | 8 +++++++- 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/cuequivariance/cuequivariance/equivariant_tensor_product.py b/cuequivariance/cuequivariance/equivariant_tensor_product.py index ca14896c..3d0601ae 100644 --- a/cuequivariance/cuequivariance/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/equivariant_tensor_product.py @@ -16,6 +16,7 @@ import copy import dataclasses +import warnings from typing import Optional, Sequence, Union import cuequivariance as cue @@ -62,6 +63,12 @@ def __init__( operands: list[cue.Rep], symmetrize: bool = True, ): + warnings.warn( + "EquivariantTensorProduct is deprecated and will be removed in a future version. " + "Please use EquivariantPolynomial instead.", + DeprecationWarning, + stacklevel=2, + ) operands = tuple(operands) if isinstance(d, stp.SegmentedTensorProduct): assert len(operands) == d.num_operands diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py index e40a31d4..7df5a661 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings + import jax import jax.numpy as jnp @@ -88,6 +90,12 @@ def equivariant_tensor_product( [[ 1. ... ] [ 2. ... ]] """ + warnings.warn( + "equivariant_tensor_product is deprecated and will be removed in a future version. " + "Please use cuex.equivariant_polynomial instead.", + DeprecationWarning, + stacklevel=2, + ) assert e.num_inputs > 0 if len(inputs) == 0: diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py index 78c7395e..152e4eee 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import warnings import jax import jax.numpy as jnp @@ -62,7 +63,12 @@ def tensor_product( Returns: list of jax.Array: The result of the tensor product. """ - + warnings.warn( + "tensor_product is deprecated and will be removed in a future version. " + "Please use cuex.segmented_polynomial instead.", + DeprecationWarning, + stacklevel=2, + ) if name is None: name = "tensor_product" From 0ee5a6a906689290714342b0cd9f655aa6285aac Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 13:32:33 -0800 Subject: [PATCH 015/107] fix tests --- .../operations/spherical_harmonics.py | 4 +- .../primitives/segmented_polynomial.py | 2 +- ...test.py => equivariant_polynomial_test.py} | 15 ++++--- ...t_test.py => segmented_polynomial_test.py} | 39 +++++++++++-------- 4 files changed, 35 insertions(+), 25 deletions(-) rename cuequivariance_jax/tests/primitives/{equivariant_tensor_product_test.py => equivariant_polynomial_test.py} (91%) rename cuequivariance_jax/tests/primitives/{tensor_product_test.py => segmented_polynomial_test.py} (77%) diff --git a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py b/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py index 2373f3ad..8749abe7 100644 --- a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py +++ b/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py @@ -49,9 +49,9 @@ def spherical_harmonics( if normalize: vector = _normalize(vector) - return cuex.equivariant_tensor_product( + return cuex.equivariant_polynomial( descriptors.spherical_harmonics(ir, ls, vector.layout), - vector, + [vector], name="spherical_harmonics", ) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index e0eda59c..e54be3c6 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -78,7 +78,7 @@ def segmented_polynomial( assert len(inputs) == polynomial.num_inputs assert len(outputs_shape_dtype) == polynomial.num_outputs - buffers = inputs + outputs_shape_dtype + buffers = list(inputs) + list(outputs_shape_dtype) if indices is None: indices = [None] * len(buffers) diff --git a/cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_jax/tests/primitives/equivariant_polynomial_test.py similarity index 91% rename from cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py rename to cuequivariance_jax/tests/primitives/equivariant_polynomial_test.py index 8579ab57..cc2f408a 100644 --- a/cuequivariance_jax/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_jax/tests/primitives/equivariant_polynomial_test.py @@ -26,7 +26,9 @@ def test_special_double_backward(): 32 * cue.Irreps("O3", "0e + 1o + 2e"), 32 * cue.Irreps("O3", "0e + 1o"), [1, 2] ) rep_w, rep_x = e.inputs - h = cuex.equivariant_tensor_product(e) + + def h(*inputs): + return cuex.equivariant_polynomial(e, inputs) h0 = lambda w, x: h(w, x).array.sum() ** 2 # noqa h1 = lambda w, x: jax.grad(h0, 1)(w, x).array.sum() ** 2 # noqa @@ -56,7 +58,7 @@ def make_uniform1d_descriptors(): @pytest.mark.parametrize("e", make_uniform1d_descriptors()) -def test_custom_kernel(e: cue.EquivariantTensorProduct): +def test_custom_kernel(e: cue.EquivariantPolynomial): if jax.default_backend() != "gpu": pytest.skip("test_custom_kernel requires CUDA") @@ -79,13 +81,14 @@ def test_custom_kernel(e: cue.EquivariantTensorProduct): output_batch_shape = (num_nodes,) def fwd(inputs, indices, impl): - return cuex.equivariant_tensor_product( + return cuex.equivariant_polynomial( e, - indices=indices, - output_batch_shape=output_batch_shape, + inputs, + jax.ShapeDtypeStruct(output_batch_shape + (e.outputs[0].dim,), jnp.float64), + indices, math_dtype=jnp.float64, impl=impl, - )(*inputs).array + ).array out0 = fwd(inputs, indices, impl="jax") out1 = fwd(inputs, indices, impl="cuda") diff --git a/cuequivariance_jax/tests/primitives/tensor_product_test.py b/cuequivariance_jax/tests/primitives/segmented_polynomial_test.py similarity index 77% rename from cuequivariance_jax/tests/primitives/tensor_product_test.py rename to cuequivariance_jax/tests/primitives/segmented_polynomial_test.py index 1109101f..7aedeaf6 100644 --- a/cuequivariance_jax/tests/primitives/tensor_product_test.py +++ b/cuequivariance_jax/tests/primitives/segmented_polynomial_test.py @@ -24,14 +24,18 @@ def test_one_operand(): d = cue.SegmentedTensorProduct.empty_segments([1]) - [out] = cuex.tensor_product( - [(cue.Operation([0]), d)], [], [jax.ShapeDtypeStruct((2, 1), jnp.float32)] + [out] = cuex.segmented_polynomial( + cue.SegmentedPolynomial(0, 1, [(cue.Operation([0]), d)]), + [], + [jax.ShapeDtypeStruct((2, 1), jnp.float32)], ) np.testing.assert_array_equal(out, np.array([[0.0], [0.0]])) d.add_path(0, c=123) - [out] = cuex.tensor_product( - [(cue.Operation([0]), d)], [], [jax.ShapeDtypeStruct((2, 1), jnp.float32)] + [out] = cuex.segmented_polynomial( + cue.SegmentedPolynomial(0, 1, [(cue.Operation([0]), d)]), + [], + [jax.ShapeDtypeStruct((2, 1), jnp.float32)], ) np.testing.assert_array_equal(out, np.array([[123.0], [123.0]])) @@ -44,10 +48,8 @@ def test_UnshapedArray_bug(): x = jnp.ones((2, 1)) def f(w, x): - [out] = cuex.tensor_product( - [(cue.Operation([0, 2]), e.ds[0]), (cue.Operation([0, 1, 2]), e.ds[1])], - [w, x], - [jax.ShapeDtypeStruct((2, 1), jnp.float32)], + [out] = cuex.segmented_polynomial( + e.polynomial, [w, x], [jax.ShapeDtypeStruct((2, 1), jnp.float32)] ) return jnp.sum(out) @@ -59,9 +61,9 @@ def test_multiple_operand_shape_bug(): # Before, it was not possible to have an input # with a different shape than the output of the same operand. def h(x): - d = cue.descriptors.spherical_harmonics(cue.SO3(1), [2]).d - [out] = cuex.tensor_product( - [(cue.Operation([0, 0, 1]), d)], + e = cue.descriptors.spherical_harmonics(cue.SO3(1), [2]) + [out] = cuex.segmented_polynomial( + e.polynomial, [x], [jax.ShapeDtypeStruct((5,), jnp.float32)], ) @@ -89,13 +91,18 @@ def test_vmap(): e = cue.descriptors.full_tensor_product( cue.Irreps("SO3", "1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "1") ) + d = e.polynomial.tensor_products[0][1] def f(x1, x2, i1): - return cuex.tensor_product( - [ - (cue.Operation([0, 1, 2]), e.ds[0]), - (cue.Operation([0, 1, 3]), e.ds[0]), - ], + return cuex.segmented_polynomial( + cue.SegmentedPolynomial( + 2, + 2, + [ + (cue.Operation([0, 1, 2]), d), + (cue.Operation([0, 1, 3]), d), + ], + ), [x1, x2], [ jax.ShapeDtypeStruct((2, 3), jnp.float32), From 252a2054c8b94628021739204db1f51f22d1b8cd Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 27 Feb 2025 14:15:51 -0800 Subject: [PATCH 016/107] remove old functions --- .../cuequivariance_jax/__init__.py | 4 - .../primitives/equivariant_polynomial.py | 13 +- .../primitives/equivariant_tensor_product.py | 172 ------------------ .../primitives/tensor_product.py | 83 --------- 4 files changed, 8 insertions(+), 264 deletions(-) delete mode 100644 cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py delete mode 100644 cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index f6a99482..bec163cb 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -24,9 +24,7 @@ from .rep_array.utils import concatenate, randn, as_irreps_array, clebsch_gordan from .primitives.segmented_polynomial import segmented_polynomial -from .primitives.tensor_product import tensor_product from .primitives.equivariant_polynomial import equivariant_polynomial -from .primitives.equivariant_tensor_product import equivariant_tensor_product from .operations.activation import ( normalspace, @@ -48,9 +46,7 @@ "as_irreps_array", "clebsch_gordan", "segmented_polynomial", - "tensor_product", "equivariant_polynomial", - "equivariant_tensor_product", "normalspace", "normalize_function", "function_parity", diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py index e6cbcb39..c206f372 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py @@ -51,7 +51,10 @@ def equivariant_polynomial( >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) >>> e - EquivariantPolynomial((1)^(0..2) -> 0+1+2) + ╭ a=1 -> B=0+1+2 + │ B ───── sizes=9 num_segments=9 num_paths=1 + │ a B ─── , sizes=3,9 num_segments=3,9 num_paths=3 + ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=8 We need some input data. @@ -62,7 +65,7 @@ def equivariant_polynomial( Now we can execute the equivariant tensor product. - >>> cuex.equivariant_tensor_product(e, x) + >>> cuex.equivariant_polynomial(e, [x]) {0: 0+1+2} [1. ... ] @@ -79,11 +82,11 @@ def equivariant_polynomial( ... [0.0, 0.0, 1.0], ... [1.0, 0.0, 0.0], ... ])) - >>> cuex.equivariant_tensor_product( + >>> cuex.equivariant_polynomial( ... e, - ... x, + ... [x], + ... [jax.ShapeDtypeStruct((2, e.outputs[0].dim), jnp.float32)], ... indices=[None, i_out], - ... output_batch_shape=(2,), ... ) {1: 0+1+2} [[ 1. ... ] diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py deleted file mode 100644 index 7df5a661..00000000 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_tensor_product.py +++ /dev/null @@ -1,172 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings - -import jax -import jax.numpy as jnp - -import cuequivariance as cue -import cuequivariance_jax as cuex - - -def equivariant_tensor_product( - e: cue.EquivariantTensorProduct, - *inputs: cuex.RepArray | jax.Array, - indices: list[jax.Array | None] | None = None, - output_batch_shape: tuple[int, ...] | None = None, - output_dtype: jnp.dtype | None = None, - math_dtype: jnp.dtype | None = None, - name: str | None = None, - impl: str = "auto", -) -> cuex.RepArray: - """Compute the equivariant tensor product of the input arrays. - - Args: - e (:class:`cue.EquivariantTensorProduct `): The equivariant tensor product descriptor. - *inputs (RepArray or jax.Array): The input arrays. - indices (list of jax.Array or None, optional): The optional indices of the inputs and output. - output_batch_shape (tuple of int, optional): The batch shape of the output array. - output_dtype (jnp.dtype, optional): The data type for the output array. Defaults to None. - math_dtype (jnp.dtype, optional): The data type for computational operations. Defaults to None. - name (str, optional): The name of the operation. Defaults to None. - - Returns: - RepArray: The result of the equivariant tensor product. - - Examples: - - Let's create a descriptor for the spherical harmonics of degree 0, 1, and 2. - - >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) - >>> e - EquivariantTensorProduct((1)^(0..2) -> 0+1+2) - - We need some input data. - - >>> with cue.assume(cue.SO3, cue.ir_mul): - ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) - >>> x - {0: 1} [0. 1. 0.] - - Now we can execute the equivariant tensor product. - - >>> cuex.equivariant_tensor_product(e, x) - {0: 0+1+2} - [1. ... ] - - The `indices` argument allows to specify a list of optional int32 arrays for each input and for the output (`None` means no index and `indices[-1]` is the output index). The indices are used to select the elements of the input arrays and to specify the output index. - In the following example, we will index the output. The input has a batch shape of (3,) and the output has a batch shape of (2,). - - >>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32) - - The `i_out` array is used to map the result to the output indices. - - >>> with cue.assume(cue.SO3, cue.ir_mul): - ... x = cuex.RepArray("1", jnp.array([ - ... [0.0, 1.0, 0.0], - ... [0.0, 0.0, 1.0], - ... [1.0, 0.0, 0.0], - ... ])) - >>> cuex.equivariant_tensor_product( - ... e, - ... x, - ... indices=[None, i_out], - ... output_batch_shape=(2,), - ... ) - {1: 0+1+2} - [[ 1. ... ] - [ 2. ... ]] - """ - warnings.warn( - "equivariant_tensor_product is deprecated and will be removed in a future version. " - "Please use cuex.equivariant_polynomial instead.", - DeprecationWarning, - stacklevel=2, - ) - assert e.num_inputs > 0 - - if len(inputs) == 0: - return lambda *inputs: equivariant_tensor_product( - e, - *inputs, - indices=indices, - output_batch_shape=output_batch_shape, - output_dtype=output_dtype, - math_dtype=math_dtype, - name=name, - impl=impl, - ) - - if len(inputs) != e.num_inputs: - raise ValueError( - f"Unexpected number of inputs. Expected {e.num_inputs}, got {len(inputs)}." - ) - - for i, (x, rep) in enumerate(zip(inputs, e.inputs)): - if isinstance(x, cuex.RepArray): - assert x.rep(-1) == rep, ( - f"Input {i} should have representation {rep}, got {x.rep(-1)}." - ) - else: - assert x.ndim >= 1, ( - f"Input {i} should have at least one dimension, got {x.ndim}." - ) - assert x.shape[-1] == rep.dim, ( - f"Input {i} should have dimension {rep.dim}, got {x.shape[-1]}." - ) - if not rep.is_scalar(): - raise ValueError( - f"Input {i} should be a RepArray unless the input is scalar. Got {type(x)} for {rep}." - ) - - inputs: list[jax.Array] = [getattr(x, "array", x) for x in inputs] - - if indices is None: - indices = [None] * e.num_operands - - if len(indices) != e.num_operands: - raise ValueError( - f"Unexpected number of indices. indices should None or a list of length {e.num_operands}, got a list of length {len(indices)}." - ) - - if output_dtype is None: - output_dtype = jnp.result_type(*inputs) - - if output_batch_shape is None: - if indices[-1] is not None: - raise ValueError( - "When output indices are provided, output_batch_shape must be provided." - ) - output_batch_shape = jnp.broadcast_shapes( - *[ - x.shape[:-1] if i is None else i.shape + x.shape[1:-1] - for i, x in zip(indices, inputs) - ] - ) - - descriptors = [(cue.Operation(e.map_operands(d.num_operands)), d) for d in e.ds] - - [x] = cuex.tensor_product( - descriptors, - inputs, - [jax.ShapeDtypeStruct(output_batch_shape + (e.output.dim,), output_dtype)], - indices, - math_dtype=math_dtype, - name=name, - impl=impl, - ) - - return cuex.RepArray(e.output, x) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py b/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py deleted file mode 100644 index 152e4eee..00000000 --- a/cuequivariance_jax/cuequivariance_jax/primitives/tensor_product.py +++ /dev/null @@ -1,83 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import warnings - -import jax -import jax.numpy as jnp - -import cuequivariance as cue -import cuequivariance_jax as cuex - -logger = logging.getLogger(__name__) - - -def tensor_product( - descriptors: list[tuple[cue.Operation, cue.SegmentedTensorProduct]], - inputs: list[jax.Array], - outputs_shape_dtype: list[jax.ShapeDtypeStruct], - indices: list[jax.Array | None] | None = None, - *, - math_dtype: jnp.dtype | None = None, - name: str | None = None, - impl: str = "auto", -) -> list[jax.Array]: - r"""Compute a polynomial described by a list of descriptors. - - Features: - - Calls a CUDA kernel if: - - STPs have a single mode which is a multiple of 32 (e.g. a channelwise tensor product that has subscripts ``u,u,,u`` with u=128) - - math data type is float32 or float64 - - in/out data type is a mix of float32, float64, float16 and bfloat16 - - indices are int32 - - Supports of infinite derivatives (JVP and tranpose rules maps to a single corresponding primitive) - - Limited support for batching (we cannot batch a buffer that has indices and if the batching is non trivial the performace will be bad) - - Automatic optimizations based on the symmetries of the STPs and on the repetition of the input buffers - - Automatic drop of unused buffers and indices - - Args: - descriptors (list of pairs): The list of descriptors. - Each descriptor is formed by a pair of :class:`cue.Operation ` and :class:`cue.SegmentedTensorProduct `. - inputs (list of jax.Array): The input buffers. - outputs_shape_dtype (list of jax.ShapeDtypeStruct): The output shapes and dtypes. - indices (list of jax.Array or None, optional): The optional indices of the inputs and outputs. - math_dtype (jnp.dtype, optional): The data type for computational operations. Defaults to None. - name (str, optional): The name of the operation. Defaults to None. - impl (str, optional): The implementation to use. Defaults to "auto". - If "auto", it will use the CUDA implementation if available, otherwise it will use the JAX implementation. - If "cuda", it will use the CUDA implementation. - If "jax", it will use the JAX implementation. - - Returns: - list of jax.Array: The result of the tensor product. - """ - warnings.warn( - "tensor_product is deprecated and will be removed in a future version. " - "Please use cuex.segmented_polynomial instead.", - DeprecationWarning, - stacklevel=2, - ) - if name is None: - name = "tensor_product" - - return cuex.segmented_polynomial( - cue.SegmentedPolynomial(len(inputs), len(outputs_shape_dtype), descriptors), - inputs, - outputs_shape_dtype, - indices, - math_dtype=math_dtype, - name=name, - impl=impl, - ) From 3fbdcb3b180b516c25d7ee950fbcc0c6802d2742 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 00:01:10 -0800 Subject: [PATCH 017/107] docstrings --- .../primitives/equivariant_polynomial.py | 44 +++++++------- .../primitives/segmented_polynomial.py | 59 ++++++++++++------- 2 files changed, 57 insertions(+), 46 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py index c206f372..791ba306 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py @@ -30,24 +30,28 @@ def equivariant_polynomial( math_dtype: jnp.dtype | None = None, name: str | None = None, impl: str = "auto", -) -> tuple[cuex.RepArray, ...] | cuex.RepArray: +) -> list[cuex.RepArray] | cuex.RepArray: """Compute an equivariant polynomial. Args: - poly (:class:`cue.EquivariantPolynomial `): The equivariant tensor product descriptor. - *inputs (RepArray or jax.Array): The input arrays. - indices (list of jax.Array or None, optional): The optional indices of the inputs and output. - output_batch_shape (tuple of int, optional): The batch shape of the output array. - output_dtype (jnp.dtype, optional): The data type for the output array. Defaults to None. - math_dtype (jnp.dtype, optional): The data type for computational operations. Defaults to None. - name (str, optional): The name of the operation. Defaults to None. + poly: The equivariant polynomial descriptor. + inputs: List of input :class:`cuex.RepArray `. + outputs_shape_dtype: Shape and dtype specifications for outputs. If None, + inferred from inputs when possible. When output indices are provided, this must be specified. + indices: Optional list of indices for inputs and outputs. Length must match + total number of operands (inputs + outputs). Use None for unindexed + operands. Defaults to None. + math_dtype: Data type for computational operations. If None, automatically + determined from input types. Defaults to None. + name: Optional name for the operation. Defaults to None. + impl: Implementation to use, one of ["auto", "cuda", "jax"]. If "auto", + uses CUDA when available, falling back to JAX otherwise. Defaults to "auto". Returns: - tuple of RepArray or RepArray: The output array(s). + Single :class:`cuex.RepArray ` if one output, or list of :class:`cuex.RepArray ` for multiple outputs. Examples: - - Let's create a descriptor for the spherical harmonics of degree 0, 1, and 2. + Create and compute spherical harmonics of degree 0, 1, and 2: >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) >>> e @@ -56,38 +60,30 @@ def equivariant_polynomial( │ a B ─── , sizes=3,9 num_segments=3,9 num_paths=3 ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=8 - We need some input data. + Basic usage with single input: >>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([0.0, 1.0, 0.0])) - >>> x - {0: 1} [0. 1. 0.] - - Now we can execute the equivariant tensor product. - >>> cuex.equivariant_polynomial(e, [x]) {0: 0+1+2} [1. ... ] - The `indices` argument allows to specify a list of optional int32 arrays for each input and for the output (`None` means no index and `indices[-1]` is the output index). The indices are used to select the elements of the input arrays and to specify the output index. - In the following example, we will index the output. The input has a batch shape of (3,) and the output has a batch shape of (2,). + Using indices: >>> i_out = jnp.array([0, 1, 1], dtype=jnp.int32) - - The `i_out` array is used to map the result to the output indices. - >>> with cue.assume(cue.SO3, cue.ir_mul): ... x = cuex.RepArray("1", jnp.array([ ... [0.0, 1.0, 0.0], ... [0.0, 0.0, 1.0], ... [1.0, 0.0, 0.0], ... ])) - >>> cuex.equivariant_polynomial( + >>> result = cuex.equivariant_polynomial( ... e, ... [x], ... [jax.ShapeDtypeStruct((2, e.outputs[0].dim), jnp.float32)], ... indices=[None, i_out], ... ) + >>> result {1: 0+1+2} [[ 1. ... ] [ 2. ... ]] @@ -169,4 +165,4 @@ def equivariant_polynomial( if poly.num_outputs == 1: return outputs[0] - return tuple(outputs) + return outputs diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index e54be3c6..5601f427 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -44,33 +44,48 @@ def segmented_polynomial( name: str | None = None, impl: str = "auto", ) -> list[jax.Array]: - r"""Compute a segmented polynomial. + """Compute a segmented polynomial. - Features: - - Calls a CUDA kernel if: - - STPs have a single mode which is a multiple of 32 (e.g. a channelwise tensor product that has subscripts ``u,u,,u`` with u=128) - - math data type is float32 or float64 - - in/out data type is a mix of float32, float64, float16 and bfloat16 - - indices are int32 - - Supports of infinite derivatives (JVP and tranpose rules maps to a single corresponding primitive) - - Limited support for batching (we cannot batch a buffer that has indices and if the batching is non trivial the performace will be bad) - - Automatic optimizations based on the symmetries of the STPs and on the repetition of the input buffers - - Automatic drop of unused buffers and indices + This function evaluates a segmented polynomial using either CUDA or JAX implementation. + The implementation choice is determined by the input characteristics and availability + of CUDA support. Args: - polynomial (cue.SegmentedPolynomial): The segmented polynomial to compute. - inputs (list of jax.Array): The input buffers. - outputs_shape_dtype (list of jax.ShapeDtypeStruct): The output shapes and dtypes. - indices (list of jax.Array or None, optional): The optional indices of the inputs and outputs. - math_dtype (jnp.dtype, optional): The data type for computational operations. Defaults to None. - name (str, optional): The name of the operation. Defaults to None. - impl (str, optional): The implementation to use. Defaults to "auto". - If "auto", it will use the CUDA implementation if available, otherwise it will use the JAX implementation. - If "cuda", it will use the CUDA implementation. - If "jax", it will use the JAX implementation. + polynomial: The segmented polynomial to compute. + inputs: List of input buffers as JAX arrays. + outputs_shape_dtype: List of output shapes and dtypes specifications. + indices: Optional list of indices for inputs and outputs. If None, no indexing + is applied. Defaults to None. + math_dtype: Data type for computational operations. If None, automatically + determined from input types, defaulting to float32 if no float64 inputs + are present. Defaults to None. + name: Optional name for the operation. Defaults to None. + impl: Implementation to use, one of ["auto", "cuda", "jax"]. If "auto", + uses CUDA when available, falling back to JAX otherwise. Defaults to "auto". Returns: - list of jax.Array: The result of the tensor product. + List of JAX arrays containing the computed tensor product results. + + Features: + - CUDA kernel activation conditions: + - STPs have a single mode which is a multiple of 32 (e.g. channelwise + tensor product with subscripts ``u,u,,u`` where u=128) + - Math data type is float32 or float64 + - Input/output data types can be float32, float64, float16, or bfloat16 + - Indices must be int32 + - Supports infinite derivatives through JVP and transpose rules + - Limited batching support: + - Cannot batch buffers with indices + - Non-trivial batching may impact performance + - Automatic optimizations: + - Based on STP symmetries + - Based on input buffer repetition patterns + - Automatic pruning of unused buffers and indices + + Note: + The function automatically determines the best implementation based on the + input characteristics when impl="auto". For maximum performance with CUDA-capable + hardware, ensure inputs match the CUDA kernel activation conditions. """ if name is None: From 6c3b7d6269fc2a1e55dced128cb89f36d09f886a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 01:04:32 -0800 Subject: [PATCH 018/107] Rename trivial method to eval_last_operand and improve docstrings This commit introduces several improvements: - Renamed the `trivial` method in `SegmentedPolynomial` to `eval_last_operand` for clarity - Added comprehensive docstrings to the `SegmentedPolynomial` class and its methods - Updated method signatures and documentation across multiple files to reflect the method rename - Improved documentation for various descriptors and operations - Enhanced type hints and example documentation --- .../cuequivariance/descriptors/irreps_tp.py | 10 ++--- .../cuequivariance/descriptors/rotations.py | 14 +++---- .../descriptors/spherical_harmonics_.py | 7 ++-- .../descriptors/symmetric_contractions.py | 14 +++---- .../descriptors/transposition.py | 2 +- .../cuequivariance/experimental/escn.py | 4 +- .../mace/symmetric_contractions.py | 3 +- .../cuequivariance/segmented_polynomial.py | 40 ++++++++++++++++++- .../operations/spherical_harmonics.py | 16 +++++--- .../segmented_polynomial_vanilla_impl.py | 21 +++++----- .../primitives/equivariant_tensor_product.py | 8 ++-- docs/api/cuequivariance.rst | 5 ++- docs/api/cuequivariance_jax.rst | 4 +- docs/index.rst | 6 ++- docs/tutorials/beta.rst | 5 ++- docs/tutorials/etp.rst | 4 +- docs/tutorials/stp.rst | 6 +-- 17 files changed, 107 insertions(+), 62 deletions(-) diff --git a/cuequivariance/cuequivariance/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/descriptors/irreps_tp.py index 6b6c7806..1b7a42f6 100644 --- a/cuequivariance/cuequivariance/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/descriptors/irreps_tp.py @@ -81,7 +81,7 @@ def fully_connected_tensor_product( cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -143,7 +143,7 @@ def full_tensor_product( cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -209,7 +209,7 @@ def channelwise_tensor_product( cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -293,7 +293,7 @@ def elementwise_tensor_product( cue.IrrepsAndLayout(irreps2, cue.ir_mul), cue.IrrepsAndLayout(irreps3, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -330,5 +330,5 @@ def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPoly cue.IrrepsAndLayout(irreps_in, cue.ir_mul), cue.IrrepsAndLayout(irreps_out, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) diff --git a/cuequivariance/cuequivariance/descriptors/rotations.py b/cuequivariance/cuequivariance/descriptors/rotations.py index f0801384..f42804fb 100644 --- a/cuequivariance/cuequivariance/descriptors/rotations.py +++ b/cuequivariance/cuequivariance/descriptors/rotations.py @@ -47,7 +47,7 @@ def fixed_axis_angle_rotation( cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -82,7 +82,7 @@ def yxy_rotation( cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(cbaio), + cue.SegmentedPolynomial.eval_last_operand(cbaio), ) @@ -105,7 +105,7 @@ def xy_rotation( cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(cbio), + cue.SegmentedPolynomial.eval_last_operand(cbio), ) @@ -128,7 +128,7 @@ def yx_rotation( cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(cbio), + cue.SegmentedPolynomial.eval_last_operand(cbio), ) @@ -197,7 +197,7 @@ def y_rotation( cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -229,7 +229,7 @@ def x_rotation( cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -249,5 +249,5 @@ def inversion(irreps: cue.Irreps) -> cue.EquivariantPolynomial: cue.IrrepsAndLayout(irreps, cue.ir_mul), cue.IrrepsAndLayout(irreps, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py index 3093bcc7..bd4b7070 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py @@ -24,8 +24,9 @@ def spherical_harmonics( ir_vec: cue.Irrep, ls: list[int], layout: cue.IrrepsLayout = cue.ir_mul ) -> cue.EquivariantPolynomial: - """ - subscripts: ``vector[],...,vector[],Yl[]`` + """Polynomial descriptor for the spherical harmonics. + + Subscripts: ``vector[],...,vector[],Yl[]`` Args: ir_vec (Irrep): irrep of the input vector, for example ``cue.SO3(1)``. @@ -35,7 +36,7 @@ def spherical_harmonics( Returns: :class:`cue.EquivariantPolynomial `: The descriptor. - Examples: + Example: >>> spherical_harmonics(cue.SO3(1), [0, 1, 2]) ╭ a=1 -> B=0+1+2 │ B ───── sizes=9 num_segments=9 num_paths=1 diff --git a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py index bab9a267..da1d0a87 100644 --- a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py @@ -21,24 +21,22 @@ def symmetric_contraction( irreps_out: cue.Irreps, degrees: list[int], ) -> cue.EquivariantPolynomial: - r""" - subscripts: ``weights[u],input[u],output[u]`` - - Construct the descriptor for a symmetric contraction. + """Construct the descriptor for a symmetric contraction. The symmetric contraction is a weighted sum of the input contracted with itself degree times. + Subscripts: ``weights[u],input[u],output[u]`` + Args: irreps_in (Irreps): The input irreps, the multiplicity are treated in parallel. irreps_out (Irreps): The output irreps. - degree (int): The degree of the symmetric contraction. + degrees (list[int]): List of degrees for the symmetric contractions. Returns: - :class:`cue.EquivariantPolynomial `: - The descriptor of the symmetric contraction. + EquivariantPolynomial: The descriptor of the symmetric contraction. The operands are the weights, the input degree times and the output. - Examples: + Example: >>> cue.descriptors.symmetric_contraction( ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), ... 16 * cue.Irreps("SO3", "0 + 1"), diff --git a/cuequivariance/cuequivariance/descriptors/transposition.py b/cuequivariance/cuequivariance/descriptors/transposition.py index 0cb4925c..e7191b65 100644 --- a/cuequivariance/cuequivariance/descriptors/transposition.py +++ b/cuequivariance/cuequivariance/descriptors/transposition.py @@ -33,5 +33,5 @@ def transpose( d.add_path(None, None, c=1, dims={"u": mul, "i": ir.dim}) return cue.EquivariantPolynomial( [cue.IrrepsAndLayout(irreps, source), cue.IrrepsAndLayout(irreps, target)], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) diff --git a/cuequivariance/cuequivariance/experimental/escn.py b/cuequivariance/cuequivariance/experimental/escn.py index 54ec4e24..1e6f43c3 100644 --- a/cuequivariance/cuequivariance/experimental/escn.py +++ b/cuequivariance/cuequivariance/experimental/escn.py @@ -110,7 +110,7 @@ def pr(mul_ir: cue.MulIrrep) -> bool: cue.IrrepsAndLayout(irreps_in, cue.ir_mul), cue.IrrepsAndLayout(irreps_out, cue.ir_mul), ], - cue.SegmentedPolynomial.trivial(d), + cue.SegmentedPolynomial.eval_last_operand(d), ) @@ -178,7 +178,7 @@ def escn_tp_compact( d = d.normalize_paths_for_operand(2) # TODO: return an EquivariantPolynomial using SphericalSignal - return cue.SegmentedPolynomial.trivial(d) + return cue.SegmentedPolynomial.eval_last_operand(d) class SphericalSignal(cue.Rep): diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py index f648e31f..c01597f6 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py @@ -40,7 +40,8 @@ def symmetric_contraction( w = jax.random.normal(jax.random.key(0), (p.shape[0], mul)) w = jnp.einsum("au,ab->bu", w, p).flatten() - cuex.equivariant_tensor_product(e, w, cuex.randn(jax.random.key(1), e.inputs[1])) + x = cuex.randn(jax.random.key(1), e.inputs[1]) + y = cuex.equivariant_polynomial(e, [w, x]) """ assert min(degrees) > 0 diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index f4a10faf..b82fca20 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -27,6 +27,24 @@ @dataclasses.dataclass(init=False, frozen=True) class SegmentedPolynomial: + """A polynomial representation using segmented tensor products. + + This class represents a polynomial using a collection of segmented tensor products, where each product + is associated with an operation that specifies how inputs are combined. The polynomial maps a set of + input tensors to output tensors through these tensor products. + + Args: + num_inputs (int): Number of input tensors. + num_outputs (int): Number of output tensors. + tensor_products (list of tuple of Operation and SegmentedTensorProduct): List of operation and tensor product pairs + that define the polynomial transformation. + + Example: + >>> # Create a polynomial with 2 inputs and 1 output + >>> poly = SegmentedPolynomial(2, 1, [(op1, stp1), (op2, stp2)]) + >>> outputs = poly(input1, input2) # Evaluate polynomial on inputs (numpy reference implementation) + """ + num_inputs: int num_outputs: int tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] @@ -42,7 +60,7 @@ def __init__( object.__setattr__(self, "tensor_products", sorted(tensor_products)) @classmethod - def trivial(cls, stp: cue.SegmentedTensorProduct): + def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): return cls( stp.num_operands - 1, 1, @@ -153,10 +171,12 @@ def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: @property def num_operands(self) -> int: + """Number of operands in the polynomial.""" return self.num_inputs + self.num_outputs @property def buffer_sizes(self) -> list[int | None]: + """Sizes of the buffers in the polynomial.""" sizes = [None] * (self.num_inputs + self.num_outputs) for ope, stp in self.tensor_products: for buffer, operand in zip(ope.buffers, stp.operands): @@ -170,10 +190,12 @@ def buffer_sizes(self) -> list[int | None]: @property def input_sizes(self) -> list[int | None]: + """Sizes of the input buffers in the polynomial.""" return self.buffer_sizes[: self.num_inputs] @property def output_sizes(self) -> list[int | None]: + """Sizes of the output buffers in the polynomial.""" return self.buffer_sizes[self.num_inputs :] def map_tensor_products( @@ -192,6 +214,7 @@ def map_tensor_products( ) def fuse_stps(self) -> SegmentedPolynomial: + """Fuse segmented tensor products with identical operations and operands.""" groups = itertools.groupby( self.tensor_products, key=lambda x: (x[0], x[1].operands, x[1].coefficient_subscripts), @@ -212,6 +235,8 @@ def fuse_stps(self) -> SegmentedPolynomial: ) def consolidate(self) -> SegmentedPolynomial: + """Consolidate the segmented tensor products.""" + def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): stp = ( stp.consolidate_modes() @@ -227,6 +252,7 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): return self.fuse_stps().map_tensor_products(f) def used_buffers(self) -> list[int]: + """Buffers used in the polynomial. (List of integers)""" return sorted( set( itertools.chain.from_iterable( @@ -236,12 +262,14 @@ def used_buffers(self) -> list[int]: ) def buffer_used(self) -> list[bool]: + """Buffers used in the polynomial. (List of boolean values)""" return [ any(buffer in ope.buffers for ope, _ in self.tensor_products) for buffer in range(self.num_inputs + self.num_outputs) ] def remove_unused_buffers(self) -> SegmentedPolynomial: + """Remove unused buffers from the polynomial.""" used = self.buffer_used() new_index = [] i = 0 @@ -265,6 +293,7 @@ def remove_unused_buffers(self) -> SegmentedPolynomial: def stack( cls, polys: list[SegmentedPolynomial], stacked: list[bool] ) -> SegmentedPolynomial: + """Stack segmented polynomials together.""" assert len(polys) > 0 num_inputs = polys[0].num_inputs num_outputs = polys[0].num_outputs @@ -286,6 +315,7 @@ def stack( return cls(num_inputs, num_outputs, tensor_products) def squeeze_modes(self) -> SegmentedPolynomial: + """Squeeze the modes of the segmented tensor products.""" return SegmentedPolynomial( self.num_inputs, self.num_outputs, @@ -293,6 +323,7 @@ def squeeze_modes(self) -> SegmentedPolynomial: ) def flatten_coefficient_modes(self) -> SegmentedPolynomial: + """Flatten the coefficient modes of the segmented tensor products.""" return SegmentedPolynomial( self.num_inputs, self.num_outputs, @@ -303,6 +334,7 @@ def flatten_coefficient_modes(self) -> SegmentedPolynomial: ) def jvp(self, has_tangent: list[bool]) -> SegmentedPolynomial: + """Compute the Jacobian-vector product of the polynomial.""" assert len(has_tangent) == self.num_inputs new_tps = [] @@ -322,6 +354,7 @@ def transpose( is_undefined_primal: list[bool], has_cotangent: list[bool], ) -> SegmentedPolynomial: + """Transpose the polynomial.""" assert len(is_undefined_primal) == self.num_inputs assert len(has_cotangent) == self.num_outputs @@ -339,6 +372,7 @@ def transpose( def backward( self, requires_gradient: list[bool], has_cotangent: list[bool] ) -> SegmentedPolynomial: + """Compute the backward pass of the polynomial.""" return self.jvp(requires_gradient).transpose( is_undefined_primal=[False] * self.num_inputs + [True] * sum(requires_gradient), @@ -346,6 +380,7 @@ def backward( ) def flops(self, batch_size: int = 1) -> int: + """Compute the number of floating point operations in the polynomial.""" n = 0 for ope, stp in self.tensor_products: oid, _ = ope.output_operand_buffer(self.num_inputs) @@ -353,6 +388,7 @@ def flops(self, batch_size: int = 1) -> int: return batch_size * n def memory(self, batch_sizes: list[int]) -> int: + """Compute the memory usage of the polynomial.""" assert len(batch_sizes) == self.num_operands return sum(Z * size for Z, size in zip(batch_sizes, self.buffer_sizes)) @@ -372,6 +408,8 @@ def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: return segments def sort_indices_for_identical_operands(self) -> SegmentedPolynomial: + """Sort the indices of the segmented tensor products for identical operands.""" + def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): for set_of_operands in ope.operands_with_identical_buffers(): stp = stp.sort_indices_for_identical_operands(set_of_operands) diff --git a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py b/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py index 8749abe7..3de6fadd 100644 --- a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py +++ b/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py @@ -28,14 +28,20 @@ def spherical_harmonics( ) -> cuex.RepArray: """Compute the spherical harmonics of a vector. + The spherical harmonics are polynomials of an input vector. This function computes the polynomials of the specified degrees. + Args: - ls (list of int): List of spherical harmonic degrees. - vector (RepArray): Input vector(s). - normalize (bool): Whether to normalize the vector before computing the spherical harmonics. - algorithm (str): Algorithm to use for the tensor product. See :class:`cuex.tensor_product ` for more information. + ls (list[int]): List of spherical harmonic degrees. Each degree must be non-negative. + vector (RepArray): Input vector. Must be a single vector (multiplicity 1) with 3 components. + normalize (bool, optional): Whether to normalize the vector before computing the spherical harmonics. Defaults to True. Returns: - RepArray: Spherical harmonics of the vector. + RepArray: The spherical harmonics of the vector, containing the polynomials of each specified degree. + + Example: + >>> import cuequivariance_jax as cuex + >>> vector = cuex.randn(jax.random.key(0), cue.Irreps(cue.SO3, "1o")) + >>> harmonics = spherical_harmonics([0, 1, 2], vector) """ ls = list(ls) assert vector.is_irreps_array() diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py index 4eba9e71..4622cacd 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py @@ -201,10 +201,7 @@ def ein( def prepare(): if not d.all_same_segment_shape(): - raise ValueError( - "cuex.tensor_product: all operands must have the same segment shape\n" - + str(d) - ) + raise ValueError("all operands must have the same segment shape\n" + str(d)) reshaped_inputs = [ jnp.reshape( input, input.shape[:-1] + (ope.num_segments,) + ope.segment_shape @@ -216,7 +213,7 @@ def prepare(): return reshaped_inputs, indices, coefficients if algorithm == "stacked": - logger.debug(f"cuex.tensor_product: {d} with stacked strategy") + logger.debug(f"{d} with stacked strategy") reshaped_inputs, indices, coefficients = prepare() return [ @@ -234,7 +231,7 @@ def prepare(): ] elif algorithm == "compact_stacked": - logger.debug(f"cuex.tensor_product: {d} with compact_stacked strategy") + logger.debug(f"{d} with compact_stacked strategy") reshaped_inputs, indices, coefficients = prepare() return [ @@ -256,7 +253,7 @@ def prepare(): ] elif algorithm == "indexed_vmap": - logger.debug(f"cuex.tensor_product: {d} with indexed_vmap strategy") + logger.debug(f"{d} with indexed_vmap strategy") reshaped_inputs, indices, coefficients = prepare() return ( @@ -279,7 +276,7 @@ def prepare(): ) elif algorithm == "indexed_compact": - logger.debug(f"cuex.tensor_product: {d} with indexed_compact strategy") + logger.debug(f"{d} with indexed_compact strategy") reshaped_inputs, indices, coefficients = prepare() return ( @@ -303,7 +300,7 @@ def prepare(): ) elif algorithm == "indexed_for_loop": - logger.debug(f"cuex.tensor_product: {d} with indexed_for_loop strategy") + logger.debug(f"{d} with indexed_for_loop strategy") reshaped_inputs, indices, coefficients = prepare() def body(pid: int, output: jax.Array) -> jax.Array: @@ -330,7 +327,7 @@ def body(pid: int, output: jax.Array) -> jax.Array: ) elif algorithm == "sliced": - logger.debug(f"cuex.tensor_product: {d} with sliced strategy") + logger.debug(f"{d} with sliced strategy") slices = [operand.segment_slices() for operand in d.operands] return [ @@ -356,7 +353,7 @@ def body(pid: int, output: jax.Array) -> jax.Array: ] elif algorithm == "no-op": - warnings.warn(f"cuex.tensor_product: {d} skipping computation!!!") + warnings.warn(f"{d} skipping computation!!!") dummy = sum([jnp.sum(input) for input in inputs]) @@ -369,4 +366,4 @@ def body(pid: int, output: jax.Array) -> jax.Array: for pid_start, pid_end in zip(pids[:-1], pids[1:]) ] - raise NotImplementedError(f"cuex.tensor_product: unknown algorithm {algorithm}") + raise NotImplementedError(f"unknown algorithm {algorithm}") diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 0ce3c481..51e50bc7 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -160,17 +160,17 @@ def __init__( # TODO: remove this when re-design if isinstance(e, cue.EquivariantPolynomial): assert e.num_outputs == 1 - for ope, stp in e.tensor_products: + for ope, stp in e.polynomial.tensor_products: inputs = list(range(e.num_inputs)) output = e.num_inputs - expected = ( + expected = tuple( inputs[: stp.num_operands - 1] + [inputs[-1]] * max(0, stp.num_operands - e.num_operands) + [output] ) - assert ope.buffers == expected + assert ope.buffers == expected, f"{ope.buffers} != {expected}" e = cue.EquivariantTensorProduct( - [stp for _, stp in e.tensor_products], e.inputs + e.outputs + [stp for _, stp in e.polynomial.tensor_products], e.inputs + e.outputs ) if not isinstance(layout_in, tuple): diff --git a/docs/api/cuequivariance.rst b/docs/api/cuequivariance.rst index e374962c..7d6f1af9 100644 --- a/docs/api/cuequivariance.rst +++ b/docs/api/cuequivariance.rst @@ -41,7 +41,7 @@ Group Representations Equivariant Tensor Products --------------------------- -These classes represent tensor products. +These classes represent tensor products and polynomials. .. autosummary:: :toctree: generated/ @@ -51,8 +51,9 @@ These classes represent tensor products. IrrepsLayout IrrepsAndLayout SegmentedTensorProduct - EquivariantTensorProduct Operation + SegmentedPolynomial + EquivariantPolynomial Descriptors ----------- diff --git a/docs/api/cuequivariance_jax.rst b/docs/api/cuequivariance_jax.rst index 84db94a3..590bb47a 100644 --- a/docs/api/cuequivariance_jax.rst +++ b/docs/api/cuequivariance_jax.rst @@ -44,8 +44,8 @@ Tensor Products :toctree: generated/ :template: function_template.rst - equivariant_tensor_product - tensor_product + equivariant_polynomial + segmented_polynomial Extra Modules ------------- diff --git a/docs/index.rst b/docs/index.rst index ad7baab1..39ef2d0f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,13 +40,15 @@ The easiest way to install cuEquivariance is from `PyPi `_ us pip install cuequivariance # Installs only the core non-ML components # CUDA kernels for different CUDA versions + pip install cuequivariance-ops-jax-cu12 pip install cuequivariance-ops-torch-cu11 pip install cuequivariance-ops-torch-cu12 Requirements ------------ -``cuequivariance-ops-torch-*`` packages are only available for Linux x86_64 and require PyTorch 2.4.0 or later. + - ``cuequivariance-ops-torch-*`` packages are only available for Linux x86_64 and require PyTorch 2.4.0 or later. + - ``cuequivariance-ops-jax-cu12`` package is only available for Linux x86_64 and requires JAX 0.5.0 or later. Organization ------------ @@ -70,7 +72,7 @@ cuEquivariance is split into three packages: :align: center Most tensor products are defined using the :class:`cue.EquivariantTensorProduct ` class, which encapsulates the :class:`cue.Irreps ` and :class:`cue.IrrepsLayout ` for each input and the output. It also includes one or more instances of :class:`cue.SegmentedTensorProduct `, which define the tensor product operations. -This descriptor is then used to create a :class:`cuet.EquivariantTensorProduct ` module, which can be used in PyTorch models. Or used to execute the tensor product operations using :class:`cuex.equivariant_tensor_product ` in JAX. +This descriptor is then used to create a :class:`cuet.EquivariantTensorProduct ` module, which can be used in PyTorch models. Or used to execute the tensor product operations using :class:`cuex.equivariant_polynomial ` in JAX. Tutorials --------- diff --git a/docs/tutorials/beta.rst b/docs/tutorials/beta.rst index 3f6722d6..3e473523 100644 --- a/docs/tutorials/beta.rst +++ b/docs/tutorials/beta.rst @@ -42,7 +42,7 @@ The segmented tensor product with 3 or 4 operands with one mode can be executed .squeeze_modes() .flatten_coefficient_modes() ) - print(e.ds[0]) + print(e) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") m = cuet.EquivariantTensorProduct(e, layout=cue.ir_mul, device=device) @@ -63,7 +63,8 @@ Again for segmented tensor product with 3 or 4 operands with one mode, we can us ) if device.type == "cuda": - m = TensorProductUniform4x1dIndexed(e.ds[0], device, torch.float32) + d = e.polynomial.tensor_products[0][1] + m = TensorProductUniform4x1dIndexed(d, device, torch.float32) x0 = torch.randn(16, e.inputs[0].dim, device=device) i0 = torch.randint(0, 16, (128,), device=device) diff --git a/docs/tutorials/etp.rst b/docs/tutorials/etp.rst index 08c15efb..0ef852a4 100644 --- a/docs/tutorials/etp.rst +++ b/docs/tutorials/etp.rst @@ -69,10 +69,10 @@ Execution on JAX w = cuex.randn(jax.random.key(0), e.inputs[0]) x = cuex.randn(jax.random.key(1), e.inputs[1]) - cuex.equivariant_tensor_product(e, w, x) + cuex.equivariant_polynomial(e, [w, x]) The function :func:`cuex.randn ` generates random :class:`cuex.RepArray ` objects. -The function :func:`cuex.equivariant_tensor_product ` executes the tensor product. +The function :func:`cuex.equivariant_polynomial ` executes the tensor product. The output is a :class:`cuex.RepArray ` object. diff --git a/docs/tutorials/stp.rst b/docs/tutorials/stp.rst index 65d63df4..28e35a22 100644 --- a/docs/tutorials/stp.rst +++ b/docs/tutorials/stp.rst @@ -175,15 +175,15 @@ Now we are verifying that the output is well normalized. -In JAX, we can use the :func:`cuex.tensor_product ` function. +In JAX, we can use the :func:`cuex.segmented_polynomial ` function. .. jupyter-execute:: w = jax.random.normal(jax.random.key(0), (d.operands[0].size,)) x1 = jax.random.normal(jax.random.key(1), (3000, irreps1.dim)) - [x2] = cuex.tensor_product( - [(cue.Operation([0, 1, 2]), d)], + [x2] = cuex.segmented_polynomial( + cue.SegmentedPolynomial.eval_last_operand(d), [w, x1], [jax.ShapeDtypeStruct((3000, irreps2.dim), jnp.float32)], ) From d06d0d130e90c618f57e11af10468f69125c7a11 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 03:50:44 -0800 Subject: [PATCH 019/107] Add comprehensive tests for SegmentedPolynomial --- .../cuequivariance/segmented_polynomial.py | 2 +- .../tests/segmented_polynomial_test.py | 183 ++++++++++++++++++ 2 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 cuequivariance/tests/segmented_polynomial_test.py diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index b82fca20..25d1896d 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -312,7 +312,7 @@ def stack( for p in polys[index + 1 :]: stp.insert_segments(oid, -1, p.buffer_segments(buffer)) tensor_products.append((ope, stp)) - return cls(num_inputs, num_outputs, tensor_products) + return cls(num_inputs, num_outputs, tensor_products).consolidate() def squeeze_modes(self) -> SegmentedPolynomial: """Squeeze the modes of the segmented tensor products.""" diff --git a/cuequivariance/tests/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomial_test.py new file mode 100644 index 00000000..4360f7d9 --- /dev/null +++ b/cuequivariance/tests/segmented_polynomial_test.py @@ -0,0 +1,183 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +import cuequivariance as cue + + +def make_simple_stp() -> cue.SegmentedTensorProduct: + d = cue.SegmentedTensorProduct.empty_segments([2, 2, 2]) + d.add_path(0, 0, 0, c=1.0) + d.add_path(1, 1, 1, c=-2.0) + return d + + +def test_init_segmented_polynomial(): + """Test initialization of SegmentedPolynomial.""" + stp = make_simple_stp() + op = cue.Operation((0, 1, 2)) + poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + + assert poly.num_inputs == 2 + assert poly.num_outputs == 1 + assert poly.num_operands == 3 + assert len(poly.tensor_products) == 1 + assert poly.tensor_products[0] == (op, stp) + + +def test_polynomial_equality(): + """Test equality comparison of polynomials.""" + stp1 = make_simple_stp() + stp2 = make_simple_stp() + op1 = cue.Operation((0, 1, 2)) + op2 = cue.Operation((0, 1, 2)) + + poly1 = cue.SegmentedPolynomial(2, 1, [(op1, stp1)]) + poly2 = cue.SegmentedPolynomial(2, 1, [(op2, stp2)]) + poly3 = cue.SegmentedPolynomial(2, 1, [(op2, 2 * stp2)]) + + assert poly1 == poly2 + assert poly1 != poly3 + assert poly1 < poly3 # Test less than operator + + +def test_call_function(): + """Test calling the polynomial as a function.""" + # Create a simple bilinear form: f(a, b) = a^T * b + # For this specific test, we need a particular structure + stp = cue.SegmentedTensorProduct.from_subscripts("i,j,k+ijk") + i0 = stp.add_segment(0, (3,)) + i1 = stp.add_segment(1, (3,)) + i2 = stp.add_segment(2, (1,)) + stp.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) + + op = cue.Operation((0, 1, 2)) + poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + + # Test evaluation + a = np.array([1.0, 2.0, 3.0]) + b = np.array([4.0, 5.0, 6.0]) + + [result] = poly(a, b) + expected = np.array([a.dot(b)]) # Dot product + + assert np.allclose(result, expected) + + +def test_buffer_properties(): + """Test properties related to buffer sizes and usage.""" + stp1 = make_simple_stp() + op1 = cue.Operation((0, 1, 2)) + + # Create a second STP with different structure for testing multiple buffers + stp2 = cue.SegmentedTensorProduct.empty_segments([2, 1]) + stp2.add_path(0, 0, c=1.0) + op2 = cue.Operation((0, 3)) + + poly = cue.SegmentedPolynomial(2, 2, [(op1, stp1), (op2, stp2)]) + + # Test buffer properties + assert poly.buffer_sizes == [2, 2, 2, 1] + assert poly.input_sizes == [2, 2] + assert poly.output_sizes == [2, 1] + + assert poly.used_buffers() == [0, 1, 2, 3] + assert poly.buffer_used() == [True, True, True, True] + + +def test_remove_unused_buffers(): + """Test removing unused buffers from the polynomial.""" + stp = make_simple_stp() + # Use operation that doesn't use buffer 1 + op = cue.Operation((0, 2, 3)) # Note: buffer 1 is not used + + poly = cue.SegmentedPolynomial(3, 1, [(op, stp)]) + + # Buffer 1 is not used + assert poly.buffer_used() == [True, False, True, True] + + # Remove unused buffer + cleaned_poly = poly.remove_unused_buffers() + + assert cleaned_poly.num_inputs == 2 + assert cleaned_poly.num_outputs == 1 + assert cleaned_poly.buffer_used() == [True, True, True] + + +def test_consolidate(): + """Test consolidating tensor products.""" + stp1 = make_simple_stp() + stp2 = make_simple_stp() + + op = cue.Operation((0, 1, 2)) + + # Create a polynomial with duplicate operations + poly = cue.SegmentedPolynomial(2, 1, [(op, stp1), (op, stp2)]) + + # Consolidate the polynomial + consolidated = poly.consolidate() + + # Should have fused the two tensor products + assert len(consolidated.tensor_products) == 1 + # Coefficients should have been combined for each path + assert len(consolidated.tensor_products[0][1].paths) == 2 + # The coefficients should have been added + assert consolidated.tensor_products[0][1].paths[0].coefficients == 2.0 + assert consolidated.tensor_products[0][1].paths[1].coefficients == -4.0 + + +def test_stack(): + """Test stacking polynomials.""" + # Create two simple polynomials using make_simple_stp + stp = make_simple_stp() + op1 = cue.Operation((0, 1, 2)) + poly1 = cue.SegmentedPolynomial(2, 1, [(op1, stp)]) + + stp2 = make_simple_stp() + op2 = cue.Operation((0, 1, 2)) + poly2 = cue.SegmentedPolynomial(2, 1, [(op2, stp2)]) + + # Stack the polynomials with the output being stacked + stacked = cue.SegmentedPolynomial.stack([poly1, poly2], [False, False, True]) + + assert stacked.num_inputs == 2 + assert stacked.num_outputs == 1 + + assert stacked.buffer_sizes == [2, 2, 4] + + [(_, stp)] = stacked.tensor_products + assert stp.operands[0].num_segments == 2 + assert stp.operands[1].num_segments == 2 + assert stp.operands[2].num_segments == 4 + assert stp.num_paths == 4 + assert stp.paths[0].indices == (0, 0, 0) + assert stp.paths[1].indices == (0, 0, 2 + 0) + assert stp.paths[2].indices == (1, 1, 1) + assert stp.paths[3].indices == (1, 1, 2 + 1) + + +def test_flops_and_memory(): + """Test computation of FLOPS and memory usage.""" + stp = make_simple_stp() + op = cue.Operation((0, 1, 2)) + poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + + # Test FLOPS calculation + flops = poly.flops(batch_size=100) + assert flops > 0 + + # Test memory calculation + memory = poly.memory([100, 100, 100]) + assert memory == 100 * (2 + 2 + 2) # All operands have size 2 From 5da9160015505b6e7b4743daad97a6ae9646d5c7 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 03:55:08 -0800 Subject: [PATCH 020/107] Add JVP (Jacobian-vector product) tests for SegmentedPolynomial --- .../tests/segmented_polynomial_test.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/cuequivariance/tests/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomial_test.py index 4360f7d9..ae67f4e4 100644 --- a/cuequivariance/tests/segmented_polynomial_test.py +++ b/cuequivariance/tests/segmented_polynomial_test.py @@ -181,3 +181,50 @@ def test_flops_and_memory(): # Test memory calculation memory = poly.memory([100, 100, 100]) assert memory == 100 * (2 + 2 + 2) # All operands have size 2 + + +def test_jvp(): + """Test Jacobian-vector product computation.""" + # Create a simple polynomial for testing: f(x,y) = x^T * y (dot product) + stp = cue.SegmentedTensorProduct.from_subscripts("i,j,k+ijk") + i0 = stp.add_segment(0, (3,)) + i1 = stp.add_segment(1, (3,)) + i2 = stp.add_segment(2, (1,)) + stp.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) + + op = cue.Operation((0, 1, 2)) + poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + + # Input values + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + + # Tangent vectors (directions for differentiation) + x_tangent = np.array([0.1, 0.2, 0.3]) + y_tangent = np.array([0.4, 0.5, 0.6]) + + # Create the JVP polynomial for both inputs having tangents + jvp_poly = poly.jvp([True, True]) + + # When both inputs have tangents, we need to concatenate inputs and tangents + # The JVP polynomial expects inputs followed by their respective tangents + jvp_result = jvp_poly(x, y, x_tangent, y_tangent) + + # For the dot product function f(x,y) = x^T * y: + # The Jacobian w.r.t x is y^T, and the Jacobian w.r.t y is x^T + # So Jvp = y^T * x_tangent + x^T * y_tangent + expected_jvp = np.array([y.dot(x_tangent) + x.dot(y_tangent)]) + + assert np.allclose(jvp_result[0], expected_jvp) + + # Test with only x having a tangent + jvp_x_only = poly.jvp([True, False]) + x_only_result = jvp_x_only(x, y, x_tangent) + expected_x_only = np.array([y.dot(x_tangent)]) + assert np.allclose(x_only_result[0], expected_x_only) + + # Test with only y having a tangent + jvp_y_only = poly.jvp([False, True]) + y_only_result = jvp_y_only(x, y, y_tangent) + expected_y_only = np.array([x.dot(y_tangent)]) + assert np.allclose(y_only_result[0], expected_y_only) From 05b9c251e34c8181d82318f75731e9cb6d990267 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 04:05:00 -0800 Subject: [PATCH 021/107] add test --- .../tests/segmented_polynomial_test.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/cuequivariance/tests/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomial_test.py index ae67f4e4..32244de4 100644 --- a/cuequivariance/tests/segmented_polynomial_test.py +++ b/cuequivariance/tests/segmented_polynomial_test.py @@ -228,3 +228,66 @@ def test_jvp(): y_only_result = jvp_y_only(x, y, y_tangent) expected_y_only = np.array([x.dot(y_tangent)]) assert np.allclose(y_only_result[0], expected_y_only) + + +def test_transpose_linear(): + """Test transposing a linear polynomial.""" + # Create a linear polynomial f(x, y) = Ax where A is a matrix + # Here we use f(x, y) = x^T * y (dot product) + # This is linear in both x and y + stp = cue.SegmentedTensorProduct.from_subscripts("i,j,k+ijk") + i0 = stp.add_segment(0, (3,)) + i1 = stp.add_segment(1, (3,)) + i2 = stp.add_segment(2, (1,)) + stp.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) + + op = cue.Operation((0, 1, 2)) + poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + + # Input values + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + + # x dot y = 1*4 + 2*5 + 3*6 = 32 + + # Cotangent for the output + cotangent = np.array([2.0]) + + # Test transposing with respect to x (x is undefined primal) + # is_undefined_primal = [True, False] means x is undefined, y is defined + # has_cotangent = [True] means the output has a cotangent + transpose_x = poly.transpose( + is_undefined_primal=[True, False], has_cotangent=[True] + ) + + # The transpose polynomial should compute the gradient of the output w.r.t x + # For f(x, y) = x^T * y, the gradient w.r.t x is y + # So transpose_x(y, cotangent) should be y * cotangent + x_result = transpose_x(y, cotangent) + expected_x_result = y * cotangent[0] + assert np.allclose(x_result[0], expected_x_result) + + # Test transposing with respect to y (y is undefined primal) + transpose_y = poly.transpose( + is_undefined_primal=[False, True], has_cotangent=[True] + ) + + # For f(x, y) = x^T * y, the gradient w.r.t y is x + # So transpose_y(x, cotangent) should be x * cotangent + y_result = transpose_y(x, cotangent) + expected_y_result = x * cotangent[0] + assert np.allclose(y_result[0], expected_y_result) + + +def test_transpose_nonlinear(): + """Test transposing a non-linear polynomial raises an error.""" + # Create a non-linear polynomial + stp = make_simple_stp() + op = cue.Operation((0, 0, 1)) # Note: using the same buffer twice (x^2) + poly = cue.SegmentedPolynomial(1, 1, [(op, stp)]) + + # Try to transpose the non-linear polynomial + # This should raise a ValueError since there are multiple undefined primals + # (the same input buffer is used twice) + with np.testing.assert_raises(ValueError): + poly.transpose(is_undefined_primal=[True], has_cotangent=[True]) From 2bc988ac39d06dbec53c65389e8774c4e85dd900 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 04:09:08 -0800 Subject: [PATCH 022/107] add test for backward --- .../tests/segmented_polynomial_test.py | 72 ++++++++++++++++--- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/cuequivariance/tests/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomial_test.py index 32244de4..c5e60f02 100644 --- a/cuequivariance/tests/segmented_polynomial_test.py +++ b/cuequivariance/tests/segmented_polynomial_test.py @@ -24,6 +24,15 @@ def make_simple_stp() -> cue.SegmentedTensorProduct: return d +def make_simple_dot_product_stp() -> cue.SegmentedTensorProduct: + d = cue.SegmentedTensorProduct.from_subscripts("i,j,k+ijk") + i0 = d.add_segment(0, (3,)) + i1 = d.add_segment(1, (3,)) + i2 = d.add_segment(2, (1,)) + d.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) + return d + + def test_init_segmented_polynomial(): """Test initialization of SegmentedPolynomial.""" stp = make_simple_stp() @@ -186,11 +195,7 @@ def test_flops_and_memory(): def test_jvp(): """Test Jacobian-vector product computation.""" # Create a simple polynomial for testing: f(x,y) = x^T * y (dot product) - stp = cue.SegmentedTensorProduct.from_subscripts("i,j,k+ijk") - i0 = stp.add_segment(0, (3,)) - i1 = stp.add_segment(1, (3,)) - i2 = stp.add_segment(2, (1,)) - stp.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) + stp = make_simple_dot_product_stp() op = cue.Operation((0, 1, 2)) poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) @@ -235,11 +240,7 @@ def test_transpose_linear(): # Create a linear polynomial f(x, y) = Ax where A is a matrix # Here we use f(x, y) = x^T * y (dot product) # This is linear in both x and y - stp = cue.SegmentedTensorProduct.from_subscripts("i,j,k+ijk") - i0 = stp.add_segment(0, (3,)) - i1 = stp.add_segment(1, (3,)) - i2 = stp.add_segment(2, (1,)) - stp.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) + stp = make_simple_dot_product_stp() op = cue.Operation((0, 1, 2)) poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) @@ -291,3 +292,54 @@ def test_transpose_nonlinear(): # (the same input buffer is used twice) with np.testing.assert_raises(ValueError): poly.transpose(is_undefined_primal=[True], has_cotangent=[True]) + + +def test_backward(): + """Test the backward method for gradient computation.""" + # Create a linear polynomial for testing: f(x,y) = x^T * y (dot product) + stp = make_simple_dot_product_stp() + op = cue.Operation((0, 1, 2)) + poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + + # Input values + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + + # Cotangent for the output (upstream gradient) + cotangent = np.array([2.0]) + + # Test backward with respect to both x and y + backward_both = poly.backward(requires_gradient=[True, True], has_cotangent=[True]) + + # The backward polynomial computes gradients for all inputs that require gradients + # For f(x,y) = x^T * y: + # - gradient w.r.t x is y * cotangent + # - gradient w.r.t y is x * cotangent + grad_x, grad_y = backward_both(x, y, cotangent) + expected_grad_x = y * cotangent[0] + expected_grad_y = x * cotangent[0] + + assert np.allclose(grad_x, expected_grad_x) + assert np.allclose(grad_y, expected_grad_y) + + # Test backward with respect to only x + backward_x = poly.backward(requires_gradient=[True, False], has_cotangent=[True]) + + # Should only compute gradient for x + [grad_x_only] = backward_x(x, y, cotangent) + assert np.allclose(grad_x_only, expected_grad_x) + + # Test backward with respect to only y + backward_y = poly.backward(requires_gradient=[False, True], has_cotangent=[True]) + + # Should only compute gradient for y + [grad_y_only] = backward_y(x, y, cotangent) + assert np.allclose(grad_y_only, expected_grad_y) + + # Test with zero cotangent + zero_cotangent = np.array([0.0]) + grad_x_zero, grad_y_zero = backward_both(x, y, zero_cotangent) + + # With zero cotangent, gradients should be zero + assert np.allclose(grad_x_zero, np.zeros_like(x)) + assert np.allclose(grad_y_zero, np.zeros_like(y)) From 7b547c9b538fba39a17a46e5f05eacf8f9615378 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 04:19:52 -0800 Subject: [PATCH 023/107] docstrings --- .../cuequivariance/equivariant_polynomial.py | 124 ++++++++++++++++++ .../cuequivariance/segmented_polynomial.py | 5 - 2 files changed, 124 insertions(+), 5 deletions(-) diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/equivariant_polynomial.py index 7ab330e1..390eff7b 100644 --- a/cuequivariance/cuequivariance/equivariant_polynomial.py +++ b/cuequivariance/cuequivariance/equivariant_polynomial.py @@ -23,6 +23,18 @@ @dataclasses.dataclass(init=False, frozen=True) class EquivariantPolynomial: + """A polynomial representation with equivariance constraints. + + This class extends SegmentedPolynomial by incorporating information about the group representations + of each input and output tensor. It ensures that operations performed by the polynomial respect + the equivariance constraints defined by these representations, making it suitable for building + equivariant neural networks. + + Args: + operands (list[cue.Rep]): Group representations for all operands (inputs and outputs). + polynomial (cue.SegmentedPolynomial): The underlying polynomial transformation. + """ + operands: tuple[cue.Rep, ...] polynomial: cue.SegmentedPolynomial @@ -72,38 +84,69 @@ def __repr__(self): return self.polynomial.to_string([f"{rep}" for rep in self.operands]) def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: + """Evaluate the polynomial on the given inputs. + + Args: + *inputs (np.ndarray): Input tensors to evaluate the polynomial on. + + Returns: + list[np.ndarray]: Output tensors resulting from the polynomial evaluation. + """ return self.polynomial(*inputs) @property def num_operands(self) -> int: + """The total number of operands (inputs and outputs) in the polynomial.""" return len(self.operands) @property def num_inputs(self) -> int: + """The number of input tensors expected by the polynomial.""" return self.polynomial.num_inputs @property def num_outputs(self) -> int: + """The number of output tensors produced by the polynomial.""" return self.polynomial.num_outputs @property def inputs(self) -> tuple[cue.Rep, ...]: + """The group representations of the input tensors.""" return self.operands[: self.num_inputs] @property def outputs(self) -> tuple[cue.Rep, ...]: + """The group representations of the output tensors.""" return self.operands[self.num_inputs :] def fuse_stps(self) -> EquivariantPolynomial: + """Fuse segmented tensor products with identical operations and operands. + + Returns: + EquivariantPolynomial: A new polynomial with fused tensor products. + """ return EquivariantPolynomial( self.operands, self.polynomial.fuse_stps(), ) def buffer_used(self) -> list[bool]: + """Check which buffers are used in the polynomial. + + Returns: + list[bool]: Boolean flags indicating which buffers are used. + """ return self.polynomial.buffer_used() def remove_unused_buffers(self) -> EquivariantPolynomial: + """Remove unused buffers from the polynomial. + + This method creates a new polynomial with unused buffers removed, + which can make operations more efficient. + + Returns: + EquivariantPolynomial: A new polynomial with unused buffers removed. + """ return EquivariantPolynomial( [rep for u, rep in zip(self.buffer_used(), self.operands) if u], self.polynomial.remove_unused_buffers(), @@ -113,6 +156,22 @@ def remove_unused_buffers(self) -> EquivariantPolynomial: def stack( cls, polys: list[EquivariantPolynomial], stacked: list[bool] ) -> EquivariantPolynomial: + """Stack multiple equivariant polynomials together. + + This method combines multiple polynomials by stacking their operands according to the + stacked parameter. Operands with the same index that are not stacked must be identical + across all polynomials. + + Args: + polys (list[EquivariantPolynomial]): List of polynomials to stack. + stacked (list[bool]): Boolean flags indicating which operands should be stacked. + + Returns: + EquivariantPolynomial: A new polynomial combining the stacked polynomials. + + Raises: + ValueError: If operands that are not stacked differ across polynomials. + """ assert len(polys) > 0 num_operands = polys[0].num_operands @@ -145,18 +204,39 @@ def stack( return poly.fuse_stps() def squeeze_modes(self) -> EquivariantPolynomial: + """Squeeze the modes of the segmented tensor products. + + Returns: + EquivariantPolynomial: A new polynomial with squeezed modes. + """ return EquivariantPolynomial( self.operands, self.polynomial.squeeze_modes(), ) def flatten_coefficient_modes(self) -> EquivariantPolynomial: + """Flatten the coefficient modes of the segmented tensor products. + + Returns: + EquivariantPolynomial: A new polynomial with flattened coefficient modes. + """ return EquivariantPolynomial( self.operands, self.polynomial.flatten_coefficient_modes(), ) def jvp(self, has_tangent: list[bool]) -> EquivariantPolynomial: + """Compute the Jacobian-vector product of the polynomial. + + This method creates a new polynomial that, when evaluated, computes the Jacobian-vector + product of the original polynomial. This is used for forward-mode automatic differentiation. + + Args: + has_tangent (list[bool]): Boolean flags indicating which inputs have tangent vectors. + + Returns: + EquivariantPolynomial: A new polynomial representing the JVP operation. + """ return EquivariantPolynomial( list(self.inputs) + [x for has, x in zip(has_tangent, self.inputs) if has] @@ -169,6 +249,21 @@ def transpose( is_undefined_primal: list[bool], has_cotangent: list[bool], ) -> EquivariantPolynomial: + """Transpose the polynomial operation. + + This method creates a new polynomial that represents the transpose of the original operation. + The transpose is essential for reverse-mode automatic differentiation. + + Args: + is_undefined_primal (list[bool]): Boolean flags indicating which inputs are undefined primals. + has_cotangent (list[bool]): Boolean flags indicating which outputs have cotangents. + + Returns: + EquivariantPolynomial: A new polynomial representing the transposed operation. + + Raises: + ValueError: If the polynomial is non-linear and cannot be transposed. + """ return EquivariantPolynomial( # defined inputs [ @@ -190,6 +285,19 @@ def transpose( def backward( self, requires_gradient: list[bool], has_cotangent: list[bool] ) -> EquivariantPolynomial: + """Compute the backward pass of the polynomial for gradient computation. + + This method combines the JVP and transpose operations to create a new polynomial that, + when evaluated, computes gradients of outputs with respect to inputs. This is the + core operation in reverse-mode automatic differentiation. + + Args: + requires_gradient (list[bool]): Boolean flags indicating which inputs require gradients. + has_cotangent (list[bool]): Boolean flags indicating which outputs have cotangents. + + Returns: + EquivariantPolynomial: A new polynomial for gradient computation. + """ return EquivariantPolynomial( list(self.inputs) + [x for has, x in zip(has_cotangent, self.outputs) if has] @@ -198,8 +306,24 @@ def backward( ) def flops(self, batch_size: int = 1) -> int: + """Compute the number of floating point operations in the polynomial. + + Args: + batch_size (int, optional): The batch size for the computation. Defaults to 1. + + Returns: + int: The estimated number of floating-point operations. + """ return self.polynomial.flops(batch_size) def memory(self, batch_sizes: list[int]) -> int: + """Compute the memory usage of the polynomial. + + Args: + batch_sizes (list[int]): The batch sizes for each operand. + + Returns: + int: The estimated memory usage in number of scalar elements. + """ assert len(batch_sizes) == len(self.operands) return sum(Z * rep.dim for Z, rep in zip(batch_sizes, self.operands)) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 25d1896d..2d836190 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -38,11 +38,6 @@ class SegmentedPolynomial: num_outputs (int): Number of output tensors. tensor_products (list of tuple of Operation and SegmentedTensorProduct): List of operation and tensor product pairs that define the polynomial transformation. - - Example: - >>> # Create a polynomial with 2 inputs and 1 output - >>> poly = SegmentedPolynomial(2, 1, [(op1, stp1), (op2, stp2)]) - >>> outputs = poly(input1, input2) # Evaluate polynomial on inputs (numpy reference implementation) """ num_inputs: int From 5e7b22ed5982d86c1d5a2d9042e0bc5899f9c98b Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 04:22:45 -0800 Subject: [PATCH 024/107] fix --- .../cuequivariance_jax/operations/spherical_harmonics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py b/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py index 3de6fadd..75686f23 100644 --- a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py +++ b/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py @@ -40,7 +40,7 @@ def spherical_harmonics( Example: >>> import cuequivariance_jax as cuex - >>> vector = cuex.randn(jax.random.key(0), cue.Irreps(cue.SO3, "1o")) + >>> vector = cuex.randn(jax.random.key(0), cue.IrrepsAndLayout(cue.Irreps(cue.O3, "1o"), cue.mul_ir)) >>> harmonics = spherical_harmonics([0, 1, 2], vector) """ ls = list(ls) From 4fcb8ea4f9877e324b3ae30ea5de0b1af1fb1cdc Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 04:27:23 -0800 Subject: [PATCH 025/107] quick and dirty fix --- .../operations/tp_channel_wise.py | 5 ++++- .../operations/tp_fully_connected.py | 4 ++-- .../tests/operations/channel_wise_test.py | 4 +++- .../primitives/symmetric_tensor_product_test.py | 5 +++-- .../tests/primitives/tensor_product_test.py | 14 +++++++++----- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index 9bb213f5..cf672f07 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -68,7 +68,10 @@ def __init__( e = descriptors.channelwise_tensor_product( irreps_in1, irreps_in2, filter_irreps_out ) - descriptor, irreps_out = e.d, e.operands[-1].irreps + descriptor, irreps_out = ( + e.polynomial.tensor_products[0][1], + e.operands[-1].irreps, + ) assert descriptor.subscripts == "uv,iu,jv,kuv+ijk" self.irreps_in1 = irreps_in1 diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index 1a003f7c..4c5b4277 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -70,13 +70,13 @@ def __init__( e = descriptors.fully_connected_tensor_product( irreps_in1, irreps_in2, irreps_out ) - assert e.d.subscripts == "uvw,iu,jv,kw+ijk" + assert e.polynomial.tensor_products[0][1].subscripts == "uvw,iu,jv,kw+ijk" self.irreps_in1 = irreps_in1 self.irreps_in2 = irreps_in2 self.irreps_out = irreps_out - self.weight_numel = e.d.operands[0].size + self.weight_numel = e.polynomial.tensor_products[0][1].operands[0].size self.shared_weights = shared_weights self.internal_weights = ( diff --git a/cuequivariance_torch/tests/operations/channel_wise_test.py b/cuequivariance_torch/tests/operations/channel_wise_test.py index 18bf586d..3b150aba 100644 --- a/cuequivariance_torch/tests/operations/channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/channel_wise_test.py @@ -69,7 +69,9 @@ def test_channel_wise_fwd( out1 = m1(x1, x2) - d = descriptors.channelwise_tensor_product(irreps1, irreps2, irreps3).d + d = descriptors.channelwise_tensor_product( + irreps1, irreps2, irreps3 + ).polynomial.tensor_products[0][1] d = d.squeeze_modes("v") assert d.subscripts == "u,iu,j,ku+ijk" if layout == cue.mul_ir: diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 37463f85..f7805cde 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -27,9 +27,10 @@ def make_descriptors(): - yield descriptors.symmetric_contraction( + [(_, d1), (_, d2), (_, d3)] = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] - ).ds + ).polynomial.tensor_products + yield [d1, d2, d3] d1 = stp.SegmentedTensorProduct.from_subscripts(",,") d1.add_path(None, None, None, c=2.0) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index 6a19072a..337d9f71 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -31,23 +31,27 @@ def make_descriptors(): cue.Irreps("O3", "4x0e + 4x1o"), cue.Irreps("O3", "6x0e + 6x1o"), cue.Irreps("O3", "5x0e + 5x1o + 5x2e + 5x1e"), - ).d + ).polynomial.tensor_products[0][1] - yield descriptors.spherical_harmonics(cue.SO3(1), [2]).d - yield descriptors.spherical_harmonics(cue.SO3(1), [3]).d + yield descriptors.spherical_harmonics(cue.SO3(1), [2]).polynomial.tensor_products[ + 0 + ][1] + yield descriptors.spherical_harmonics(cue.SO3(1), [3]).polynomial.tensor_products[ + 0 + ][1] d = descriptors.channelwise_tensor_product( cue.Irreps("SU2", "3x1/2 + 4x1"), cue.Irreps("SU2", "1/2 + 1 + 3/2"), cue.Irreps("SU2", "1/2 + 1"), - ).d + ).polynomial.tensor_products[0][1] yield d d = descriptors.channelwise_tensor_product( cue.Irreps("SO3", "32x1 + 32x2"), cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), - ).d + ).polynomial.tensor_products[0][1] yield d for subscripts in [ From cd855f136b46e29047296b21427ba7664826f8d9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 04:32:38 -0800 Subject: [PATCH 026/107] fix --- .../cuequivariance_torch/operations/linear.py | 2 +- cuequivariance_torch/tests/operations/fully_connected_test.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index 0191c8b5..8872f426 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -59,7 +59,7 @@ def __init__( math_dtype = math_dtype or dtype e = descriptors.linear(irreps_in, irreps_out) - assert e.d.subscripts == "uv,iu,iv" + assert e.polynomial.tensor_products[0][1].subscripts == "uv,iu,iv" self.irreps_in = irreps_in self.irreps_out = irreps_out diff --git a/cuequivariance_torch/tests/operations/fully_connected_test.py b/cuequivariance_torch/tests/operations/fully_connected_test.py index 7d62d6a2..e62f5b53 100644 --- a/cuequivariance_torch/tests/operations/fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/fully_connected_test.py @@ -70,7 +70,9 @@ def test_fully_connected( out1 = m1(x1, x2) - d = descriptors.fully_connected_tensor_product(irreps1, irreps2, irreps3).d + d = descriptors.fully_connected_tensor_product( + irreps1, irreps2, irreps3 + ).polynomial.tensor_products[0][1] if layout == cue.mul_ir: d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") m2 = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).to(device) From aa71b8dfc4fc1430dad628bfd39fa3078302523a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 05:30:47 -0800 Subject: [PATCH 027/107] fix --- .../tests/primitives/equivariant_tensor_product_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py index 99dec626..71f2b34d 100644 --- a/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/equivariant_tensor_product_test.py @@ -227,4 +227,4 @@ def test_high_degrees(use_fallback: bool, batch_size: int): for rep in e.inputs ] output = m(*inputs) - assert output.shape == (batch_size, e.output.dim) + assert output.shape == (batch_size, e.outputs[0].dim) From bb718f036d67834877abafef8fe62cb54ba568a7 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 06:04:44 -0800 Subject: [PATCH 028/107] fix --- .../tests/primitives/primitive_export_test.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/cuequivariance_torch/tests/primitives/primitive_export_test.py b/cuequivariance_torch/tests/primitives/primitive_export_test.py index f188434f..b4fe2c2c 100644 --- a/cuequivariance_torch/tests/primitives/primitive_export_test.py +++ b/cuequivariance_torch/tests/primitives/primitive_export_test.py @@ -39,9 +39,10 @@ def test_script_symmetric_contraction(mode, tmp_path): if not torch.cuda.is_available(): pytest.skip("CUDA is not available") - ds = cue.descriptors.symmetric_contraction( + e = cue.descriptors.symmetric_contraction( 32 * cue.Irreps("SO3", "0 + 1"), 32 * cue.Irreps("SO3", "0 + 1"), [1, 2, 3] - ).ds + ) + ds = [stp for _, stp in e.polynomial.tensor_products] batch = 12 x0 = torch.randn(3, ds[0].operands[0].size, device=device, dtype=torch.float32) @@ -65,7 +66,8 @@ def test_script_fused_tp_3(mode, tmp_path): cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") ) - .d.flatten_coefficient_modes() + .polynomial.tensor_products[0][1] + .flatten_coefficient_modes() .squeeze_modes("v") ) @@ -97,7 +99,8 @@ def test_script_fused_tp_4(mode, tmp_path): cue.descriptors.fully_connected_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") ) - .d.flatten_coefficient_modes() + .polynomial.tensor_products[0][1] + .flatten_coefficient_modes() .squeeze_modes("v") .permute_operands([1, 2, 0, 3]) ) @@ -133,7 +136,8 @@ def test_script_uniform_tp_3(mode, tmp_path): cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") ) - .d.flatten_coefficient_modes() + .polynomial.tensor_products[0][1] + .flatten_coefficient_modes() .squeeze_modes("v") ) From b9eb4245112fbf0359f71b1bdc8f7294531f8f22 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 06:16:32 -0800 Subject: [PATCH 029/107] fix --- cuequivariance_torch/tests/primitives/primitive_export_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cuequivariance_torch/tests/primitives/primitive_export_test.py b/cuequivariance_torch/tests/primitives/primitive_export_test.py index b4fe2c2c..969d9543 100644 --- a/cuequivariance_torch/tests/primitives/primitive_export_test.py +++ b/cuequivariance_torch/tests/primitives/primitive_export_test.py @@ -170,7 +170,8 @@ def test_script_uniform_tp_4(mode, tmp_path): cue.descriptors.channelwise_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") ) - .d.flatten_coefficient_modes() + .polynomial.tensor_products[0][1] + .flatten_coefficient_modes() .squeeze_modes("v") ) From 148d79c713a4bb384162f28c301e7c36cbf84911 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 06:21:11 -0800 Subject: [PATCH 030/107] fix --- .../tests/primitives/symmetric_tensor_product_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index f7805cde..72034ca6 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -118,9 +118,10 @@ def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: b if use_fallback is False and not torch.cuda.is_available(): pytest.skip("CUDA is not available") - ds = descriptors.symmetric_contraction( + e = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] - ).ds + ) + ds = [stp for _, stp in e.polynomial.tensor_products] m = cuet.IWeightedSymmetricTensorProduct( ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback ) From 58b0b60ec12846e5590ccd86089a96783383623a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 06:25:36 -0800 Subject: [PATCH 031/107] Improve coverage workflow error handling and dependency installation --- .github/workflows/coverage.yml | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index e7abad9f..ef9a90e3 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -30,7 +30,17 @@ jobs: git fetch origin ${{ github.event.pull_request.base.ref }}:${{ github.event.pull_request.base.ref }} git checkout ${{ github.event.pull_request.base.ref }} - pytest --cov=cuequivariance cuequivariance > coverage.txt + + # Install dependencies on target branch as they might be different + python -m uv pip install ./cuequivariance + + # Run tests on target branch and capture exit code + pytest --cov=cuequivariance cuequivariance > coverage.txt || { + echo "Tests failed on target branch. Cannot compare coverage values." + echo "Coverage on PR branch: $coverage_pr" + exit 1 + } + coverage_target=$(cat coverage.txt | grep TOTAL | awk '{print $4}' | sed 's/%//') echo "Coverage on target branch: $coverage_target" From 082e4a266a5ed89d9b621bdf8ac2435da850f986 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 06:27:29 -0800 Subject: [PATCH 032/107] debug --- .github/workflows/coverage.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index ef9a90e3..3cdd70f3 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -37,7 +37,7 @@ jobs: # Run tests on target branch and capture exit code pytest --cov=cuequivariance cuequivariance > coverage.txt || { echo "Tests failed on target branch. Cannot compare coverage values." - echo "Coverage on PR branch: $coverage_pr" + cat coverage.txt exit 1 } From ba6b9b387fd6fa43c0979a61056bc64e63d254a9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 06:30:00 -0800 Subject: [PATCH 033/107] coverage CI is not working --- .github/workflows/coverage.yml | 52 ---------------------------------- 1 file changed, 52 deletions(-) delete mode 100644 .github/workflows/coverage.yml diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml deleted file mode 100644 index 3cdd70f3..00000000 --- a/.github/workflows/coverage.yml +++ /dev/null @@ -1,52 +0,0 @@ -name: Coverage on PR - -on: - pull_request: - branches: [ "main", "release" ] - - -jobs: - cuequivariance: - - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v4 - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.12" - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install --upgrade uv - python -m uv pip install pytest pytest-cov - python -m uv pip install ./cuequivariance - - name: Run coverage - run: | - pytest --cov=cuequivariance cuequivariance > coverage.txt - coverage_pr=$(cat coverage.txt | grep TOTAL | awk '{print $4}' | sed 's/%//') - echo "Coverage on PR branch: $coverage_pr" - - git fetch origin ${{ github.event.pull_request.base.ref }}:${{ github.event.pull_request.base.ref }} - git checkout ${{ github.event.pull_request.base.ref }} - - # Install dependencies on target branch as they might be different - python -m uv pip install ./cuequivariance - - # Run tests on target branch and capture exit code - pytest --cov=cuequivariance cuequivariance > coverage.txt || { - echo "Tests failed on target branch. Cannot compare coverage values." - cat coverage.txt - exit 1 - } - - coverage_target=$(cat coverage.txt | grep TOTAL | awk '{print $4}' | sed 's/%//') - echo "Coverage on target branch: $coverage_target" - - if [ $coverage_pr -lt $coverage_target ]; then - echo "Coverage on PR branch is lower than on target branch" - echo "Coverage on PR branch: $coverage_pr" - echo "Coverage on target branch: $coverage_target" - exit 1 - fi From 1b48cec037ef21d1f81063dde8e08e7f822c5962 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 08:38:56 -0800 Subject: [PATCH 034/107] Enhance SegmentedPolynomial with buffer and output selection methods - Add `select_buffers` method to flexibly select and remap polynomial buffers - Implement `select_outputs` method to choose specific outputs - Create `compute_only` method to compute a subset of polynomial outputs - Refactor `remove_unused_buffers` to use the new `select_buffers` method --- .../cuequivariance/segmented_polynomial.py | 56 ++++++++++++++++--- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 2d836190..e5059ecd 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -263,24 +263,66 @@ def buffer_used(self) -> list[bool]: for buffer in range(self.num_inputs + self.num_outputs) ] - def remove_unused_buffers(self) -> SegmentedPolynomial: - """Remove unused buffers from the polynomial.""" - used = self.buffer_used() + def select_buffers(self, keep: list[bool]) -> SegmentedPolynomial: + """Select the buffers of the polynomial.""" + assert len(keep) == self.num_operands + + # Create a mapping from old buffer indices to new buffer indices new_index = [] i = 0 - for u in used: + for u in keep: if u: new_index.append(i) i += 1 else: new_index.append(None) + # Filter tensor products that write to buffers we want to keep + # and remap the buffer indices + new_tensor_products = [] + for ope, stp in self.tensor_products: + # Check if the operation writes to a buffer we want to keep + bid = ope.output_buffer(self.num_inputs) + if keep[bid]: + # Check if all input buffers needed by this operation are kept + if not all(keep[buffer] for buffer in ope.buffers): + raise ValueError( + f"Operation {ope} writes to buffer {bid} which is kept, but requires input buffers that are being dropped" + ) + + # Remap buffer indices + new_ope = cue.Operation([new_index[buffer] for buffer in ope.buffers]) + new_tensor_products.append((new_ope, stp)) + + # Calculate new num_inputs and num_outputs + new_num_inputs = sum(keep[: self.num_inputs]) + new_num_outputs = sum(keep[self.num_inputs :]) + return SegmentedPolynomial( - sum(used[: self.num_inputs]), - sum(used[self.num_inputs :]), + new_num_inputs, + new_num_outputs, + new_tensor_products, + ) + + def select_outputs(self, keep: list[bool]) -> SegmentedPolynomial: + """Select the outputs of the polynomial.""" + assert len(keep) == self.num_outputs + return self.select_buffers([True] * self.num_inputs + keep) + + def remove_unused_buffers(self) -> SegmentedPolynomial: + """Remove unused buffers from the polynomial.""" + return self.select_buffers(self.buffer_used()) + + def compute_only(self, keep: list[bool]) -> SegmentedPolynomial: + """Compute only the selected outputs of the polynomial.""" + assert len(keep) == self.num_outputs + return SegmentedPolynomial( + self.num_inputs, + self.num_outputs, # on purpose, we keep all outputs [ - (cue.Operation([new_index[buffer] for buffer in ope.buffers]), stp) + (ope, stp) for ope, stp in self.tensor_products + if keep[ope.output_buffer(self.num_inputs) - self.num_inputs] ], ) From 46b0509128b8013bd9f4bd367743f94925be5159 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 08:39:10 -0800 Subject: [PATCH 035/107] Enhance logging in segmented polynomial operations Add polynomial details to log output for improved debugging and traceability --- .../primitives/segmented_polynomial_ops_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py index 35e2cc3a..2a57138d 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py @@ -117,7 +117,7 @@ def log(msg: str): for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) - log("Using the uniform 1d kernel of cuequivariance_ops_jax 🚀") + log("Using the uniform 1d kernel of cuequivariance_ops_jax 🚀\n" + str(polynomial)) outputs = tensor_product_uniform_1d_jit( buffers[: polynomial.num_inputs], buffers[polynomial.num_inputs :], From 4546b7ff3096d7d9ad065a9321cff00e170a94e9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 28 Feb 2025 08:39:27 -0800 Subject: [PATCH 036/107] Add type hint and buffer usage assertion in segmented polynomial implementation - Add type hint for `used_buffers` as `list[int]` - Include an assertion to verify all polynomial buffers are used before sorting indices --- .../cuequivariance_jax/primitives/segmented_polynomial.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index 5601f427..d55776d4 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -197,7 +197,7 @@ def segmented_polynomial_prim( ] polynomial = polynomial.consolidate() - used_buffers = polynomial.used_buffers() + used_buffers: list[int] = polynomial.used_buffers() polynomial = polynomial.remove_unused_buffers() used_indices = sorted( @@ -283,6 +283,7 @@ def segmented_polynomial_impl( inputs, indices = inputs_and_indices[:num_inputs], inputs_and_indices[num_inputs:] del inputs_and_indices + assert all(polynomial.buffer_used()) polynomial = polynomial.sort_indices_for_identical_operands() outputs = None From f72537346dfcdba8089c81d030ac71a419b7c45f Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sat, 1 Mar 2025 04:45:56 -0800 Subject: [PATCH 037/107] clean --- .../cuequivariance/equivariant_polynomial.py | 22 --------------- .../cuequivariance/segmented_polynomial.py | 27 ++++++++++--------- .../tests/segmented_polynomial_test.py | 7 +++-- 3 files changed, 17 insertions(+), 39 deletions(-) diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/equivariant_polynomial.py index 390eff7b..39ca89e8 100644 --- a/cuequivariance/cuequivariance/equivariant_polynomial.py +++ b/cuequivariance/cuequivariance/equivariant_polynomial.py @@ -130,28 +130,6 @@ def fuse_stps(self) -> EquivariantPolynomial: self.polynomial.fuse_stps(), ) - def buffer_used(self) -> list[bool]: - """Check which buffers are used in the polynomial. - - Returns: - list[bool]: Boolean flags indicating which buffers are used. - """ - return self.polynomial.buffer_used() - - def remove_unused_buffers(self) -> EquivariantPolynomial: - """Remove unused buffers from the polynomial. - - This method creates a new polynomial with unused buffers removed, - which can make operations more efficient. - - Returns: - EquivariantPolynomial: A new polynomial with unused buffers removed. - """ - return EquivariantPolynomial( - [rep for u, rep in zip(self.buffer_used(), self.operands) if u], - self.polynomial.remove_unused_buffers(), - ) - @classmethod def stack( cls, polys: list[EquivariantPolynomial], stacked: list[bool] diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index e5059ecd..c41b7d54 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -246,23 +246,24 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): return self.fuse_stps().map_tensor_products(f) - def used_buffers(self) -> list[int]: - """Buffers used in the polynomial. (List of integers)""" - return sorted( - set( - itertools.chain.from_iterable( - ope.buffers for ope, _ in self.tensor_products - ) - ) - ) + def used_inputs(self) -> list[bool]: + """Inputs used in the polynomial. (List of boolean values)""" + return [ + any(buffer in ope.buffers for ope, _ in self.tensor_products) + for buffer in range(self.num_inputs) + ] - def buffer_used(self) -> list[bool]: - """Buffers used in the polynomial. (List of boolean values)""" + def used_outputs(self) -> list[bool]: + """Outputs used in the polynomial. (List of boolean values)""" return [ any(buffer in ope.buffers for ope, _ in self.tensor_products) - for buffer in range(self.num_inputs + self.num_outputs) + for buffer in range(self.num_inputs, self.num_inputs + self.num_outputs) ] + def used_buffers(self) -> list[bool]: + """Buffers used in the polynomial. (List of boolean values)""" + return self.used_inputs() + self.used_outputs() + def select_buffers(self, keep: list[bool]) -> SegmentedPolynomial: """Select the buffers of the polynomial.""" assert len(keep) == self.num_operands @@ -311,7 +312,7 @@ def select_outputs(self, keep: list[bool]) -> SegmentedPolynomial: def remove_unused_buffers(self) -> SegmentedPolynomial: """Remove unused buffers from the polynomial.""" - return self.select_buffers(self.buffer_used()) + return self.select_buffers(self.used_buffers()) def compute_only(self, keep: list[bool]) -> SegmentedPolynomial: """Compute only the selected outputs of the polynomial.""" diff --git a/cuequivariance/tests/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomial_test.py index c5e60f02..228d30c2 100644 --- a/cuequivariance/tests/segmented_polynomial_test.py +++ b/cuequivariance/tests/segmented_polynomial_test.py @@ -102,8 +102,7 @@ def test_buffer_properties(): assert poly.input_sizes == [2, 2] assert poly.output_sizes == [2, 1] - assert poly.used_buffers() == [0, 1, 2, 3] - assert poly.buffer_used() == [True, True, True, True] + assert poly.used_buffers() == [True, True, True, True] def test_remove_unused_buffers(): @@ -115,14 +114,14 @@ def test_remove_unused_buffers(): poly = cue.SegmentedPolynomial(3, 1, [(op, stp)]) # Buffer 1 is not used - assert poly.buffer_used() == [True, False, True, True] + assert poly.used_buffers() == [True, False, True, True] # Remove unused buffer cleaned_poly = poly.remove_unused_buffers() assert cleaned_poly.num_inputs == 2 assert cleaned_poly.num_outputs == 1 - assert cleaned_poly.buffer_used() == [True, True, True] + assert cleaned_poly.used_buffers() == [True, True, True] def test_consolidate(): From f974b8f6930b0a8c8c5d16682e3f36387901b082 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sat, 1 Mar 2025 04:48:16 -0800 Subject: [PATCH 038/107] clean segmented_polynomial_prim using _dce_helper --- .../primitives/segmented_polynomial.py | 64 +++++++++++++------ 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index d55776d4..6abbbf82 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -173,6 +173,30 @@ def fn(x: jax.Array, shape: tuple[int, ...]) -> jax.Array: segmented_polynomial_p.multiple_results = True +def _dce_helper( + used_inputs: list[bool], + used_outputs: list[bool], + buffer_index: list[int], + num_indices: int, +) -> tuple[list[bool], list[int]]: + used_indices_id: list[int] = sorted( + { + buffer_index[i] + for i, used in enumerate(used_inputs + used_outputs) + if used and buffer_index[i] >= 0 + } + ) + used_indices: list[bool] = [i in used_indices_id for i in range(num_indices)] + + buffer_index = tuple( + used_indices_id.index(buffer_index[i]) if buffer_index[i] >= 0 else -1 + for i, used in enumerate(used_inputs + used_outputs) + if used + ) + + return used_indices, buffer_index + + def segmented_polynomial_prim( inputs: list[jax.Array], # input buffers outputs_shape_dtype: list[jax.ShapeDtypeStruct], # output shapes and dtypes @@ -196,40 +220,40 @@ def segmented_polynomial_prim( jax.ShapeDtypeStruct(x.shape, x.dtype) for x in outputs_shape_dtype ] + # fuse STPs, consolidate modes, squeeze modes, remove empty segments, consolidate paths, sort paths polynomial = polynomial.consolidate() - used_buffers: list[int] = polynomial.used_buffers() - polynomial = polynomial.remove_unused_buffers() - used_indices = sorted( - {buffer_index[i] for i in used_buffers if buffer_index[i] >= 0} + used_inputs, used_outputs = polynomial.used_inputs(), polynomial.used_outputs() + + used_indices, buffer_index = _dce_helper( + used_inputs, used_outputs, buffer_index, len(indices) ) - used_outputs = segmented_polynomial_p.bind( - *[inputs[i] for i in used_buffers[: polynomial.num_inputs]], - *[indices[i] for i in used_indices], - buffer_index=tuple( - used_indices.index(buffer_index[i]) if buffer_index[i] >= 0 else -1 - for i in used_buffers - ), + new_outputs = segmented_polynomial_p.bind( + *[v for v, used in zip(inputs, used_inputs) if used], + *[v for v, used in zip(indices, used_indices) if used], + buffer_index=buffer_index, outputs_shape_dtype=tuple( - outputs_shape_dtype[i - len(inputs)] - for i in used_buffers[polynomial.num_inputs :] + x for x, used in zip(outputs_shape_dtype, used_outputs) if used ), - polynomial=polynomial, + polynomial=polynomial.select_buffers(used_inputs + used_outputs), math_dtype=jnp.dtype(math_dtype), name=str(name), impl=impl, ) if return_none_if_empty: - outputs = [None] * len(outputs_shape_dtype) + old_outputs = [None] * len(outputs_shape_dtype) else: - outputs = [jnp.zeros(out.shape, out.dtype) for out in outputs_shape_dtype] + old_outputs = [jnp.zeros(out.shape, out.dtype) for out in outputs_shape_dtype] - for i, output in zip(used_buffers[polynomial.num_inputs :], used_outputs): - outputs[i - len(inputs)] = output + i_new = 0 + for i_old, used in enumerate(used_outputs): + if used: + old_outputs[i_old] = new_outputs[i_new] + i_new += 1 - return tuple(outputs) + return tuple(old_outputs) def map_indices( @@ -283,7 +307,7 @@ def segmented_polynomial_impl( inputs, indices = inputs_and_indices[:num_inputs], inputs_and_indices[num_inputs:] del inputs_and_indices - assert all(polynomial.buffer_used()) + assert all(polynomial.used_buffers()) polynomial = polynomial.sort_indices_for_identical_operands() outputs = None From 9c9a0779287109956d5e823eb91224ad8c481708 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sat, 1 Mar 2025 09:13:56 -0800 Subject: [PATCH 039/107] sanitize outputs_shape_dtype argument --- .../primitives/segmented_polynomial.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index 6abbbf82..ecbb4e79 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -93,6 +93,14 @@ def segmented_polynomial( assert len(inputs) == polynomial.num_inputs assert len(outputs_shape_dtype) == polynomial.num_outputs + + outputs_shape_dtype = [ + jax.ShapeDtypeStruct( + x.shape if size is None else x.shape[:-1] + (size,), x.dtype + ) + for x, size in zip(outputs_shape_dtype, polynomial.output_sizes) + ] + buffers = list(inputs) + list(outputs_shape_dtype) if indices is None: @@ -216,10 +224,6 @@ def segmented_polynomial_prim( assert len(inputs) + len(outputs_shape_dtype) == len(buffer_index) assert max(buffer_index) < len(indices) - outputs_shape_dtype = [ - jax.ShapeDtypeStruct(x.shape, x.dtype) for x in outputs_shape_dtype - ] - # fuse STPs, consolidate modes, squeeze modes, remove empty segments, consolidate paths, sort paths polynomial = polynomial.consolidate() From 98842782b29f820c3ffea080e0f3ea44859a6934 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sat, 1 Mar 2025 19:59:31 +0100 Subject: [PATCH 040/107] cuex: implement custom_dce rule (#88) * wip * wip * working * done * remove print * clean * clean dce code * add missing import --- .../primitives/segmented_polynomial.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index c0afddfd..3bf09d90 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -20,7 +20,7 @@ import jax.extend import jax.lax import jax.numpy as jnp -from jax.interpreters import ad, batching, mlir, xla +from jax.interpreters import ad, batching, mlir, partial_eval, xla import cuequivariance as cue from cuequivariance_jax.primitives.primitives_utils import reshape @@ -562,6 +562,49 @@ def flatten_index(x: jax.Array) -> jax.Array: return outputs, (0,) * len(outputs) +def segmented_polynomial_dce( + used_outputs: list[bool], + eqn: jax.extend.core.JaxprEqn, +) -> tuple[list[bool], jax.extend.core.JaxprEqn | None]: + assert len(used_outputs) == len(eqn.outvars) + + polynomial: cue.SegmentedPolynomial = eqn.params["polynomial"] + buffer_index = eqn.params["buffer_index"] + outputs_shape_dtype = eqn.params["outputs_shape_dtype"] + + # If no outputs are used, we can eliminate the operation entirely + if not any(used_outputs) and not eqn.effects: + return [False] * len(eqn.invars), None + + num_inputs = polynomial.num_inputs + + polynomial = polynomial.compute_only(used_outputs) + used_inputs: list[bool] = polynomial.used_inputs() + + used_indices, buffer_index = _dce_helper( + used_inputs, used_outputs, buffer_index, len(eqn.invars) - num_inputs + ) + + new_eqn = jax.extend.core.JaxprEqn( + [v for v, used in zip(eqn.invars, used_inputs + used_indices) if used], + [v for v, used in zip(eqn.outvars, used_outputs) if used], + eqn.primitive, + dict( + eqn.params, + polynomial=polynomial.select_buffers(used_inputs + used_outputs), + buffer_index=buffer_index, + outputs_shape_dtype=tuple( + x for x, used in zip(outputs_shape_dtype, used_outputs) if used + ), + ), + eqn.effects, + eqn.source_info, + eqn.ctx, + ) + + return used_inputs + used_indices, new_eqn + + segmented_polynomial_p.def_abstract_eval(segmented_polynomial_abstract_eval) segmented_polynomial_p.def_impl(partial(xla.apply_primitive, segmented_polynomial_p)) mlir.register_lowering( @@ -583,3 +626,4 @@ def flatten_index(x: jax.Array) -> jax.Array: ad.primitive_jvps[segmented_polynomial_p] = segmented_polynomial_jvp ad.primitive_transposes[segmented_polynomial_p] = segmented_polynomial_transpose batching.primitive_batchers[segmented_polynomial_p] = segmented_polynomial_batching +partial_eval.dce_rules[segmented_polynomial_p] = segmented_polynomial_dce From a8e5b643128451f222a0fd3826a6feda2c2db3a3 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sun, 2 Mar 2025 11:38:00 +0100 Subject: [PATCH 041/107] add automatic stp symmetrization --- .../descriptors/spherical_harmonics_.py | 2 ++ cuequivariance/cuequivariance/segmented_polynomial.py | 11 ++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py index bd4b7070..1dd7e937 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py @@ -58,6 +58,8 @@ def spherical_harmonics( indices = poly_degrees_to_path_indices(degrees) d.add_path(*indices, i, c=coeff) + d = d.symmetrize_operands(range(ell)) + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul), diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index c41b7d54..867c6a95 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -50,9 +50,18 @@ def __init__( num_outputs: int, tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], ): + _tensor_products = [] + for ope, stp in tensor_products: + assert len(ope.buffers) == stp.num_operands + for set_of_operands in ope.operands_with_identical_buffers(): + stp = stp.symmetrize_operands(set_of_operands) + stp = stp.sort_paths() + + _tensor_products.append((ope, stp)) + object.__setattr__(self, "num_inputs", num_inputs) object.__setattr__(self, "num_outputs", num_outputs) - object.__setattr__(self, "tensor_products", sorted(tensor_products)) + object.__setattr__(self, "tensor_products", sorted(_tensor_products)) @classmethod def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): From 526c61ebe49f92f3fcc4f862949299e44673a00a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sun, 2 Mar 2025 12:18:14 +0100 Subject: [PATCH 042/107] make STP frozen while still modifiable inplace through modifiers --- .../segmented_tensor_product/dot.py | 4 +- .../segmented_tensor_product.py | 212 +++++++++++------- .../segmented_polynomial_ops_impl.py | 2 +- 3 files changed, 133 insertions(+), 85 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py b/cuequivariance/cuequivariance/segmented_tensor_product/dot.py index e54d4b19..0c215d54 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/dot.py @@ -65,9 +65,7 @@ def dot( x.coefficient_subscripts + y.coefficient_subscripts ) ) - d.operands = copy.deepcopy( - [x.operands[i] for i in x_keep] + [y.operands[i] for i in y_keep] - ) + d.set_operands([x.operands[i] for i in x_keep] + [y.operands[i] for i in y_keep]) formula = f"{x.coefficient_subscripts} , {y.coefficient_subscripts} -> {d.coefficient_subscripts}" diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index 31285c57..887047d9 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) -@dataclasses.dataclass(init=False, frozen=False) +@dataclasses.dataclass(init=False, frozen=True) class SegmentedTensorProduct: """ Irreps-agnostic and dataflow-agnostic descriptor of a segmented tensor product @@ -56,12 +56,13 @@ class SegmentedTensorProduct: .. rubric:: Methods """ - operands: list[stp.Operand] - paths: list[stp.Path] + operands: tuple[stp.Operand, ...] + paths: tuple[stp.Path, ...] coefficient_subscripts: str ################################ Initializers ################################ + # From here we can use object.__setattr__ to modify the attributes def __init__( self, *, @@ -71,12 +72,38 @@ def __init__( ): if operands is None: operands = [] - self.operands = operands - if paths is None: paths = [] - self.paths = paths - self.coefficient_subscripts = coefficient_subscripts + + object.__setattr__( + self, "operands", tuple(copy.deepcopy(ope) for ope in operands) + ) + object.__setattr__(self, "paths", tuple(copy.deepcopy(path) for path in paths)) + object.__setattr__(self, "coefficient_subscripts", coefficient_subscripts) + + def set_operands(self, operands: list[stp.Operand]): + object.__setattr__( + self, "operands", tuple(copy.deepcopy(ope) for ope in operands) + ) + + def set_operand(self, oid: int, operand: stp.Operand): + object.__setattr__( + self, + "operands", + self.operands[:oid] + (copy.deepcopy(operand),) + self.operands[oid + 1 :], + ) + + def set_paths(self, paths: list[stp.Path]): + object.__setattr__(self, "paths", tuple(copy.deepcopy(path) for path in paths)) + + def insert_path_(self, path_index: int, path: stp.Path): + object.__setattr__( + self, + "paths", + self.paths[:path_index] + (copy.deepcopy(path),) + self.paths[path_index:], + ) + + # until here. Below we use dataclasses.replace or the setters to modify the attributes def assert_valid(self): assert stp.Subscripts.is_valid(self.subscripts) @@ -679,21 +706,19 @@ def insert_path( dims = {m: next(iter(dd)) for m, dd in dims.items()} - self.paths.insert( - path_index, - stp.Path( - [ - ( - (s + self.operands[oid].num_segments) - % self.operands[oid].num_segments - if isinstance(s, int) - else self.add_segment(oid, dims) - ) - for oid, s in enumerate(segments) - ], - coefficients, - ), + path = stp.Path( + [ + ( + (s + self.operands[oid].num_segments) + % self.operands[oid].num_segments + if isinstance(s, int) + else self.add_segment(oid, dims) + ) + for oid, s in enumerate(segments) + ], + coefficients, ) + self.insert_path_(path_index, path) return path_index def add_path( @@ -749,16 +774,18 @@ def insert_segment( sid = _canonicalize_index("sid", sid, self.operands[operand].num_segments + 1) self.operands[operand].insert_segment(sid, segment) - self.paths = [ - stp.Path( - [ - s if s < sid or oid != operand else s + 1 - for oid, s in enumerate(path.indices) - ], - path.coefficients, - ) - for path in self.paths - ] + self.set_paths( + [ + stp.Path( + [ + s if s < sid or oid != operand else s + 1 + for oid, s in enumerate(path.indices) + ], + path.coefficients, + ) + for path in self.paths + ] + ) def insert_segments( self, @@ -781,21 +808,26 @@ def insert_segments( f"expected segments subscripts {segments.subscripts} to match operand subscripts {o.subscripts}." ) - self.operands[operand] = stp.Operand( - subscripts=o.subscripts, - segments=o.segments[:sid] + segments.segments + o.segments[sid:], - _dims={m: o.get_dims(m) | segments.get_dims(m) for m in o.subscripts}, + self.set_operand( + operand, + stp.Operand( + subscripts=o.subscripts, + segments=o.segments[:sid] + segments.segments + o.segments[sid:], + _dims={m: o.get_dims(m) | segments.get_dims(m) for m in o.subscripts}, + ), + ) + self.set_paths( + [ + stp.Path( + [ + s if s < sid or oid != operand else s + segments.num_segments + for oid, s in enumerate(path.indices) + ], + path.coefficients, + ) + for path in self.paths + ], ) - self.paths = [ - stp.Path( - [ - s if s < sid or oid != operand else s + segments.num_segments - for oid, s in enumerate(path.indices) - ], - path.coefficients, - ) - for path in self.paths - ] def add_segment( self, operand: int, segment: Union[tuple[int, ...], dict[str, int]] @@ -845,8 +877,9 @@ def add_or_rename_modes( for oid, operand in enumerate(self.operands): d.add_segments(oid, operand.segments) for path in self.paths: - d.paths.append( - stp.Path(indices=path.indices, coefficients=path.coefficients) + d.insert_path_( + len(d.paths), + stp.Path(indices=path.indices, coefficients=path.coefficients), ) return d @@ -888,14 +921,15 @@ def add_or_rename_modes( for m, d in zip(D.operands[oid].subscripts, D.operands[oid][sid]): dims[m] = d - D.paths.append( + D.insert_path_( + len(D.paths), stp.Path( indices=path.indices, coefficients=np.reshape( path.coefficients, tuple(dims[m] for m in D.coefficient_subscripts), ), - ) + ), ) return D @@ -951,13 +985,15 @@ def add_or_transpose_modes( ) old, new = self.coefficient_subscripts, d.coefficient_subscripts perm = [old.index(ch) for ch in new] - for path in self.paths: - d.paths.append( + d.set_paths( + [ stp.Path( indices=path.indices, coefficients=np.transpose(path.coefficients, perm), ) - ) + for path in self.paths + ] + ) return d def append_modes_to_all_operands( @@ -1017,7 +1053,7 @@ def permute_segments( ) -> SegmentedTensorProduct: """Permute the segments of an operand.""" operand = _canonicalize_index("operand", operand, self.num_operands) - new_operands = self.operands.copy() + new_operands = list(self.operands) new_operands[operand] = stp.Operand( segments=[self.operands[operand][i] for i in perm], subscripts=self.operands[operand].subscripts, @@ -1099,7 +1135,8 @@ def filter_shape(shape: tuple[int, ...], subscripts: str) -> tuple[int, ...]: ], ) for path in self.paths: - d.paths.append( + d.insert_path_( + len(d.paths), stp.Path( indices=path.indices, coefficients=np.reshape( @@ -1108,7 +1145,7 @@ def filter_shape(shape: tuple[int, ...], subscripts: str) -> tuple[int, ...]: path.coefficients.shape, self.coefficient_subscripts ), ), - ) + ), ) logger.debug(f"Squeezed {self} to {d}") return d @@ -1344,14 +1381,16 @@ def empty(D, oid, sid): perm.append(sid) D = D.permute_segments(oid, perm) - D.paths = [ - path - for path in D.paths - if not any(empty(D, oid, sid) for oid, sid in enumerate(path.indices)) - ] + D.set_paths( + [ + path + for path in D.paths + if not any(empty(D, oid, sid) for oid, sid in enumerate(path.indices)) + ] + ) for oid in range(D.num_operands): - D.operands[oid] = stp.Operand( + operand = stp.Operand( subscripts=D.operands[oid].subscripts, segments=[ segment @@ -1363,6 +1402,7 @@ def empty(D, oid, sid): for m, dd in D.operands[oid]._dims.items() }, ) + D.set_operand(oid, operand) return D @@ -1625,9 +1665,11 @@ def _consolidate_pair_of_modes(self, m: str, n: str) -> SegmentedTensorProduct: c = np.reshape( c, c.shape[:i] + (c.shape[i] * c.shape[i + 1],) + c.shape[i + 2 :] ) - d1.paths.append(stp.Path(indices=path.indices, coefficients=c)) + d1.insert_path_( + len(d1.paths), stp.Path(indices=path.indices, coefficients=c) + ) else: - d1.paths = copy.deepcopy(d0.paths) + d1.set_paths(d0.paths) return d1 @@ -1641,13 +1683,15 @@ def round_coefficients_to_rational( max_denominator (int): The maximum denominator, ``q < max_denominator``. """ d = copy.deepcopy(self) - d.paths = [ - stp.Path( - indices=path.indices, - coefficients=round_to_rational(path.coefficients, max_denominator), - ) - for path in d.paths - ] + d.set_paths( + [ + stp.Path( + indices=path.indices, + coefficients=round_to_rational(path.coefficients, max_denominator), + ) + for path in d.paths + ] + ) return d def round_coefficients_to_sqrt_rational( @@ -1660,13 +1704,17 @@ def round_coefficients_to_sqrt_rational( max_denominator (int): The maximum denominator, ``q < max_denominator``. """ d = copy.deepcopy(self) - d.paths = [ - stp.Path( - indices=path.indices, - coefficients=round_to_sqrt_rational(path.coefficients, max_denominator), - ) - for path in d.paths - ] + d.set_paths( + [ + stp.Path( + indices=path.indices, + coefficients=round_to_sqrt_rational( + path.coefficients, max_denominator + ), + ) + for path in d.paths + ] + ) return d def modify_coefficients( @@ -1679,10 +1727,12 @@ def modify_coefficients( f (callable): The function to apply to the coefficients. """ d = copy.deepcopy(self) - d.paths = [ - stp.Path(indices=path.indices, coefficients=f(path.coefficients)) - for path in d.paths - ] + d.set_paths( + [ + stp.Path(indices=path.indices, coefficients=f(path.coefficients)) + for path in d.paths + ] + ) return d def __mul__(self, factor: float) -> SegmentedTensorProduct: diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py index 775c100a..8d6ac3e8 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py @@ -124,7 +124,7 @@ def log(msg: str): outputs = tensor_product_uniform_1d_jit( buffers[: polynomial.num_inputs], buffers[polynomial.num_inputs :], - indices, + list(indices), buffer_index, operations=operations, paths=paths, From 72daf14e808620c24da44f061b6b22017194dfc2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sun, 2 Mar 2025 12:42:39 +0100 Subject: [PATCH 043/107] Optimize SegmentedTensorProduct symmetry check and hashing - Add optimization to skip symmetrization if already symmetric - Simplify hash method by removing unnecessary tuple conversion - Add lru_cache to get_dimensions_dict for performance --- .../segmented_tensor_product.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index 887047d9..99c610ae 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -205,9 +205,7 @@ def from_base64(cls, data: str) -> SegmentedTensorProduct: ################################ Properties ################################ def __hash__(self) -> int: - return hash( - (tuple(self.operands), tuple(self.paths), self.coefficient_subscripts) - ) + return hash((self.operands, self.paths, self.coefficient_subscripts)) def __eq__(self, value: SegmentedTensorProduct) -> bool: assert isinstance(value, SegmentedTensorProduct) @@ -438,6 +436,7 @@ def to_base64(self, extended: bool = False) -> str: """ return base64.b64encode(self.to_bytes(extended)).decode("ascii") + @functools.lru_cache(maxsize=None) def get_dimensions_dict(self) -> dict[str, set[int]]: """Get the dimensions of the tensor product.""" dims: dict[str, set[int]] = {ch: set() for ch in self.subscripts.modes()} @@ -1346,6 +1345,18 @@ def symmetrize_operands(self, operands: Sequence[int]) -> SegmentedTensorProduct return self permutations = list(itertools.permutations(range(len(operands)))) + + # optimization: skip if already symmetric + def make_global_perm(perm: tuple[int, ...]) -> tuple[int, ...]: + p = list(range(self.num_operands)) + for i, j in enumerate(perm): + p[operands[i]] = operands[j] + return tuple(p) + + symmetries: list[tuple[int, ...]] = self.symmetries() + if all(make_global_perm(perm) in symmetries for perm in permutations): + return self + d = self.sort_indices_for_identical_operands(operands) paths = [] From 3c6cbf7e679bc8dffb072cc2d57b49136e8fa9bf Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 10:23:17 +0100 Subject: [PATCH 044/107] Add symmetrization and unsymmetrization methods for identical operands --- .../cuequivariance/segmented_polynomial.py | 34 +++++++++++------ .../tests/segmented_polynomial_test.py | 37 +++++++++++++++++++ 2 files changed, 59 insertions(+), 12 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 867c6a95..69ec8a3c 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -50,18 +50,9 @@ def __init__( num_outputs: int, tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], ): - _tensor_products = [] - for ope, stp in tensor_products: - assert len(ope.buffers) == stp.num_operands - for set_of_operands in ope.operands_with_identical_buffers(): - stp = stp.symmetrize_operands(set_of_operands) - stp = stp.sort_paths() - - _tensor_products.append((ope, stp)) - object.__setattr__(self, "num_inputs", num_inputs) object.__setattr__(self, "num_outputs", num_outputs) - object.__setattr__(self, "tensor_products", sorted(_tensor_products)) + object.__setattr__(self, "tensor_products", sorted(tensor_products)) @classmethod def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): @@ -454,8 +445,27 @@ def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: raise ValueError(f"Buffer {buffer} is not used") return segments - def sort_indices_for_identical_operands(self) -> SegmentedPolynomial: - """Sort the indices of the segmented tensor products for identical operands.""" + def symmetrize_for_identical_operands(self) -> SegmentedPolynomial: + """Symmetrize the paths of the segmented tensor products for identical operands. + + This operation increases the number of paths in the segmented tensor products. + """ + + symmetrized_tensor_products = [] + for ope, stp in self.tensor_products: + for set_of_operands in ope.operands_with_identical_buffers(): + stp = stp.symmetrize_operands(set_of_operands) + stp = stp.sort_paths() + symmetrized_tensor_products.append((ope, stp)) + return SegmentedPolynomial( + self.num_inputs, self.num_outputs, symmetrized_tensor_products + ) + + def unsymmetrize_for_identical_operands(self) -> SegmentedPolynomial: + """Unsymmetrize the paths of the segmented tensor products for identical operands. + + This operation decreases the number of paths in the segmented tensor products. + """ def optimize_paths(ope: cue.Operation, stp: cue.SegmentedTensorProduct): for set_of_operands in ope.operands_with_identical_buffers(): diff --git a/cuequivariance/tests/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomial_test.py index 228d30c2..0d7f2cd7 100644 --- a/cuequivariance/tests/segmented_polynomial_test.py +++ b/cuequivariance/tests/segmented_polynomial_test.py @@ -342,3 +342,40 @@ def test_backward(): # With zero cotangent, gradients should be zero assert np.allclose(grad_x_zero, np.zeros_like(x)) assert np.allclose(grad_y_zero, np.zeros_like(y)) + + +def test_symmetrize_identical_operands(): + """Test symmetrization and unsymmetrization of polynomials with identical operands.""" + stp = cue.SegmentedTensorProduct.empty_segments([2, 2, 1]) + stp.add_path(0, 1, 0, c=1.0) # x0 * y1 path + + # Create operation that uses the same input buffer twice + op = cue.Operation((0, 0, 1)) # Use buffer 0 twice, write to buffer 1 + poly = cue.SegmentedPolynomial(1, 1, [(op, stp)]) + + # Symmetrize the polynomial + sym_poly = poly.symmetrize_for_identical_operands() + + # Check that we get 0.5 x0*y1 + 0.5 x1*y0 + # This means we should have two paths with coefficient 0.5 + [(_, sym_stp)] = sym_poly.tensor_products + assert len(sym_stp.paths) == 2 + # Check that we get 0.5 x0*y1 + 0.5 x1*y0 + assert sym_stp.paths[0].coefficients == 0.5 + assert sym_stp.paths[1].coefficients == 0.5 + # Check that the paths have different indices (operands swapped) + assert sym_stp.paths[0].indices == (0, 1, 0) + assert sym_stp.paths[1].indices == (1, 0, 0) + + # Test that unsymmetrize returns to original form + unsym_poly = sym_poly.unsymmetrize_for_identical_operands() + [(_, unsym_stp)] = unsym_poly.tensor_products + assert len(unsym_stp.paths) == 1 + assert unsym_stp.paths[0].coefficients == 1.0 + assert unsym_stp.paths[0].indices == (0, 1, 0) + + # Test evaluation to verify the symmetrization works correctly + x = np.array([1.0, 2.0]) + [result] = poly(x) # Original polynomial + [sym_result] = sym_poly(x) # Symmetrized polynomial + assert np.allclose(result, sym_result) # Results should be identical From f17b8f471265c2c9c9c1e863127e502ee30f8896 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 10:33:13 +0100 Subject: [PATCH 045/107] Symmetrize polynomial for JVP computation --- cuequivariance/cuequivariance/segmented_polynomial.py | 5 ++++- .../cuequivariance_jax/primitives/segmented_polynomial.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 69ec8a3c..9b6a5f5b 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -375,8 +375,11 @@ def jvp(self, has_tangent: list[bool]) -> SegmentedPolynomial: """Compute the Jacobian-vector product of the polynomial.""" assert len(has_tangent) == self.num_inputs + # Symmetrizing the polynomial helps identify simplifications by group_by_operational_symmetries + sym_poly = self.symmetrize_for_identical_operands() + new_tps = [] - for ope, stp in self.tensor_products: + for ope, stp in sym_poly.tensor_products: jvps = ope.jvp(has_tangent) permutations: list[tuple[int, ...]] = stp.symmetries() for multiplicator, ope in cue.Operation.group_by_operational_symmetries( diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index 3bf09d90..892412db 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -312,7 +312,7 @@ def segmented_polynomial_impl( del inputs_and_indices assert all(polynomial.used_buffers()) - polynomial = polynomial.sort_indices_for_identical_operands() + polynomial = polynomial.unsymmetrize_for_identical_operands() outputs = None kwargs = dict( From 5295186e6615ad9203bf03b138ff3dd6d953c43a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 10:37:38 +0100 Subject: [PATCH 046/107] rename --- .../cuequivariance_jax/primitives/segmented_polynomial.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index 892412db..4745d44c 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -260,7 +260,7 @@ def segmented_polynomial_prim( return tuple(old_outputs) -def map_indices( +def _remap_indices_and_buffer_index( old_indices: list[jax.Array], old_buffer_index: list[int], mapping: list[int] ) -> tuple[list[jax.Array], list[int]]: new_indices = [] @@ -374,7 +374,7 @@ def segmented_polynomial_jvp( impl=impl, ) - jvp_indices, jvp_buffer_index = map_indices( + jvp_indices, jvp_buffer_index = _remap_indices_and_buffer_index( indices, buffer_index, [i for i, x in enumerate(primals)] @@ -416,7 +416,7 @@ def segmented_polynomial_transpose( # The cotangents replace the outputs as inputs # The undefined primal inputs become outputs - tr_indices, tr_buffer_index = map_indices( + tr_indices, tr_buffer_index = _remap_indices_and_buffer_index( indices, buffer_index, [i for i, x in enumerate(inputs) if not ad.is_undefined_primal(x)] From 1aa0cb17d52383ac9e4cfc36ba44f9262e16d428 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 10:54:59 +0100 Subject: [PATCH 047/107] fix --- .../cuequivariance_torch/primitives/symmetric_tensor_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 9c7a5739..2f9503a7 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -48,7 +48,7 @@ def __init__( descriptors = [ stp.SegmentedTensorProduct( - operands=[stp.Operand.empty_segments(1)] + d.operands, + operands=(stp.Operand.empty_segments(1),) + d.operands, paths=[ stp.Path((0,) + path.indices, path.coefficients) for path in d.paths ], From bc55f41d4ee563f121a46d2cdb9e0fffe671e499 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 10:57:52 +0100 Subject: [PATCH 048/107] fix --- .../cuequivariance/descriptors/spherical_harmonics_.py | 2 +- .../segmented_tensor_product/segmented_tensor_product.py | 2 +- .../cuequivariance_jax/primitives/equivariant_polynomial.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py index 1dd7e937..4754d44f 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py @@ -41,7 +41,7 @@ def spherical_harmonics( ╭ a=1 -> B=0+1+2 │ B ───── sizes=9 num_segments=9 num_paths=1 │ a B ─── , sizes=3,9 num_segments=3,9 num_paths=3 - ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=8 + ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=11 """ if len(ls) != 1: return cue.EquivariantPolynomial.stack( diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index 99c610ae..b30572d2 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -239,7 +239,7 @@ def __repr__(self) -> str: "[" + ",".join(map(str, operand.segments)) + "]" for operand in self.operands ) - return f"{self.subscripts} operands={operands} paths={self.paths}" + return f"{self.subscripts} operands={operands} paths={list(self.paths)}" sizes = ",".join(f"{operand.size}" for operand in self.operands) num_segments = ",".join(f"{len(operand)}" for operand in self.operands) output = f"{self.subscripts} sizes={sizes} num_segments={num_segments} num_paths={self.num_paths}" diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py index 791ba306..ec23efe6 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py @@ -58,7 +58,7 @@ def equivariant_polynomial( ╭ a=1 -> B=0+1+2 │ B ───── sizes=9 num_segments=9 num_paths=1 │ a B ─── , sizes=3,9 num_segments=3,9 num_paths=3 - ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=8 + ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=11 Basic usage with single input: From 5a45a1c62231c49adea4fb50c682178b06467bf3 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 11:03:14 +0100 Subject: [PATCH 049/107] fix --- cuequivariance/cuequivariance/segmented_tensor_product/dot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py b/cuequivariance/cuequivariance/segmented_tensor_product/dot.py index 0c215d54..3ed47c65 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/dot.py @@ -14,7 +14,6 @@ # limitations under the License. from __future__ import annotations -import copy import itertools from typing import Any, Sequence From 1ebf6e2762e844d192c0bf5ba8f9bdf6185f2e00 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 11:04:24 +0100 Subject: [PATCH 050/107] update ruff version --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 265ce6c5..c7f8ceaf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.3 + rev: v0.9.9 hooks: - id: ruff args: ["--extend-select", "I", "--fix"] From c3d11afa3bc03fbc73fff67c63c25a221bd70843 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 11:07:06 +0100 Subject: [PATCH 051/107] self-hosted is temporarily down --- .github/workflows/tests.yml | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9e9860c0..24733fad 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -61,13 +61,15 @@ jobs: strategy: fail-fast: false matrix: - include: - - runner: "ubuntu-latest" - python-version: "3.10" - - runner: "self-hosted" - python-version: "3.12" + # include: + # - runner: "ubuntu-latest" + # python-version: "3.10" + # - runner: "self-hosted" + # python-version: "3.12" + python-version: ["3.10", "3.12"] - runs-on: ${{ matrix.runner }} + # runs-on: ${{ matrix.runner }} + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 From 1d32ab2ebec8ac2b2312bfd4ef4c6593177e2b68 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 11:25:39 +0100 Subject: [PATCH 052/107] simplify methods flops & memory and uniformize --- CHANGELOG.md | 5 +++ .../equivariant_tensor_product.py | 2 +- .../cuequivariance/segmented_polynomial.py | 2 +- .../segmented_tensor_product.py | 45 ++++++------------- 4 files changed, 20 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3c2743cd..8b347e8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ ## Latest Changes +### Breaking Changes +- Rename `SegmentedTensorProduct.flop_cost` in `flops` +- Rename `SegmentedTensorProduct.memory_cost` in `memory` + + ## 0.3.0-rc1 ### Breaking Changes diff --git a/cuequivariance/cuequivariance/equivariant_tensor_product.py b/cuequivariance/cuequivariance/equivariant_tensor_product.py index da52f9a7..db753cac 100644 --- a/cuequivariance/cuequivariance/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/equivariant_tensor_product.py @@ -316,7 +316,7 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: def flop_cost(self, batch_size: int) -> int: """Compute the number of flops of the tensor product.""" - return sum(d.flop_cost(-1) for d in self.ds) * batch_size + return sum(d.flops(-1) for d in self.ds) * batch_size def memory_cost( self, batch_sizes: tuple[int, ...], itemsize: Union[int, tuple[int, ...]] diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 9b6a5f5b..73ba69c3 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -425,7 +425,7 @@ def flops(self, batch_size: int = 1) -> int: n = 0 for ope, stp in self.tensor_products: oid, _ = ope.output_operand_buffer(self.num_inputs) - n += stp.flop_cost(oid) + n += stp.flops(oid) return batch_size * n def memory(self, batch_sizes: list[int]) -> int: diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index b30572d2..c4027ede 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -346,8 +346,8 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str: + "]" ) - out += f"\nFlop cost: {' '.join(f'{oid}->{self.flop_cost(oid)}' for oid in range(self.num_operands))}" - out += f"\nMemory cost: {self.memory_cost('global')}" + out += f"\nFlop cost: {' '.join(f'{oid}->{self.flops(oid)}' for oid in range(self.num_operands))}" + out += f"\nMemory cost: {self.memory()}" if len(self.paths) > 0: out += "\nPath indices: " + ", ".join( @@ -399,16 +399,13 @@ def to_dict(self, extended: bool = False) -> dict[str, Any]: "size": ope.size, "segment_offsets": [sl.start for sl in slices], "segment_sizes": [sl.stop - sl.start for sl in slices], - "flop_cost": self.flop_cost(oid), + "flops": self.flops(oid), } for oid, ope, slices in zip( range(self.num_operands), self.operands, segment_slices ) ], - "memory_cost": { - algorithm: self.memory_cost(algorithm) - for algorithm in ["sequential", "global"] - }, + "memory": self.memory(), "paths": paths, } return extended_dict @@ -567,12 +564,15 @@ def coefficients_equal_one(self) -> bool: """Check if all coefficients are equal to one.""" return np.all(self.stacked_coefficients == 1) - def flop_cost(self, operand: int, algorithm: str = "optimal") -> int: + def flops( + self, operand: int, batch_size: int = 1, algorithm: str = "optimal" + ) -> int: """ Compute the number of flops needed to compute the specified operand. Args: operand (int): The operand for which to compute the flop cost. + batch_size (int, optional): The batch size for the computation. Defaults to 1. algorithm (str, optional): The algorithm to use to compute the cost. Can be 'optimal' or 'naive'. Returns: @@ -606,31 +606,12 @@ def compute_cost(segment_shapes: tuple[tuple[int, ...], ...]) -> int: dims = d.get_path_dimensions_dict(path) coeff_shape = tuple(dims[ch] for ch in d.coefficient_subscripts) cost += compute_cost((coeff_shape,) + shapes) - return cost - - def memory_cost(self, algorithm: str = "sequential") -> int: - """ - Compute the number of memory accesses needed to compute the specified operand. + return cost * batch_size - Args: - operand (int): The operand for which to compute the memory cost. - algorithm (str, optional): The algorithm to use to compute the cost. Can be 'sequential' or 'global'. - - Returns: - int: The number of memory accesses needed to compute the specified operand. - """ - if algorithm == "sequential": - return sum( - sum( - math.prod(self.get_segment_shape(oid, path)) - for oid in range(self.num_operands) - ) - for path in self.paths - ) - elif algorithm == "global": - return sum(operand.size for operand in self.operands) - else: - raise ValueError(f"unknown algorithm {algorithm}.") + def memory(self, batch_sizes: list[int]) -> int: + """Compute the memory usage of the tensor product.""" + assert len(batch_sizes) == self.num_operands + return sum(Z * size for Z, size in zip(batch_sizes, self.operands)) ################################ Modifiers ################################ From 70382fb48c40aacff8c17272dee0416a2671e9a9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 14:27:48 +0100 Subject: [PATCH 053/107] add asserts --- cuequivariance/cuequivariance/segmented_polynomial.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 73ba69c3..73419e25 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -50,6 +50,11 @@ def __init__( num_outputs: int, tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], ): + for ope, stp in tensor_products: + assert isinstance(ope, cue.Operation) + assert isinstance(stp, cue.SegmentedTensorProduct) + assert len(ope.buffers) == stp.num_operands + object.__setattr__(self, "num_inputs", num_inputs) object.__setattr__(self, "num_outputs", num_outputs) object.__setattr__(self, "tensor_products", sorted(tensor_products)) From 2cde3a01e739144323922c183e8a5bfb6c7214eb Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 14:41:36 +0100 Subject: [PATCH 054/107] rename Operand in SegmentedOperand and move at root --- cuequivariance/cuequivariance/__init__.py | 2 ++ .../descriptors/transposition.py | 8 ++--- .../equivariant_tensor_product.py | 2 +- .../operand.py => segmented_operand.py} | 18 +++++----- .../segmented_tensor_product/__init__.py | 2 -- .../segmented_tensor_product/dot.py | 18 +++++----- .../segmented_tensor_product.py | 34 ++++++++++--------- .../segmented_polynomial_vanilla_impl.py | 2 +- .../primitives/symmetric_tensor_product.py | 25 ++++++++------ .../primitives/transpose.py | 2 +- 10 files changed, 57 insertions(+), 56 deletions(-) rename cuequivariance/cuequivariance/{segmented_tensor_product/operand.py => segmented_operand.py} (94%) diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index e683da28..b489c02b 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -48,6 +48,7 @@ ) from cuequivariance.operation import Operation +from cuequivariance.segmented_operand import SegmentedOperand from cuequivariance.segmented_tensor_product import SegmentedTensorProduct from cuequivariance.segmented_polynomial import SegmentedPolynomial from cuequivariance.equivariant_polynomial import EquivariantPolynomial @@ -80,6 +81,7 @@ "reduced_symmetric_tensor_product_basis", "reduced_antisymmetric_tensor_product_basis", "Operation", + "SegmentedOperand", "SegmentedTensorProduct", "SegmentedPolynomial", "EquivariantPolynomial", diff --git a/cuequivariance/cuequivariance/descriptors/transposition.py b/cuequivariance/cuequivariance/descriptors/transposition.py index e7191b65..749fa0d8 100644 --- a/cuequivariance/cuequivariance/descriptors/transposition.py +++ b/cuequivariance/cuequivariance/descriptors/transposition.py @@ -21,12 +21,8 @@ def transpose( """Transpose the irreps layout of a tensor.""" d = cue.SegmentedTensorProduct( operands=[ - cue.segmented_tensor_product.Operand( - subscripts="ui" if source == cue.mul_ir else "iu" - ), - cue.segmented_tensor_product.Operand( - subscripts="ui" if target == cue.mul_ir else "iu" - ), + cue.SegmentedOperand(subscripts="ui" if source == cue.mul_ir else "iu"), + cue.SegmentedOperand(subscripts="ui" if target == cue.mul_ir else "iu"), ] ) for mul, ir in irreps: diff --git a/cuequivariance/cuequivariance/equivariant_tensor_product.py b/cuequivariance/cuequivariance/equivariant_tensor_product.py index db753cac..eed6581f 100644 --- a/cuequivariance/cuequivariance/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/equivariant_tensor_product.py @@ -369,7 +369,7 @@ def backward(self, input: int) -> tuple[EquivariantTensorProduct, tuple[int, ... e = EquivariantTensorProduct(ds, tuple(self.operands[i] for i in oids)) return e, tuple(oids) - def stp_operand(self, oid: int) -> Optional[stp.Operand]: + def stp_operand(self, oid: int) -> Optional[cue.SegmentedOperand]: # output if oid == self.num_operands - 1: return self.ds[0].operands[-1] diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/operand.py b/cuequivariance/cuequivariance/segmented_operand.py similarity index 94% rename from cuequivariance/cuequivariance/segmented_tensor_product/operand.py rename to cuequivariance/cuequivariance/segmented_operand.py index 0c6627fe..ab3d7770 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/operand.py +++ b/cuequivariance/cuequivariance/segmented_operand.py @@ -20,11 +20,11 @@ from cuequivariance import segmented_tensor_product as stp -from .dimensions_dict import format_dimensions_dict +from .segmented_tensor_product.dimensions_dict import format_dimensions_dict @dataclasses.dataclass(init=False, frozen=True) -class Operand: +class SegmentedOperand: """A tensor product operand. It is a list of segments and subscripts.""" _segments: list[tuple[int, ...]] @@ -52,7 +52,7 @@ def __init__( object.__setattr__(self, "_dims", _dims) @classmethod - def empty_segments(cls, num_segments: int) -> Operand: + def empty_segments(cls, num_segments: int) -> SegmentedOperand: """Create an operand with empty subscripts""" return cls(subscripts="", segments=[()] * num_segments, _dims=dict()) @@ -102,10 +102,10 @@ def add_segment(self, segment: Union[tuple[int, ...], dict[str, int]]) -> int: def __hash__(self) -> int: return hash((tuple(self.segments), self.subscripts)) - def __eq__(self, other: Operand) -> bool: + def __eq__(self, other: SegmentedOperand) -> bool: return self.subscripts == other.subscripts and self.segments == other.segments - def __lt__(self, other: Operand) -> bool: + def __lt__(self, other: SegmentedOperand) -> bool: return (self.subscripts, self.segments) < (other.subscripts, other.segments) def __repr__(self) -> str: @@ -169,7 +169,7 @@ def get_segment_shape( def transpose_modes( self, subscripts: Union[str, Sequence[str], Sequence[int]] - ) -> Operand: + ) -> SegmentedOperand: """Transpose the channels of the operand.""" if not isinstance(subscripts, Sequence): raise TypeError("channels must be a sequence.") @@ -187,7 +187,7 @@ def transpose_modes( ) segments = [tuple(segment[i] for i in subscripts) for segment in self.segments] - return Operand( + return SegmentedOperand( subscripts="".join(self.subscripts[i] for i in subscripts), segments=segments, _dims=self._dims, @@ -211,10 +211,10 @@ def segment_size(self) -> int: raise ValueError("Segments do not have the same shape.") return math.prod(self.segments[0]) - def __add__(self, other: Operand) -> Operand: + def __add__(self, other: SegmentedOperand) -> SegmentedOperand: if self.subscripts != other.subscripts: raise ValueError("subscripts do not match.") - return Operand( + return SegmentedOperand( subscripts=self.subscripts, segments=self.segments + other.segments, _dims={m: self.get_dims(m) | other.get_dims(m) for m in self.subscripts}, diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/__init__.py b/cuequivariance/cuequivariance/segmented_tensor_product/__init__.py index 67519aeb..3ff358a5 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/__init__.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/__init__.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from .subscripts import Subscripts -from .operand import Operand from .path import Path from .segmented_tensor_product import SegmentedTensorProduct from .dot import dot, trace @@ -24,7 +23,6 @@ __all__ = [ "Subscripts", - "Operand", "Path", "SegmentedTensorProduct", "dot", diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py b/cuequivariance/cuequivariance/segmented_tensor_product/dot.py index 3ed47c65..4165682c 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/dot.py @@ -19,7 +19,7 @@ import numpy as np -from cuequivariance import segmented_tensor_product as stp +import cuequivariance as cue def stable_unique(xs: Sequence[Any]) -> Sequence[Any]: @@ -33,10 +33,10 @@ def stable_unique(xs: Sequence[Any]) -> Sequence[Any]: def dot( - x: stp.SegmentedTensorProduct, - y: stp.SegmentedTensorProduct, + x: cue.SegmentedTensorProduct, + y: cue.SegmentedTensorProduct, *contraction: tuple[int, int], -) -> stp.SegmentedTensorProduct: +) -> cue.SegmentedTensorProduct: """ Compute the dot product of two segmented tensor products. @@ -59,7 +59,7 @@ def dot( oid for oid in range(y.num_operands) if all(oid != j for _, j in contraction) ] - d = stp.SegmentedTensorProduct( + d = cue.SegmentedTensorProduct( coefficient_subscripts=stable_unique( x.coefficient_subscripts + y.coefficient_subscripts ) @@ -103,8 +103,8 @@ def dot( def trace( - d: stp.SegmentedTensorProduct, *contraction: tuple[int, int] -) -> stp.SegmentedTensorProduct: + d: cue.SegmentedTensorProduct, *contraction: tuple[int, int] +) -> cue.SegmentedTensorProduct: """ Compute the trace of a segmented tensor product. @@ -134,10 +134,10 @@ def trace( coefficients_subscripts_renamed = f(d.coefficient_subscripts) coefficients_subscripts_compressed = stable_unique(coefficients_subscripts_renamed) - dout = stp.SegmentedTensorProduct( + dout = cue.SegmentedTensorProduct( coefficient_subscripts=coefficients_subscripts_compressed, operands=[ - stp.Operand( + cue.SegmentedOperand( subscripts=f(ope.subscripts), segments=ope.segments, ) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index c4027ede..f0bb175e 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -56,7 +56,7 @@ class SegmentedTensorProduct: .. rubric:: Methods """ - operands: tuple[stp.Operand, ...] + operands: tuple[cue.SegmentedOperand, ...] paths: tuple[stp.Path, ...] coefficient_subscripts: str @@ -66,7 +66,7 @@ class SegmentedTensorProduct: def __init__( self, *, - operands: Optional[list[stp.Operand]] = None, + operands: Optional[list[cue.SegmentedOperand]] = None, paths: Optional[list[stp.Path]] = None, coefficient_subscripts: str = "", ): @@ -81,12 +81,12 @@ def __init__( object.__setattr__(self, "paths", tuple(copy.deepcopy(path) for path in paths)) object.__setattr__(self, "coefficient_subscripts", coefficient_subscripts) - def set_operands(self, operands: list[stp.Operand]): + def set_operands(self, operands: list[cue.SegmentedOperand]): object.__setattr__( self, "operands", tuple(copy.deepcopy(ope) for ope in operands) ) - def set_operand(self, oid: int, operand: stp.Operand): + def set_operand(self, oid: int, operand: cue.SegmentedOperand): object.__setattr__( self, "operands", @@ -154,7 +154,9 @@ def from_subscripts(cls, subscripts: stp.Subscripts) -> SegmentedTensorProduct: uv,ui,vj+ij operands=[(2, 3)],[(2, 5)],[(3, 4)] paths=[op0[0]*op1[0]*op2[0]*c c.shape=(5, 4) c.nnz=20] """ subscripts = stp.Subscripts(subscripts) - operands = [stp.Operand(subscripts=operand) for operand in subscripts.operands] + operands = [ + cue.SegmentedOperand(subscripts=operand) for operand in subscripts.operands + ] return cls( operands=operands, paths=[], coefficient_subscripts=subscripts.coefficients @@ -170,7 +172,7 @@ def empty_segments(cls, num_segments: list[int]) -> SegmentedTensorProduct: ,, sizes=2,3,4 num_segments=2,3,4 num_paths=0 """ return cls( - operands=[stp.Operand.empty_segments(num) for num in num_segments], + operands=[cue.SegmentedOperand.empty_segments(num) for num in num_segments], paths=[], coefficient_subscripts="", ) @@ -347,7 +349,7 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str: ) out += f"\nFlop cost: {' '.join(f'{oid}->{self.flops(oid)}' for oid in range(self.num_operands))}" - out += f"\nMemory cost: {self.memory()}" + out += f"\nMemory cost: {self.memory([1] * self.num_operands)}" if len(self.paths) > 0: out += "\nPath indices: " + ", ".join( @@ -405,7 +407,7 @@ def to_dict(self, extended: bool = False) -> dict[str, Any]: range(self.num_operands), self.operands, segment_slices ) ], - "memory": self.memory(), + "memory": self.memory([1] * self.num_operands), "paths": paths, } return extended_dict @@ -611,7 +613,7 @@ def compute_cost(segment_shapes: tuple[tuple[int, ...], ...]) -> int: def memory(self, batch_sizes: list[int]) -> int: """Compute the memory usage of the tensor product.""" assert len(batch_sizes) == self.num_operands - return sum(Z * size for Z, size in zip(batch_sizes, self.operands)) + return sum(Z * ope.size for Z, ope in zip(batch_sizes, self.operands)) ################################ Modifiers ################################ @@ -771,14 +773,14 @@ def insert_segments( self, operand: int, sid: int, - segments: Union[list[tuple[int, ...]], stp.Operand], + segments: Union[list[tuple[int, ...]], cue.SegmentedOperand], ): """Insert segments at a specific index.""" operand = _canonicalize_index("operand", operand, self.num_operands) sid = _canonicalize_index("sid", sid, self.operands[operand].num_segments + 1) - if not isinstance(segments, stp.Operand): - segments = stp.Operand( + if not isinstance(segments, cue.SegmentedOperand): + segments = cue.SegmentedOperand( subscripts=self.operands[operand].subscripts, segments=segments ) @@ -790,7 +792,7 @@ def insert_segments( self.set_operand( operand, - stp.Operand( + cue.SegmentedOperand( subscripts=o.subscripts, segments=o.segments[:sid] + segments.segments + o.segments[sid:], _dims={m: o.get_dims(m) | segments.get_dims(m) for m in o.subscripts}, @@ -1034,7 +1036,7 @@ def permute_segments( """Permute the segments of an operand.""" operand = _canonicalize_index("operand", operand, self.num_operands) new_operands = list(self.operands) - new_operands[operand] = stp.Operand( + new_operands[operand] = cue.SegmentedOperand( segments=[self.operands[operand][i] for i in perm], subscripts=self.operands[operand].subscripts, ) @@ -1382,7 +1384,7 @@ def empty(D, oid, sid): ) for oid in range(D.num_operands): - operand = stp.Operand( + operand = cue.SegmentedOperand( subscripts=D.operands[oid].subscripts, segments=[ segment @@ -1461,7 +1463,7 @@ def flatten_modes( rm_shape_per_operand.append(rm_shapes) new_operands.append( - stp.Operand(segments=new_segments, subscripts=new_subscripts) + cue.SegmentedOperand(segments=new_segments, subscripts=new_subscripts) ) def ravel_multi_index(indices: tuple[int, ...], shape: tuple[int, ...]) -> int: diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py index d401e67f..e425ef8f 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py @@ -86,7 +86,7 @@ def flatten(x: jax.Array, axis: int) -> jax.Array: def sum_cat_list_list( - operand: cue.segmented_tensor_product.Operand, + operand: cue.SegmentedOperand, list_list: list[list[jax.Array]] | jax.Array, batch_shape: tuple[int, ...], dtype: jnp.dtype, diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 2f9503a7..bd257dd4 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -19,7 +19,7 @@ import torch import torch.fx -import cuequivariance.segmented_tensor_product as stp +import cuequivariance as cue import cuequivariance_torch as cuet logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ class SymmetricTensorProduct(torch.nn.Module): def __init__( self, - descriptors: list[stp.SegmentedTensorProduct], + descriptors: list[cue.SegmentedTensorProduct], *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, @@ -47,10 +47,13 @@ def __init__( self.descriptors = descriptors descriptors = [ - stp.SegmentedTensorProduct( - operands=(stp.Operand.empty_segments(1),) + d.operands, + cue.SegmentedTensorProduct( + operands=(cue.SegmentedOperand.empty_segments(1),) + d.operands, paths=[ - stp.Path((0,) + path.indices, path.coefficients) for path in d.paths + cue.segmented_tensor_product.Path( + (0,) + path.indices, path.coefficients + ) + for path in d.paths ], coefficient_subscripts=d.coefficient_subscripts, ) @@ -99,7 +102,7 @@ class IWeightedSymmetricTensorProduct(torch.nn.Module): Parameters ---------- - descriptors : list[stp.SegmentedTensorProduct] + descriptors : list[cue.SegmentedTensorProduct] The list of SegmentedTensorProduct descriptors math_dtype : torch.dtype, optional The data type of the coefficients and calculations @@ -107,7 +110,7 @@ class IWeightedSymmetricTensorProduct(torch.nn.Module): def __init__( self, - descriptors: list[stp.SegmentedTensorProduct], + descriptors: list[cue.SegmentedTensorProduct], *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, @@ -197,7 +200,7 @@ def forward( return self.f(x0, i0, x1) -def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): +def _check_descriptors(descriptors: list[cue.SegmentedTensorProduct]): if len(descriptors) == 0: raise ValueError("stps must contain at least one STP.") @@ -218,7 +221,7 @@ def _check_descriptors(descriptors: list[stp.SegmentedTensorProduct]): class CUDAKernel(torch.nn.Module): def __init__( self, - ds: list[stp.SegmentedTensorProduct], + ds: list[cue.SegmentedTensorProduct], device: Optional[torch.device], math_dtype: torch.dtype, ): @@ -239,7 +242,7 @@ def __init__( if len({d.operands[-1].num_segments for d in ds}) != 1: raise ValueError("All STPs must have the same number of segments in x2.") - def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: + def f(d: cue.SegmentedTensorProduct) -> cue.SegmentedTensorProduct: d = d.move_operand(0, -2) d = d.flatten_coefficient_modes(force=True) d = d.flatten_modes( @@ -336,7 +339,7 @@ def forward( class FallbackImpl(torch.nn.Module): def __init__( self, - stps: list[stp.SegmentedTensorProduct], + stps: list[cue.SegmentedTensorProduct], device: Optional[torch.device], math_dtype: Optional[torch.dtype], ): diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 3bc312d4..5c6b2039 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -145,7 +145,7 @@ def _transpose_segments_fx(segments: list[tuple[int, int]]) -> torch.nn.Module: x = torch.fx.Proxy(graph.placeholder("input"), tracer) outputs = [] - source = cue.segmented_tensor_product.Operand(subscripts="ij", segments=segments) + source = cue.SegmentedOperand(subscripts="ij", segments=segments) for sl, (u, v) in zip(source.segment_slices(), source.segments): outputs += [ x[..., sl] From cb23e6f2d91f5d3248f7860b1757a75f7e348ee0 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 15:04:05 +0100 Subject: [PATCH 055/107] SegmentedOperand compatible with dataclasses.replace --- .../cuequivariance/segmented_operand.py | 33 ++++++++++--------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_operand.py b/cuequivariance/cuequivariance/segmented_operand.py index ab3d7770..d5e772bb 100644 --- a/cuequivariance/cuequivariance/segmented_operand.py +++ b/cuequivariance/cuequivariance/segmented_operand.py @@ -27,8 +27,8 @@ class SegmentedOperand: """A tensor product operand. It is a list of segments and subscripts.""" - _segments: list[tuple[int, ...]] subscripts: stp.Subscripts + segments: tuple[tuple[int, ...]] _dims: dict[str, set[int]] def __init__( @@ -42,7 +42,7 @@ def __init__( if segments is None: segments = [] - object.__setattr__(self, "_segments", segments) + object.__setattr__(self, "segments", tuple(segments)) if _dims is None: _dims = dict() @@ -88,8 +88,20 @@ def insert_segment( f"segment has {len(segment)} dimensions, expected {len(self.subscripts)} for subscripts {self.subscripts}." ) + if index < 0: + index = len(self.segments) + index + + if index < 0 or index > len(self.segments): + raise ValueError( + f"index {index} is out of bounds for segments {self.segments}." + ) + segment = tuple(int(d) for d in segment) - self._segments.insert(index, segment) + object.__setattr__( + self, + "segments", + self.segments[:index] + (segment,) + self.segments[index:], + ) for m, d in zip(self.subscripts, segment): self._dims.setdefault(m, set()).add(d) @@ -100,12 +112,14 @@ def add_segment(self, segment: Union[tuple[int, ...], dict[str, int]]) -> int: return len(self.segments) - 1 def __hash__(self) -> int: - return hash((tuple(self.segments), self.subscripts)) + return hash((self.segments, self.subscripts)) def __eq__(self, other: SegmentedOperand) -> bool: + assert isinstance(other, SegmentedOperand) return self.subscripts == other.subscripts and self.segments == other.segments def __lt__(self, other: SegmentedOperand) -> bool: + assert isinstance(other, SegmentedOperand) return (self.subscripts, self.segments) < (other.subscripts, other.segments) def __repr__(self) -> str: @@ -121,11 +135,6 @@ def __len__(self) -> int: def __iter__(self): return iter(self.segments) - @property - def segments(self) -> tuple[tuple[int, ...], ...]: - """The segments of the operand.""" - return tuple(self._segments) - @property def num_segments(self) -> int: """The number of segments in the operand.""" @@ -161,12 +170,6 @@ def get_dims(self, m: str) -> set[int]: """Return the dimensions for a given channel.""" return self._dims.get(m, set()).copy() - def get_segment_shape( - self, dims: dict[str, int], *, default: int = -1 - ) -> tuple[int, ...]: - """Return the shape of a potential segment.""" - return tuple(dims.get(ch, default) for ch in self.subscripts) - def transpose_modes( self, subscripts: Union[str, Sequence[str], Sequence[int]] ) -> SegmentedOperand: From f742a02d3ea5f58ea6d76c5171e88f6c0b47965c Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 15:16:07 +0100 Subject: [PATCH 056/107] Refactor SegmentedPolynomial to use tuple instead of list for tensor_products --- .../cuequivariance/segmented_polynomial.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 73419e25..8af2b4d5 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -42,7 +42,7 @@ class SegmentedPolynomial: num_inputs: int num_outputs: int - tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] + tensor_products: tuple[tuple[cue.Operation, cue.SegmentedTensorProduct], ...] def __init__( self, @@ -57,18 +57,18 @@ def __init__( object.__setattr__(self, "num_inputs", num_inputs) object.__setattr__(self, "num_outputs", num_outputs) - object.__setattr__(self, "tensor_products", sorted(tensor_products)) + object.__setattr__(self, "tensor_products", tuple(sorted(tensor_products))) @classmethod def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): return cls( stp.num_operands - 1, 1, - [(cue.Operation(tuple(range(stp.num_operands))), stp)], + ((cue.Operation(tuple(range(stp.num_operands))), stp),), ) def __hash__(self) -> int: - return hash((self.num_inputs, self.num_outputs, tuple(self.tensor_products))) + return hash((self.num_inputs, self.num_outputs, self.tensor_products)) def __eq__(self, value) -> bool: assert isinstance(value, SegmentedPolynomial) @@ -94,7 +94,7 @@ def __mul__(self, factor: float) -> SegmentedPolynomial: return SegmentedPolynomial( self.num_inputs, self.num_outputs, - [(ope, factor * stp) for ope, stp in self.tensor_products], + tuple((ope, factor * stp) for ope, stp in self.tensor_products), ) def __rmul__(self, factor: float) -> SegmentedPolynomial: @@ -206,9 +206,9 @@ def map_tensor_products( ], ) -> SegmentedPolynomial: new_tensor_products = [f(ope, stp) for ope, stp in self.tensor_products] - new_tensor_products = [ + new_tensor_products = tuple( ope_stp for ope_stp in new_tensor_products if ope_stp is not None - ] + ) return SegmentedPolynomial( self.num_inputs, self.num_outputs, new_tensor_products ) @@ -219,7 +219,7 @@ def fuse_stps(self) -> SegmentedPolynomial: self.tensor_products, key=lambda x: (x[0], x[1].operands, x[1].coefficient_subscripts), ) - new_tensor_products = [ + new_tensor_products = tuple( ( ope, cue.SegmentedTensorProduct( @@ -229,7 +229,7 @@ def fuse_stps(self) -> SegmentedPolynomial: ).consolidate_paths(), ) for (ope, operands, coefficient_subscripts), elements in groups - ] + ) return SegmentedPolynomial( self.num_inputs, self.num_outputs, new_tensor_products ) From 6aec8d16eeaa898146db7d342c8059eeed8a7239 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 17:17:43 +0100 Subject: [PATCH 057/107] Big! remove the field subscripts: stp.Subscripts from SegmentedOperand --- .../descriptors/transposition.py | 6 +- .../equivariant_tensor_product.py | 4 +- .../cuequivariance/segmented_operand.py | 114 +++------- .../cuequivariance/segmented_polynomial.py | 20 +- .../segmented_tensor_product/dot.py | 19 +- .../segmented_tensor_product.py | 205 ++++++++---------- .../segmented_polynomial_vanilla_impl.py | 5 +- 7 files changed, 166 insertions(+), 207 deletions(-) diff --git a/cuequivariance/cuequivariance/descriptors/transposition.py b/cuequivariance/cuequivariance/descriptors/transposition.py index 749fa0d8..a7ae0b0e 100644 --- a/cuequivariance/cuequivariance/descriptors/transposition.py +++ b/cuequivariance/cuequivariance/descriptors/transposition.py @@ -20,9 +20,9 @@ def transpose( ) -> cue.EquivariantPolynomial: """Transpose the irreps layout of a tensor.""" d = cue.SegmentedTensorProduct( - operands=[ - cue.SegmentedOperand(subscripts="ui" if source == cue.mul_ir else "iu"), - cue.SegmentedOperand(subscripts="ui" if target == cue.mul_ir else "iu"), + operands_and_subscripts=[ + (cue.SegmentedOperand(ndim=2), "ui" if source == cue.mul_ir else "iu"), + (cue.SegmentedOperand(ndim=2), "ui" if target == cue.mul_ir else "iu"), ] ) for mul, ir in irreps: diff --git a/cuequivariance/cuequivariance/equivariant_tensor_product.py b/cuequivariance/cuequivariance/equivariant_tensor_product.py index eed6581f..2bac0331 100644 --- a/cuequivariance/cuequivariance/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/equivariant_tensor_product.py @@ -413,9 +413,9 @@ def stack( for oid in range(d.num_operands): if stacked[ii[oid]]: for e_ in reversed(es[:eid]): - d.insert_segments(oid, 0, e_.stp_operand(ii[oid])) + d.insert_segments(oid, 0, e_.stp_operand(ii[oid]).segments) for e_ in es[eid + 1 :]: - d.insert_segments(oid, -1, e_.stp_operand(ii[oid])) + d.insert_segments(oid, -1, e_.stp_operand(ii[oid]).segments) if d.num_operands not in new_ds: new_ds[d.num_operands] = d diff --git a/cuequivariance/cuequivariance/segmented_operand.py b/cuequivariance/cuequivariance/segmented_operand.py index d5e772bb..4d3b1341 100644 --- a/cuequivariance/cuequivariance/segmented_operand.py +++ b/cuequivariance/cuequivariance/segmented_operand.py @@ -16,29 +16,26 @@ import dataclasses import math -from typing import Optional, Sequence, Union - -from cuequivariance import segmented_tensor_product as stp from .segmented_tensor_product.dimensions_dict import format_dimensions_dict @dataclasses.dataclass(init=False, frozen=True) class SegmentedOperand: - """A tensor product operand. It is a list of segments and subscripts.""" + """A segmented operand is a list of segment's shapes.""" - subscripts: stp.Subscripts + ndim: int segments: tuple[tuple[int, ...]] - _dims: dict[str, set[int]] + _dims: dict[int, set[int]] def __init__( self, *, - subscripts: stp.Subscripts, - segments: Optional[list[tuple[int, ...]]] = None, - _dims: Optional[dict[str, set[int]]] = None, + ndim: int, + segments: list[tuple[int, ...]] | None = None, + _dims: dict[int, set[int]] | None = None, ): - object.__setattr__(self, "subscripts", stp.Subscripts(subscripts)) + object.__setattr__(self, "ndim", ndim) if segments is None: segments = [] @@ -47,45 +44,37 @@ def __init__( if _dims is None: _dims = dict() for segment in self.segments: - for m, d in zip(self.subscripts, segment): - _dims.setdefault(m, set()).add(d) + for i, d in enumerate(segment): + _dims.setdefault(i, set()).add(d) object.__setattr__(self, "_dims", _dims) @classmethod def empty_segments(cls, num_segments: int) -> SegmentedOperand: - """Create an operand with empty subscripts""" - return cls(subscripts="", segments=[()] * num_segments, _dims=dict()) + """Create an operand with ndim=0""" + return cls(ndim=0, segments=[()] * num_segments, _dims=dict()) def assert_valid(self): """Assert that the operand is valid.""" - if not all(m.isalpha() and m.islower() for m in self.subscripts): - raise ValueError(f"subscripts {self.subscripts} is not valid.") - for segment in self.segments: - if not all(isinstance(dim, int) and dim > 0 for dim in segment): - raise ValueError(f"segment {segment} is not valid.") - - if len(segment) != len(self.subscripts): + if len(segment) != self.ndim: raise ValueError( - f"segment {segment} has {len(segment)} dimensions, expected {len(self.subscripts)} for subscripts {self.subscripts}." + f"segment {segment} has {len(segment)} dimensions, expected {self.ndim}." ) - for m, d in zip(self.subscripts, segment): - if d not in self.get_dims(m): + if not all(isinstance(dim, int) and dim > 0 for dim in segment): + raise ValueError(f"segment {segment} is not valid.") + + for i, d in enumerate(segment): + if d not in self.get_dims(i): raise ValueError( - f"dimension {d} not in {m} dimensions {self.get_dims(m)}." + f"dimension {d} not in {i} dimensions {self.get_dims(i)}." ) - def insert_segment( - self, index: int, segment: Union[tuple[int, ...], dict[str, int]] - ): + def insert_segment(self, index: int, segment: tuple[int, ...]): """Insert a segment at a given index.""" - if isinstance(segment, dict): - segment = tuple(segment[m] for m in self.subscripts) - - if len(segment) != len(self.subscripts): + if len(segment) != self.ndim: raise ValueError( - f"segment has {len(segment)} dimensions, expected {len(self.subscripts)} for subscripts {self.subscripts}." + f"segment has {len(segment)} dimensions, expected {self.ndim}." ) if index < 0: @@ -103,28 +92,28 @@ def insert_segment( self.segments[:index] + (segment,) + self.segments[index:], ) - for m, d in zip(self.subscripts, segment): - self._dims.setdefault(m, set()).add(d) + for i, d in enumerate(segment): + self._dims.setdefault(i, set()).add(d) - def add_segment(self, segment: Union[tuple[int, ...], dict[str, int]]) -> int: + def add_segment(self, segment: tuple[int, ...]) -> int: """Add a segment to the operand.""" self.insert_segment(len(self.segments), segment) return len(self.segments) - 1 def __hash__(self) -> int: - return hash((self.segments, self.subscripts)) + return hash((self.ndim, self.segments)) def __eq__(self, other: SegmentedOperand) -> bool: assert isinstance(other, SegmentedOperand) - return self.subscripts == other.subscripts and self.segments == other.segments + return self.ndim == other.ndim and self.segments == other.segments def __lt__(self, other: SegmentedOperand) -> bool: assert isinstance(other, SegmentedOperand) - return (self.subscripts, self.segments) < (other.subscripts, other.segments) + return (self.ndim, self.segments) < (other.ndim, other.segments) def __repr__(self) -> str: dims = format_dimensions_dict(self.get_dimensions_dict()) - return f"Operand(subscripts={self.subscripts} num_segments={self.num_segments} {dims})" + return f"Operand(ndim={self.ndim} num_segments={self.num_segments} {dims})" def __getitem__(self, index: int) -> tuple[int, ...]: return self.segments[index] @@ -148,11 +137,6 @@ def size(self) -> int: return sum(math.prod(segment) for segment in self.segments) - @property - def ndim(self) -> int: - """The number of segment dimensions.""" - return len(self.subscripts) - def segment_slices(self) -> list[slice]: """Return slice object for each segment.""" offset = 0 @@ -162,39 +146,13 @@ def segment_slices(self) -> list[slice]: offset += math.prod(segment) return slices - def get_dimensions_dict(self) -> dict[str, set[int]]: + def get_dimensions_dict(self) -> dict[int, set[int]]: """Return a dictionary of dimensions for each channel.""" return self._dims.copy() - def get_dims(self, m: str) -> set[int]: + def get_dims(self, i: int) -> set[int]: """Return the dimensions for a given channel.""" - return self._dims.get(m, set()).copy() - - def transpose_modes( - self, subscripts: Union[str, Sequence[str], Sequence[int]] - ) -> SegmentedOperand: - """Transpose the channels of the operand.""" - if not isinstance(subscripts, Sequence): - raise TypeError("channels must be a sequence.") - - if isinstance(subscripts[0], str): - subscripts = "".join(subscripts) - subscripts = stp.Subscripts.complete_wildcards(subscripts, self.subscripts) - subscripts = [self.subscripts.index(m) for m in subscripts] - - subscripts: list[int] = list(subscripts) - - if len(subscripts) != len(self.subscripts): - raise ValueError( - f"channels has {len(subscripts)} dimensions, expected {len(self.subscripts)} for subscripts {self.subscripts}." - ) - - segments = [tuple(segment[i] for i in subscripts) for segment in self.segments] - return SegmentedOperand( - subscripts="".join(self.subscripts[i] for i in subscripts), - segments=segments, - _dims=self._dims, - ) + return self._dims.get(i, set()).copy() def all_same_segment_shape(self) -> bool: """Check if all segments have the same shape. Returns False if there are no segments.""" @@ -215,10 +173,10 @@ def segment_size(self) -> int: return math.prod(self.segments[0]) def __add__(self, other: SegmentedOperand) -> SegmentedOperand: - if self.subscripts != other.subscripts: - raise ValueError("subscripts do not match.") + if self.ndim != other.ndim: + raise ValueError("ndim do not match.") return SegmentedOperand( - subscripts=self.subscripts, + ndim=self.ndim, segments=self.segments + other.segments, - _dims={m: self.get_dims(m) | other.get_dims(m) for m in self.subscripts}, + _dims={i: self.get_dims(i) | other.get_dims(i) for i in range(self.ndim)}, ) diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 8af2b4d5..0fe250aa 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -215,20 +215,32 @@ def map_tensor_products( def fuse_stps(self) -> SegmentedPolynomial: """Fuse segmented tensor products with identical operations and operands.""" + poly = self.map_tensor_products( + lambda ope, stp: (ope, stp.canonicalize_subscripts()) + ) + groups = itertools.groupby( - self.tensor_products, - key=lambda x: (x[0], x[1].operands, x[1].coefficient_subscripts), + poly.tensor_products, + key=lambda x: ( + x[0], + x[1].operands_and_subscripts, + x[1].coefficient_subscripts, + ), ) new_tensor_products = tuple( ( ope, cue.SegmentedTensorProduct( - operands=operands, + operands_and_subscripts=operands_and_subscripts, coefficient_subscripts=coefficient_subscripts, paths=[path for _, stp in elements for path in stp.paths], ).consolidate_paths(), ) - for (ope, operands, coefficient_subscripts), elements in groups + for ( + ope, + operands_and_subscripts, + coefficient_subscripts, + ), elements in groups ) return SegmentedPolynomial( self.num_inputs, self.num_outputs, new_tensor_products diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py b/cuequivariance/cuequivariance/segmented_tensor_product/dot.py index 4165682c..0d43669b 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/dot.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import annotations +import dataclasses import itertools from typing import Any, Sequence @@ -49,7 +50,7 @@ def dot( The segmented tensor product resulting from the dot product. """ for oidx, oidy in contraction: - if x.operands[oidx] != y.operands[oidy]: + if x.operands_and_subscripts[oidx] != y.operands_and_subscripts[oidy]: raise ValueError("Operands to contract must be the same.") x_keep = [ @@ -64,7 +65,11 @@ def dot( x.coefficient_subscripts + y.coefficient_subscripts ) ) - d.set_operands([x.operands[i] for i in x_keep] + [y.operands[i] for i in y_keep]) + d = dataclasses.replace( + d, + operands_and_subscripts=[x.operands_and_subscripts[i] for i in x_keep] + + [y.operands_and_subscripts[i] for i in y_keep], + ) formula = f"{x.coefficient_subscripts} , {y.coefficient_subscripts} -> {d.coefficient_subscripts}" @@ -127,7 +132,7 @@ def trace( mapping = { chj: chi for i, j in contraction - for chi, chj in zip(d.operands[i].subscripts, d.operands[j].subscripts) + for chi, chj in zip(d.subscripts.operands[i], d.subscripts.operands[j]) } f = lambda subscripts: "".join(mapping.get(ch, ch) for ch in subscripts) # noqa @@ -136,12 +141,8 @@ def trace( dout = cue.SegmentedTensorProduct( coefficient_subscripts=coefficients_subscripts_compressed, - operands=[ - cue.SegmentedOperand( - subscripts=f(ope.subscripts), - segments=ope.segments, - ) - for ope in (d.operands[i] for i in keep) + operands_and_subscripts=[ + (ope, f(ss)) for ope, ss in (d.operands_and_subscripts[i] for i in keep) ], ) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index f0bb175e..758ea429 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -56,7 +56,7 @@ class SegmentedTensorProduct: .. rubric:: Methods """ - operands: tuple[cue.SegmentedOperand, ...] + operands_and_subscripts: tuple[tuple[cue.SegmentedOperand, stp.Subscripts], ...] paths: tuple[stp.Path, ...] coefficient_subscripts: str @@ -66,31 +66,36 @@ class SegmentedTensorProduct: def __init__( self, *, - operands: Optional[list[cue.SegmentedOperand]] = None, + operands_and_subscripts: Optional[ + list[tuple[cue.SegmentedOperand, stp.Subscripts]] + ] = None, paths: Optional[list[stp.Path]] = None, coefficient_subscripts: str = "", ): - if operands is None: - operands = [] + if operands_and_subscripts is None: + operands_and_subscripts = [] if paths is None: paths = [] object.__setattr__( - self, "operands", tuple(copy.deepcopy(ope) for ope in operands) + self, + "operands_and_subscripts", + tuple( + (copy.deepcopy(ope), stp.Subscripts(ss)) + for ope, ss in operands_and_subscripts + ), ) object.__setattr__(self, "paths", tuple(copy.deepcopy(path) for path in paths)) object.__setattr__(self, "coefficient_subscripts", coefficient_subscripts) - def set_operands(self, operands: list[cue.SegmentedOperand]): - object.__setattr__( - self, "operands", tuple(copy.deepcopy(ope) for ope in operands) - ) - def set_operand(self, oid: int, operand: cue.SegmentedOperand): + assert oid < len(self.operands_and_subscripts) object.__setattr__( self, - "operands", - self.operands[:oid] + (copy.deepcopy(operand),) + self.operands[oid + 1 :], + "operands_and_subscripts", + self.operands_and_subscripts[:oid] + + ((copy.deepcopy(operand), self.operands_and_subscripts[oid][1]),) + + self.operands_and_subscripts[oid + 1 :], ) def set_paths(self, paths: list[stp.Path]): @@ -105,6 +110,10 @@ def insert_path_(self, path_index: int, path: stp.Path): # until here. Below we use dataclasses.replace or the setters to modify the attributes + @property + def operands(self) -> tuple[cue.SegmentedOperand, ...]: + return tuple(ope for ope, _ in self.operands_and_subscripts) + def assert_valid(self): assert stp.Subscripts.is_valid(self.subscripts) @@ -155,11 +164,13 @@ def from_subscripts(cls, subscripts: stp.Subscripts) -> SegmentedTensorProduct: """ subscripts = stp.Subscripts(subscripts) operands = [ - cue.SegmentedOperand(subscripts=operand) for operand in subscripts.operands + cue.SegmentedOperand(ndim=len(operand)) for operand in subscripts.operands ] return cls( - operands=operands, paths=[], coefficient_subscripts=subscripts.coefficients + operands_and_subscripts=list(zip(operands, subscripts.operands)), + paths=[], + coefficient_subscripts=subscripts.coefficients, ) @classmethod @@ -172,7 +183,9 @@ def empty_segments(cls, num_segments: list[int]) -> SegmentedTensorProduct: ,, sizes=2,3,4 num_segments=2,3,4 num_paths=0 """ return cls( - operands=[cue.SegmentedOperand.empty_segments(num) for num in num_segments], + operands_and_subscripts=[ + (cue.SegmentedOperand.empty_segments(num), "") for num in num_segments + ], paths=[], coefficient_subscripts="", ) @@ -253,7 +266,7 @@ def __repr__(self) -> str: @property def num_operands(self) -> int: """Number of operands.""" - return len(self.operands) + return len(self.operands_and_subscripts) @property def num_paths(self) -> int: @@ -264,7 +277,7 @@ def num_paths(self) -> int: def subscripts(self) -> stp.Subscripts: """Subscripts of the tensor product.""" return stp.Subscripts.from_operands( - [operand.subscripts for operand in self.operands], + [subscripts for _, subscripts in self.operands_and_subscripts], self.coefficient_subscripts, ) @@ -336,9 +349,9 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str: """ out = f"{self}" dims = self.get_dimensions_dict() - for oid, operand in enumerate(self.operands): - out += f"\noperand #{oid} subscripts={operand.subscripts}" - for i, ch in enumerate(operand.subscripts): + for oid, (operand, subscripts) in enumerate(self.operands_and_subscripts): + out += f"\noperand #{oid} subscripts={subscripts}" + for i, ch in enumerate(subscripts): if len(dims[ch]) == 1: out += f"\n | {ch}: [{operand.segments[0][i]}] * {len(operand.segments)}" else: @@ -396,15 +409,17 @@ def to_dict(self, extended: bool = False) -> dict[str, Any]: "coefficient_subscripts": self.coefficient_subscripts, "operands": [ { - "subscripts": ope.subscripts, + "subscripts": ss, "segments": ope.segments, "size": ope.size, "segment_offsets": [sl.start for sl in slices], "segment_sizes": [sl.stop - sl.start for sl in slices], "flops": self.flops(oid), } - for oid, ope, slices in zip( - range(self.num_operands), self.operands, segment_slices + for oid, (ope, ss), slices in zip( + range(self.num_operands), + self.operands_and_subscripts, + segment_slices, ) ], "memory": self.memory([1] * self.num_operands), @@ -439,9 +454,9 @@ def to_base64(self, extended: bool = False) -> str: def get_dimensions_dict(self) -> dict[str, set[int]]: """Get the dimensions of the tensor product.""" dims: dict[str, set[int]] = {ch: set() for ch in self.subscripts.modes()} - for operand in self.operands: - for m, dd in operand.get_dimensions_dict().items(): - dims[m].update(dd) + for operand, subscripts in self.operands_and_subscripts: + for i, dd in operand.get_dimensions_dict().items(): + dims[subscripts[i]].update(dd) # Note: no need to go through the coefficients since must be contracted with the operands return dims @@ -475,7 +490,7 @@ def get_path_dimensions_dict( m: {d} for m, d in zip(self.coefficient_subscripts, path.coefficients.shape) } for oid, sid in enumerate(path.indices): - for m, d in zip(self.operands[oid].subscripts, self.operands[oid][sid]): + for m, d in zip(self.subscripts.operands[oid], self.operands[oid][sid]): dims.setdefault(m, set()).add(d) if returns_sets: @@ -583,9 +598,9 @@ def flops( d = self.move_operand_last(operand) subscripts = ( d.coefficient_subscripts - + "".join("," + operand.subscripts for operand in d.operands[:-1]) + + "".join("," + ss for ss in d.subscripts.operands[:-1]) + "->" - + d.operands[-1].subscripts + + d.subscripts.operands[-1] ) @functools.lru_cache(maxsize=None) @@ -651,7 +666,7 @@ def insert_path( dims.setdefault(m, set()).add(d) for oid, s in enumerate(segments): - subscripts = self.operands[oid].subscripts + subscripts = self.subscripts.operands[oid] if isinstance(s, int): if not ( -self.operands[oid].num_segments @@ -745,64 +760,26 @@ def add_path( """ return self.insert_path(len(self.paths), *segments, c=c, dims=dims) - def insert_segment( - self, - operand: int, - sid: int, - segment: Union[tuple[int, ...], dict[str, int]], - ): - """Insert a segment at a specific index.""" - operand = _canonicalize_index("operand", operand, self.num_operands) - sid = _canonicalize_index("sid", sid, self.operands[operand].num_segments + 1) - - self.operands[operand].insert_segment(sid, segment) - self.set_paths( - [ - stp.Path( - [ - s if s < sid or oid != operand else s + 1 - for oid, s in enumerate(path.indices) - ], - path.coefficients, - ) - for path in self.paths - ] - ) - - def insert_segments( - self, - operand: int, - sid: int, - segments: Union[list[tuple[int, ...]], cue.SegmentedOperand], - ): + def insert_segments(self, operand: int, sid: int, segments: list[tuple[int, ...]]): """Insert segments at a specific index.""" operand = _canonicalize_index("operand", operand, self.num_operands) sid = _canonicalize_index("sid", sid, self.operands[operand].num_segments + 1) - if not isinstance(segments, cue.SegmentedOperand): - segments = cue.SegmentedOperand( - subscripts=self.operands[operand].subscripts, segments=segments - ) - o = self.operands[operand] - if o.subscripts != segments.subscripts: - raise ValueError( - f"expected segments subscripts {segments.subscripts} to match operand subscripts {o.subscripts}." - ) - + n = cue.SegmentedOperand(ndim=o.ndim, segments=segments) self.set_operand( operand, cue.SegmentedOperand( - subscripts=o.subscripts, - segments=o.segments[:sid] + segments.segments + o.segments[sid:], - _dims={m: o.get_dims(m) | segments.get_dims(m) for m in o.subscripts}, + ndim=o.ndim, + segments=o.segments[:sid] + n.segments + o.segments[sid:], + _dims={m: o.get_dims(m) | n.get_dims(m) for m in range(o.ndim)}, ), ) self.set_paths( [ stp.Path( [ - s if s < sid or oid != operand else s + segments.num_segments + s if s < sid or oid != operand else s + n.num_segments for oid, s in enumerate(path.indices) ], path.coefficients, @@ -815,6 +792,8 @@ def add_segment( self, operand: int, segment: Union[tuple[int, ...], dict[str, int]] ) -> int: """Add a segment to the descriptor.""" + if isinstance(segment, dict): + segment = tuple(segment[m] for m in self.subscripts.operands[operand]) return self.operands[operand].add_segment(segment) def add_segments( @@ -953,8 +932,8 @@ def add_or_transpose_modes( d = SegmentedTensorProduct.from_subscripts(subscripts) for oid in range(self.num_operands): old, new = ( - self.operands[oid].subscripts, - d.operands[oid].subscripts, + self.subscripts.operands[oid], + d.subscripts.operands[oid], ) perm = [old.index(ch) if ch in old else ch for ch in new] for sid in range(self.operands[oid].num_segments): @@ -1008,7 +987,7 @@ def permute_operands(self, perm: tuple[int, ...]) -> SegmentedTensorProduct: assert set(perm) == set(range(self.num_operands)) return dataclasses.replace( self, - operands=[self.operands[i] for i in perm], + operands_and_subscripts=[self.operands_and_subscripts[i] for i in perm], paths=[path.permute_operands(perm) for path in self.paths], ) @@ -1037,8 +1016,8 @@ def permute_segments( operand = _canonicalize_index("operand", operand, self.num_operands) new_operands = list(self.operands) new_operands[operand] = cue.SegmentedOperand( + ndim=self.operands[operand].ndim, segments=[self.operands[operand][i] for i in perm], - subscripts=self.operands[operand].subscripts, ) new_paths = [ stp.Path( @@ -1050,7 +1029,11 @@ def permute_segments( ) for path in self.paths ] - return dataclasses.replace(self, operands=new_operands, paths=new_paths) + return dataclasses.replace( + self, + operands_and_subscripts=list(zip(new_operands, self.subscripts.operands)), + paths=new_paths, + ) def sort_paths( self, operands_ordering: Optional[Union[int, Sequence[int]]] = None @@ -1112,7 +1095,7 @@ def filter_shape(shape: tuple[int, ...], subscripts: str) -> tuple[int, ...]: d.add_segments( oid, [ - filter_shape(segment, operand.subscripts) + filter_shape(segment, self.subscripts.operands[oid]) for segment in operand.segments ], ) @@ -1240,7 +1223,7 @@ def normalize_paths_for_operand(self, operand: int) -> SegmentedTensorProduct: # ^^^^^^^-- (0, "uvw") (1, "iu") (2, "jv") (3, "kw") (4, "ijk") if i != operand # e.g. discard (3, "kw") for u in subscripts # u v w i u j v i j k - if u not in self.operands[operand].subscripts # e.g. discard k and w + if u not in self.subscripts.operands[operand] # e.g. discard k and w } # e.g. {u, v, i, j} for path in self.paths: @@ -1385,15 +1368,15 @@ def empty(D, oid, sid): for oid in range(D.num_operands): operand = cue.SegmentedOperand( - subscripts=D.operands[oid].subscripts, + ndim=D.operands[oid].ndim, segments=[ segment for sid, segment in enumerate(D.operands[oid].segments) if not empty(D, oid, sid) ], _dims={ - m: {d for d in dd if d > 0} - for m, dd in D.operands[oid]._dims.items() + i: {d for d in dd if d > 0} + for i, dd in D.operands[oid]._dims.items() }, ) D.set_operand(oid, operand) @@ -1420,8 +1403,7 @@ def flatten_modes( if force: extra = "" for m in modes: - for ope in self.operands: - sm = ope.subscripts + for _, sm in self.operands_and_subscripts: if m in sm: extra += sm[: sm.index(m)] modes += extra @@ -1431,16 +1413,16 @@ def flatten_modes( return self pattern = re.compile(rf"^([{modes}]*)([^{modes}]*)$") - new_operands = [] + new_operands_and_subscripts = [] offsets_per_operand = [] rm_shape_per_operand = [] rm_modes_per_operand = [] - for operand in self.operands: - ma = pattern.match(operand.subscripts) + for operand, subscripts in self.operands_and_subscripts: + ma = pattern.match(subscripts) if ma is None: raise ValueError( f"expected modes {modes} to be at the beginning of the segment subscripts." - f" Got {operand.subscripts}." + f" Got {subscripts}." ) rm_modes, new_subscripts = ma.groups() @@ -1462,8 +1444,13 @@ def flatten_modes( offsets_per_operand.append(offsets) rm_shape_per_operand.append(rm_shapes) - new_operands.append( - cue.SegmentedOperand(segments=new_segments, subscripts=new_subscripts) + new_operands_and_subscripts.append( + ( + cue.SegmentedOperand( + ndim=len(new_subscripts), segments=new_segments + ), + new_subscripts, + ) ) def ravel_multi_index(indices: tuple[int, ...], shape: tuple[int, ...]) -> int: @@ -1540,7 +1527,7 @@ def make_new_path( ) d = dataclasses.replace( self, - operands=new_operands, + operands_and_subscripts=new_operands_and_subscripts, paths=new_paths, coefficient_subscripts="".join( ch for ch in self.coefficient_subscripts if ch not in modes @@ -1574,11 +1561,11 @@ def consolidate_modes(self, modes: Optional[str] = None) -> SegmentedTensorProdu # look for opportunities to consolidate for m in self.subscripts.modes(): neighbors: set[str] = set() - for operand in self.operands: - if m in operand.subscripts: - i = operand.subscripts.index(m) - if i < len(operand.subscripts) - 1: - neighbors.add(operand.subscripts[i + 1]) + for _, subscripts in self.operands_and_subscripts: + if m in subscripts: + i = subscripts.index(m) + if i < len(subscripts) - 1: + neighbors.add(subscripts[i + 1]) else: neighbors.add(".") if len(neighbors) != 1: @@ -1589,13 +1576,13 @@ def consolidate_modes(self, modes: Optional[str] = None) -> SegmentedTensorProdu # Zuvw_Ziu_Zjv_Zkw+ijk ok = True - for operand in self.operands: - if m in operand.subscripts: - if n not in operand.subscripts: + for _, subscripts in self.operands_and_subscripts: + if m in subscripts: + if n not in subscripts: ok = False break else: - if n in operand.subscripts: + if n in subscripts: ok = False break @@ -1613,11 +1600,11 @@ def consolidate_modes(self, modes: Optional[str] = None) -> SegmentedTensorProdu return self - for operand in self.operands: - if modes not in operand.subscripts: - if any(ch in operand.subscripts for ch in modes): + for _, subscripts in self.operands_and_subscripts: + if modes not in subscripts: + if any(ch in subscripts for ch in modes): raise ValueError( - f"expected {modes} to be contiguous in the subscripts {operand.subscripts}." + f"expected {modes} to be contiguous in the subscripts {subscripts}." ) d = self @@ -1642,10 +1629,10 @@ def _consolidate_pair_of_modes(self, m: str, n: str) -> SegmentedTensorProduct: d1 = SegmentedTensorProduct.from_subscripts(d0.subscripts.replace(m + n, m)) - for oid, operand in enumerate(d0.operands): + for oid, (operand, subscripts) in enumerate(d0.operands_and_subscripts): for segment in operand: - if m in operand.subscripts: - i = operand.subscripts.index(m) + if m in subscripts: + i = subscripts.index(m) segment = list(segment) segment[i] *= segment[i + 1] segment.pop(i + 1) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py index e425ef8f..4aadbca4 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py @@ -164,6 +164,7 @@ def tp_list_list( d = d.sort_paths(-1) pids = d.compressed_path_segment(-1) ope_out = d.operands[-1] + ss_out = d.subscripts.operands[-1] def ein( coefficients: jax.Array, segments: list[jax.Array], mode: str = "normal" @@ -181,14 +182,14 @@ def ein( term_out = ( "".join(m for m, s in zip(batch_modes, out_batch_shape) if s > 1) + path_out - + ope_out.subscripts + + ss_out ) terms = [path_in + d.coefficient_subscripts] + terms_in + [term_out] formula = ",".join(terms[:-1]) + "->" + terms[-1] segments = [x.astype(coefficients.dtype) for x in segments] segment = jnp.einsum(formula, coefficients, *segments, precision=precision) - segment_shape = segment.shape[segment.ndim - len(ope_out.subscripts) :] + segment_shape = segment.shape[segment.ndim - len(ss_out) :] if mode == "vectorized": num_paths = coefficients.shape[0] From 668b16edd3d1abdf46567495c37754e6cb3448bc Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 3 Mar 2025 17:25:04 +0100 Subject: [PATCH 058/107] fix --- .../cuequivariance_torch/primitives/tensor_product.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index daf668fb..781f1107 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -223,9 +223,7 @@ def _tensor_product_fx( for i in range(num_inputs) ] - operand_subscripts = [ - f"Z{operand.subscripts}" for operand in descriptor.operands - ] + operand_subscripts = [f"Z{ss}" for ss in descriptor.subscripts.operands] formula = ( ",".join([descriptor.coefficient_subscripts] + operand_subscripts[:-1]) From 3da13f36319924632b41b6970d6bc71eafc58411 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 00:31:00 +0100 Subject: [PATCH 059/107] SegmentedPolynomial has known operands from the start --- .../descriptors/spherical_harmonics_.py | 6 +- .../descriptors/symmetric_contractions.py | 11 +- .../cuequivariance/equivariant_polynomial.py | 11 +- .../mace/symmetric_contractions.py | 10 +- .../cuequivariance/segmented_operand.py | 20 ++- .../cuequivariance/segmented_polynomial.py | 160 ++++++++++-------- .../tests/segmented_polynomial_test.py | 107 ++++++++---- .../primitives/segmented_polynomial.py | 6 +- .../primitives/segmented_polynomial_test.py | 12 +- 9 files changed, 224 insertions(+), 119 deletions(-) diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py index 4754d44f..82cf9c13 100644 --- a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py @@ -65,7 +65,11 @@ def spherical_harmonics( cue.IrrepsAndLayout(cue.Irreps(ir_vec), cue.ir_mul), cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul), ], - cue.SegmentedPolynomial(1, 1, [(cue.Operation([0] * ell + [1]), d)]), + cue.SegmentedPolynomial( + [cue.SegmentedOperand(ndim=0, segments=[()] * 3)], + [cue.SegmentedOperand(ndim=0, segments=[()] * ir.dim)], + [(cue.Operation([0] * ell + [1]), d)], + ), ) diff --git a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py index da1d0a87..64384e32 100644 --- a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py @@ -70,6 +70,8 @@ def symmetric_contraction( input_operands = range(1, degree + 1) output_operand = degree + 1 + input_operand = cue.SegmentedOperand(ndim=1, segments=[(mul,)] * irreps_in.dim) + if degree == 0: d = stp.SegmentedTensorProduct.from_subscripts("i_i") for _, ir in irreps_out: @@ -103,11 +105,18 @@ def symmetric_contraction( d = d.flatten_coefficient_modes() d = d.append_modes_to_all_operands("u", {"u": mul}) + for i in input_operands: + assert d.operands[i] == input_operand + return cue.EquivariantPolynomial( [ cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul), ], - cue.SegmentedPolynomial(2, 1, [(cue.Operation([0] + [1] * degree + [2]), d)]), + cue.SegmentedPolynomial( + [d.operands[0], input_operand], + [d.operands[-1]], + [(cue.Operation([0] + [1] * degree + [2]), d)], + ), ) diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/equivariant_polynomial.py index 39ca89e8..f2310ac1 100644 --- a/cuequivariance/cuequivariance/equivariant_polynomial.py +++ b/cuequivariance/cuequivariance/equivariant_polynomial.py @@ -42,16 +42,15 @@ def __init__(self, operands: list[cue.Rep], polynomial: cue.SegmentedPolynomial) assert isinstance(polynomial, cue.SegmentedPolynomial) object.__setattr__(self, "operands", tuple(operands)) object.__setattr__(self, "polynomial", polynomial) - if ( - len(self.operands) - != self.polynomial.num_inputs + self.polynomial.num_outputs - ): + if len(self.operands) != self.polynomial.num_operands: raise ValueError( f"Number of operands {len(self.operands)} must equal the number of inputs" f" {self.polynomial.num_inputs} plus the number of outputs {self.polynomial.num_outputs}" ) - for rep, size in zip(self.operands, self.polynomial.buffer_sizes): - assert size is None or size == rep.dim + for rep, ope in zip(self.operands, self.polynomial.operands): + assert ope.size == rep.dim, ( + f"{ope} incompatible with {rep}. {ope.size=} != {rep.dim=}" + ) def __hash__(self) -> int: return hash((self.operands, self.polynomial)) diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py index c01597f6..b4959842 100644 --- a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py @@ -146,13 +146,19 @@ def _symmetric_contraction( d = d.flatten_coefficient_modes() d = d.append_modes_to_all_operands("u", {"u": mul}) + + assert d.num_operands >= 3 + [w, x], y = d.operands[:2], d.operands[-1] + return cue.EquivariantPolynomial( [ - cue.IrrepsAndLayout(irreps_in.new_scalars(d.operands[0].size), cue.ir_mul), + cue.IrrepsAndLayout(irreps_in.new_scalars(w.size), cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_in, cue.ir_mul), cue.IrrepsAndLayout(mul * irreps_out, cue.ir_mul), ], - cue.SegmentedPolynomial(2, 1, [(cue.Operation([0] + [1] * degree + [2]), d)]), + cue.SegmentedPolynomial( + [w, x], [y], [(cue.Operation([0] + [1] * degree + [2]), d)] + ), ) diff --git a/cuequivariance/cuequivariance/segmented_operand.py b/cuequivariance/cuequivariance/segmented_operand.py index 4d3b1341..2cd1dd77 100644 --- a/cuequivariance/cuequivariance/segmented_operand.py +++ b/cuequivariance/cuequivariance/segmented_operand.py @@ -53,6 +53,24 @@ def empty_segments(cls, num_segments: int) -> SegmentedOperand: """Create an operand with ndim=0""" return cls(ndim=0, segments=[()] * num_segments, _dims=dict()) + @classmethod + def stack(cls, operands: list[SegmentedOperand]) -> SegmentedOperand: + """Stack a list of operands together.""" + assert len(operands) > 0 + ndim = operands[0].ndim + assert all(ope.ndim == ndim for ope in operands) + + _dims = dict() + for ope in operands: + for i, d in ope.get_dimensions_dict().items(): + _dims.setdefault(i, set()).update(d) + + return cls( + ndim=ndim, + segments=sum([list(ope.segments) for ope in operands], []), + _dims=_dims, + ) + def assert_valid(self): """Assert that the operand is valid.""" for segment in self.segments: @@ -113,7 +131,7 @@ def __lt__(self, other: SegmentedOperand) -> bool: def __repr__(self) -> str: dims = format_dimensions_dict(self.get_dimensions_dict()) - return f"Operand(ndim={self.ndim} num_segments={self.num_segments} {dims})" + return f"Operand(ndim={self.ndim} num_segments={self.num_segments} dims={dims})" def __getitem__(self, index: int) -> tuple[int, ...]: return self.segments[index] diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index 0fe250aa..f0a75f0d 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -34,66 +34,87 @@ class SegmentedPolynomial: input tensors to output tensors through these tensor products. Args: - num_inputs (int): Number of input tensors. - num_outputs (int): Number of output tensors. + inputs (tuple of SegmentedOperand): Input operands. + outputs (tuple of SegmentedOperand): Output operands. tensor_products (list of tuple of Operation and SegmentedTensorProduct): List of operation and tensor product pairs that define the polynomial transformation. """ - num_inputs: int - num_outputs: int + inputs: tuple[cue.SegmentedOperand, ...] + outputs: tuple[cue.SegmentedOperand, ...] tensor_products: tuple[tuple[cue.Operation, cue.SegmentedTensorProduct], ...] def __init__( self, - num_inputs: int, - num_outputs: int, + inputs: tuple[cue.SegmentedOperand, ...], + outputs: tuple[cue.SegmentedOperand, ...], tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], ): + buffers = list(inputs) + list(outputs) + for ope, stp in tensor_products: assert isinstance(ope, cue.Operation) assert isinstance(stp, cue.SegmentedTensorProduct) assert len(ope.buffers) == stp.num_operands + for buffer_id, operand in zip(ope.buffers, stp.operands): + assert operand == buffers[buffer_id] - object.__setattr__(self, "num_inputs", num_inputs) - object.__setattr__(self, "num_outputs", num_outputs) + object.__setattr__(self, "inputs", tuple(inputs)) + object.__setattr__(self, "outputs", tuple(outputs)) object.__setattr__(self, "tensor_products", tuple(sorted(tensor_products))) @classmethod def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): return cls( - stp.num_operands - 1, - 1, + stp.operands[:-1], + (stp.operands[-1],), ((cue.Operation(tuple(range(stp.num_operands))), stp),), ) + @classmethod + def from_default_buffers( + cls, + inputs: tuple[cue.SegmentedOperand, ...], + outputs: tuple[cue.SegmentedOperand, ...], + tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], + ): + buffers = list(inputs) + list(outputs) + for ope, stp in tensor_products: + assert isinstance(ope, cue.Operation) + assert isinstance(stp, cue.SegmentedTensorProduct) + assert len(ope.buffers) == stp.num_operands + for buffer_id, operand in zip(ope.buffers, stp.operands): + buffers[buffer_id] = operand + + return cls(buffers[: len(inputs)], buffers[len(inputs) :], tensor_products) + def __hash__(self) -> int: - return hash((self.num_inputs, self.num_outputs, self.tensor_products)) + return hash((self.inputs, self.outputs, self.tensor_products)) def __eq__(self, value) -> bool: assert isinstance(value, SegmentedPolynomial) return ( - self.num_inputs == value.num_inputs - and self.num_outputs == value.num_outputs + self.inputs == value.inputs + and self.outputs == value.outputs and self.tensor_products == value.tensor_products ) def __lt__(self, value) -> bool: assert isinstance(value, SegmentedPolynomial) return ( - self.num_inputs, - self.num_outputs, + self.inputs, + self.outputs, self.tensor_products, ) < ( - value.num_inputs, - value.num_outputs, + value.inputs, + value.outputs, value.tensor_products, ) def __mul__(self, factor: float) -> SegmentedPolynomial: return SegmentedPolynomial( - self.num_inputs, - self.num_outputs, + self.inputs, + self.outputs, tuple((ope, factor * stp) for ope, stp in self.tensor_products), ) @@ -101,7 +122,7 @@ def __rmul__(self, factor: float) -> SegmentedPolynomial: return self.__mul__(factor) def __repr__(self): - return self.to_string() + return self.to_string([f"[{ope.size}]" for ope in self.operands]) def to_string(self, buffer_names: list[str] | None = None) -> str: buffer_txts = ( @@ -155,8 +176,8 @@ def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: inferred_shape = np.broadcast_shapes(*[x.shape[:-1] for x in inputs]) inferred_dtype = np.result_type(*[x.dtype for x in inputs]) outputs = [ - np.zeros(inferred_shape + (size,), dtype=inferred_dtype) - for size in self.output_sizes + np.zeros(inferred_shape + (ope.size,), dtype=inferred_dtype) + for ope in self.outputs ] for ope, stp in self.tensor_products: oid, bid = ope.output_operand_buffer(self.num_inputs) @@ -170,33 +191,21 @@ def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: return outputs @property - def num_operands(self) -> int: - """Number of operands in the polynomial.""" - return self.num_inputs + self.num_outputs + def operands(self) -> tuple[cue.SegmentedOperand, ...]: + return self.inputs + self.outputs @property - def buffer_sizes(self) -> list[int | None]: - """Sizes of the buffers in the polynomial.""" - sizes = [None] * (self.num_inputs + self.num_outputs) - for ope, stp in self.tensor_products: - for buffer, operand in zip(ope.buffers, stp.operands): - if sizes[buffer] is None: - sizes[buffer] = operand.size - if sizes[buffer] != operand.size: - raise ValueError( - f"Buffer {buffer} has inconsistent sizes: {sizes[buffer]} vs {operand.size}" - ) - return sizes + def num_inputs(self) -> int: + return len(self.inputs) @property - def input_sizes(self) -> list[int | None]: - """Sizes of the input buffers in the polynomial.""" - return self.buffer_sizes[: self.num_inputs] + def num_outputs(self) -> int: + return len(self.outputs) @property - def output_sizes(self) -> list[int | None]: - """Sizes of the output buffers in the polynomial.""" - return self.buffer_sizes[self.num_inputs :] + def num_operands(self) -> int: + """Number of operands in the polynomial.""" + return self.num_inputs + self.num_outputs def map_tensor_products( self, @@ -209,8 +218,8 @@ def map_tensor_products( new_tensor_products = tuple( ope_stp for ope_stp in new_tensor_products if ope_stp is not None ) - return SegmentedPolynomial( - self.num_inputs, self.num_outputs, new_tensor_products + return SegmentedPolynomial.from_default_buffers( + self.inputs, self.outputs, new_tensor_products ) def fuse_stps(self) -> SegmentedPolynomial: @@ -242,9 +251,7 @@ def fuse_stps(self) -> SegmentedPolynomial: coefficient_subscripts, ), elements in groups ) - return SegmentedPolynomial( - self.num_inputs, self.num_outputs, new_tensor_products - ) + return SegmentedPolynomial(self.inputs, self.outputs, new_tensor_products) def consolidate(self) -> SegmentedPolynomial: """Consolidate the segmented tensor products.""" @@ -312,13 +319,9 @@ def select_buffers(self, keep: list[bool]) -> SegmentedPolynomial: new_ope = cue.Operation([new_index[buffer] for buffer in ope.buffers]) new_tensor_products.append((new_ope, stp)) - # Calculate new num_inputs and num_outputs - new_num_inputs = sum(keep[: self.num_inputs]) - new_num_outputs = sum(keep[self.num_inputs :]) - return SegmentedPolynomial( - new_num_inputs, - new_num_outputs, + [x for x, k in zip(self.inputs, keep[: self.num_inputs]) if k], + [x for x, k in zip(self.outputs, keep[self.num_inputs :]) if k], new_tensor_products, ) @@ -335,8 +338,8 @@ def compute_only(self, keep: list[bool]) -> SegmentedPolynomial: """Compute only the selected outputs of the polynomial.""" assert len(keep) == self.num_outputs return SegmentedPolynomial( - self.num_inputs, - self.num_outputs, # on purpose, we keep all outputs + self.inputs, + self.outputs, # on purpose, we keep all outputs [ (ope, stp) for ope, stp in self.tensor_products @@ -356,6 +359,17 @@ def stack( assert all(pol.num_outputs == num_outputs for pol in polys) assert len(stacked) == num_inputs + num_outputs + operands = [] + for bid in range(num_inputs + num_outputs): + if stacked[bid]: + operands.append( + cue.SegmentedOperand.stack([pol.operands[bid] for pol in polys]) + ) + else: + ope = polys[0].operands[bid] + assert all(pol.operands[bid] == ope for pol in polys) + operands.append(ope) + tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] = [] for index, pol in enumerate(polys): for ope, stp in pol.tensor_products: @@ -367,21 +381,24 @@ def stack( for p in polys[index + 1 :]: stp.insert_segments(oid, -1, p.buffer_segments(buffer)) tensor_products.append((ope, stp)) - return cls(num_inputs, num_outputs, tensor_products).consolidate() + + return cls( + operands[:num_inputs], operands[num_inputs:], tensor_products + ).consolidate() def squeeze_modes(self) -> SegmentedPolynomial: """Squeeze the modes of the segmented tensor products.""" - return SegmentedPolynomial( - self.num_inputs, - self.num_outputs, + return SegmentedPolynomial.from_default_buffers( + self.inputs, + self.outputs, [(ope, stp.squeeze_modes()) for ope, stp in self.tensor_products], ) def flatten_coefficient_modes(self) -> SegmentedPolynomial: """Flatten the coefficient modes of the segmented tensor products.""" - return SegmentedPolynomial( - self.num_inputs, - self.num_outputs, + return SegmentedPolynomial.from_default_buffers( + self.inputs, + self.outputs, [ (ope, stp.flatten_coefficient_modes()) for ope, stp in self.tensor_products @@ -404,7 +421,9 @@ def jvp(self, has_tangent: list[bool]) -> SegmentedPolynomial: ): new_tps.append((ope, multiplicator * stp)) return SegmentedPolynomial( - self.num_inputs + sum(has_tangent), self.num_outputs, new_tps + list(self.inputs) + [x for has, x in zip(has_tangent, self.inputs) if has], + self.outputs, + new_tps, ) def transpose( @@ -422,8 +441,12 @@ def transpose( if ope is not None: new_tps.append((ope, stp)) return SegmentedPolynomial( - sum(map(lambda u: not u, is_undefined_primal)) + sum(has_cotangent), - sum(is_undefined_primal), + # defined inputs + [x for undef, x in zip(is_undefined_primal, self.inputs) if not undef] + # cotangent outputs + + [x for has, x in zip(has_cotangent, self.outputs) if has], + # undefined inputs + [x for undef, x in zip(is_undefined_primal, self.inputs) if undef], new_tps, ) @@ -448,7 +471,7 @@ def flops(self, batch_size: int = 1) -> int: def memory(self, batch_sizes: list[int]) -> int: """Compute the memory usage of the polynomial.""" assert len(batch_sizes) == self.num_operands - return sum(Z * size for Z, size in zip(batch_sizes, self.buffer_sizes)) + return sum(Z * ope.size for Z, ope in zip(batch_sizes, self.operands)) def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: segments = None @@ -477,8 +500,9 @@ def symmetrize_for_identical_operands(self) -> SegmentedPolynomial: stp = stp.symmetrize_operands(set_of_operands) stp = stp.sort_paths() symmetrized_tensor_products.append((ope, stp)) + return SegmentedPolynomial( - self.num_inputs, self.num_outputs, symmetrized_tensor_products + self.inputs, self.outputs, symmetrized_tensor_products ) def unsymmetrize_for_identical_operands(self) -> SegmentedPolynomial: diff --git a/cuequivariance/tests/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomial_test.py index 0d7f2cd7..01f200d8 100644 --- a/cuequivariance/tests/segmented_polynomial_test.py +++ b/cuequivariance/tests/segmented_polynomial_test.py @@ -36,26 +36,23 @@ def make_simple_dot_product_stp() -> cue.SegmentedTensorProduct: def test_init_segmented_polynomial(): """Test initialization of SegmentedPolynomial.""" stp = make_simple_stp() - op = cue.Operation((0, 1, 2)) - poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + poly = cue.SegmentedPolynomial.eval_last_operand(stp) assert poly.num_inputs == 2 assert poly.num_outputs == 1 assert poly.num_operands == 3 assert len(poly.tensor_products) == 1 - assert poly.tensor_products[0] == (op, stp) + assert poly.tensor_products[0] == (cue.Operation((0, 1, 2)), stp) def test_polynomial_equality(): """Test equality comparison of polynomials.""" stp1 = make_simple_stp() stp2 = make_simple_stp() - op1 = cue.Operation((0, 1, 2)) - op2 = cue.Operation((0, 1, 2)) - poly1 = cue.SegmentedPolynomial(2, 1, [(op1, stp1)]) - poly2 = cue.SegmentedPolynomial(2, 1, [(op2, stp2)]) - poly3 = cue.SegmentedPolynomial(2, 1, [(op2, 2 * stp2)]) + poly1 = cue.SegmentedPolynomial.eval_last_operand(stp1) + poly2 = cue.SegmentedPolynomial.eval_last_operand(stp2) + poly3 = cue.SegmentedPolynomial.eval_last_operand(2 * stp2) assert poly1 == poly2 assert poly1 != poly3 @@ -72,8 +69,7 @@ def test_call_function(): i2 = stp.add_segment(2, (1,)) stp.add_path(i0, i1, i2, c=np.eye(3).reshape(3, 3, 1)) - op = cue.Operation((0, 1, 2)) - poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + poly = cue.SegmentedPolynomial.eval_last_operand(stp) # Test evaluation a = np.array([1.0, 2.0, 3.0]) @@ -95,12 +91,20 @@ def test_buffer_properties(): stp2.add_path(0, 0, c=1.0) op2 = cue.Operation((0, 3)) - poly = cue.SegmentedPolynomial(2, 2, [(op1, stp1), (op2, stp2)]) + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(1), + ], + [(op1, stp1), (op2, stp2)], + ) # Test buffer properties - assert poly.buffer_sizes == [2, 2, 2, 1] - assert poly.input_sizes == [2, 2] - assert poly.output_sizes == [2, 1] + assert [ope.size for ope in poly.operands] == [2, 2, 2, 1] assert poly.used_buffers() == [True, True, True, True] @@ -111,7 +115,15 @@ def test_remove_unused_buffers(): # Use operation that doesn't use buffer 1 op = cue.Operation((0, 2, 3)) # Note: buffer 1 is not used - poly = cue.SegmentedPolynomial(3, 1, [(op, stp)]) + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), # unused + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op, stp)], + ) # Buffer 1 is not used assert poly.used_buffers() == [True, False, True, True] @@ -132,7 +144,14 @@ def test_consolidate(): op = cue.Operation((0, 1, 2)) # Create a polynomial with duplicate operations - poly = cue.SegmentedPolynomial(2, 1, [(op, stp1), (op, stp2)]) + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op, stp1), (op, stp2)], + ) # Consolidate the polynomial consolidated = poly.consolidate() @@ -151,11 +170,25 @@ def test_stack(): # Create two simple polynomials using make_simple_stp stp = make_simple_stp() op1 = cue.Operation((0, 1, 2)) - poly1 = cue.SegmentedPolynomial(2, 1, [(op1, stp)]) + poly1 = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op1, stp)], + ) stp2 = make_simple_stp() op2 = cue.Operation((0, 1, 2)) - poly2 = cue.SegmentedPolynomial(2, 1, [(op2, stp2)]) + poly2 = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op2, stp2)], + ) # Stack the polynomials with the output being stacked stacked = cue.SegmentedPolynomial.stack([poly1, poly2], [False, False, True]) @@ -163,7 +196,7 @@ def test_stack(): assert stacked.num_inputs == 2 assert stacked.num_outputs == 1 - assert stacked.buffer_sizes == [2, 2, 4] + assert [ope.size for ope in stacked.operands] == [2, 2, 4] [(_, stp)] = stacked.tensor_products assert stp.operands[0].num_segments == 2 @@ -180,8 +213,14 @@ def test_flops_and_memory(): """Test computation of FLOPS and memory usage.""" stp = make_simple_stp() op = cue.Operation((0, 1, 2)) - poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) - + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op, stp)], + ) # Test FLOPS calculation flops = poly.flops(batch_size=100) assert flops > 0 @@ -195,9 +234,7 @@ def test_jvp(): """Test Jacobian-vector product computation.""" # Create a simple polynomial for testing: f(x,y) = x^T * y (dot product) stp = make_simple_dot_product_stp() - - op = cue.Operation((0, 1, 2)) - poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + poly = cue.SegmentedPolynomial.eval_last_operand(stp) # Input values x = np.array([1.0, 2.0, 3.0]) @@ -240,9 +277,7 @@ def test_transpose_linear(): # Here we use f(x, y) = x^T * y (dot product) # This is linear in both x and y stp = make_simple_dot_product_stp() - - op = cue.Operation((0, 1, 2)) - poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + poly = cue.SegmentedPolynomial.eval_last_operand(stp) # Input values x = np.array([1.0, 2.0, 3.0]) @@ -284,7 +319,13 @@ def test_transpose_nonlinear(): # Create a non-linear polynomial stp = make_simple_stp() op = cue.Operation((0, 0, 1)) # Note: using the same buffer twice (x^2) - poly = cue.SegmentedPolynomial(1, 1, [(op, stp)]) + poly = cue.SegmentedPolynomial( + [ + cue.SegmentedOperand.empty_segments(2), + ], + [cue.SegmentedOperand.empty_segments(2)], + [(op, stp)], + ) # Try to transpose the non-linear polynomial # This should raise a ValueError since there are multiple undefined primals @@ -297,8 +338,7 @@ def test_backward(): """Test the backward method for gradient computation.""" # Create a linear polynomial for testing: f(x,y) = x^T * y (dot product) stp = make_simple_dot_product_stp() - op = cue.Operation((0, 1, 2)) - poly = cue.SegmentedPolynomial(2, 1, [(op, stp)]) + poly = cue.SegmentedPolynomial.eval_last_operand(stp) # Input values x = np.array([1.0, 2.0, 3.0]) @@ -351,8 +391,11 @@ def test_symmetrize_identical_operands(): # Create operation that uses the same input buffer twice op = cue.Operation((0, 0, 1)) # Use buffer 0 twice, write to buffer 1 - poly = cue.SegmentedPolynomial(1, 1, [(op, stp)]) - + poly = cue.SegmentedPolynomial( + [cue.SegmentedOperand.empty_segments(2)], + [cue.SegmentedOperand.empty_segments(1)], + [(op, stp)], + ) # Symmetrize the polynomial sym_poly = poly.symmetrize_for_identical_operands() diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py index 4745d44c..dc4ba3a1 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py @@ -95,10 +95,8 @@ def segmented_polynomial( assert len(outputs_shape_dtype) == polynomial.num_outputs outputs_shape_dtype = [ - jax.ShapeDtypeStruct( - x.shape if size is None else x.shape[:-1] + (size,), x.dtype - ) - for x, size in zip(outputs_shape_dtype, polynomial.output_sizes) + jax.ShapeDtypeStruct(x.shape[:-1] + (ope.size,), x.dtype) + for x, ope in zip(outputs_shape_dtype, polynomial.outputs) ] buffers = list(inputs) + list(outputs_shape_dtype) diff --git a/cuequivariance_jax/tests/primitives/segmented_polynomial_test.py b/cuequivariance_jax/tests/primitives/segmented_polynomial_test.py index a19a2def..df3571bb 100644 --- a/cuequivariance_jax/tests/primitives/segmented_polynomial_test.py +++ b/cuequivariance_jax/tests/primitives/segmented_polynomial_test.py @@ -25,7 +25,9 @@ def test_one_operand(): d = cue.SegmentedTensorProduct.empty_segments([1]) [out] = cuex.segmented_polynomial( - cue.SegmentedPolynomial(0, 1, [(cue.Operation([0]), d)]), + cue.SegmentedPolynomial( + [], [cue.SegmentedOperand.empty_segments(1)], [(cue.Operation([0]), d)] + ), [], [jax.ShapeDtypeStruct((2, 1), jnp.float32)], ) @@ -33,7 +35,9 @@ def test_one_operand(): d.add_path(0, c=123) [out] = cuex.segmented_polynomial( - cue.SegmentedPolynomial(0, 1, [(cue.Operation([0]), d)]), + cue.SegmentedPolynomial( + [], [cue.SegmentedOperand.empty_segments(1)], [(cue.Operation([0]), d)] + ), [], [jax.ShapeDtypeStruct((2, 1), jnp.float32)], ) @@ -96,8 +100,8 @@ def test_vmap(): def f(x1, x2, i1): return cuex.segmented_polynomial( cue.SegmentedPolynomial( - 2, - 2, + d.operands[:2], + [d.operands[2], d.operands[2]], [ (cue.Operation([0, 1, 2]), d), (cue.Operation([0, 1, 3]), d), From a39053eb8e9e0dd1f36a0ef3dec8b29001979645 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 00:46:06 +0100 Subject: [PATCH 060/107] fix --- .../segmented_tensor_product.py | 16 +-- .../descriptor_test.py | 97 +++++++++++++++++++ .../primitives/symmetric_tensor_product.py | 4 +- .../primitives/transpose.py | 2 +- 4 files changed, 109 insertions(+), 10 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index 758ea429..090843a3 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -1132,10 +1132,10 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: if mode not in self.subscripts: return self - for oid, operand in enumerate(self.operands): - if mode in operand.subscripts and not operand.subscripts.startswith(mode): + for oid, ss in enumerate(self.subscripts.operands): + if mode in ss and not ss.startswith(mode): raise ValueError( - f"mode {mode} is not the first mode in operand {oid} ({operand.subscripts})." + f"mode {mode} is not the first mode in operand {oid} ({ss})." ) if not all(dim % size == 0 for dim in self.get_dims(mode)): @@ -1159,14 +1159,14 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: d = SegmentedTensorProduct.from_subscripts(self.subscripts) offsets_per_operand = [] - for oid, operand in enumerate(self.operands): - if mode not in operand.subscripts: + for oid, (operand, ss) in enumerate(self.operands_and_subscripts): + if mode not in ss: for segment in operand: d.add_segment(oid, segment) offsets_per_operand.append(None) continue - assert operand.subscripts.startswith(mode) + assert ss.startswith(mode) offsets = [] for segment in operand: @@ -1190,7 +1190,9 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: coefficients = path.coefficients if self.coefficient_subscripts.startswith(mode): coefficients = np.split(coefficients, num_subdivisions, axis=0)[i] - d.paths.append(stp.Path(indices=indices, coefficients=coefficients)) + d.insert_path_( + len(d.paths), stp.Path(indices=indices, coefficients=coefficients) + ) logger.debug(f"Split {mode} in {self}: got {d}") return d diff --git a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py index bfb6773e..5cae1a64 100644 --- a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py +++ b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py @@ -261,3 +261,100 @@ def test_hash(): assert hash(d) != hash(d2) d2.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) assert hash(d) == hash(d2) + + +def test_split_mode(): + # Create a descriptor with a mode that has dimensions divisible by the desired split size + d = stp.SegmentedTensorProduct.from_subscripts("ua,ub+ab") + + # Add a segment with u dimension = 6 (divisible by 2 and 3) + d.add_segment(0, (6, 4)) + d.add_segment(1, (6, 5)) + + # Add a path + d.add_path(0, 0, c=np.ones((4, 5))) + d.assert_valid() + + # Split mode 'u' with size 2 + d_split = d.split_mode("u", 2) + d_split.assert_valid() + + # Check that the dimensions are correctly split + assert d_split.operands[0].num_segments == 3 # 6/2 = 3 segments + assert d_split.operands[1].num_segments == 3 # 6/2 = 3 segments + + # Check that the subscripts are preserved + assert d_split.subscripts == "ua,ub+ab" + + # Check that the segments have the correct shape + for segment in d_split.operands[0]: + assert segment[0] == 2 # First dimension should be 2 + assert segment[1] == 4 # Second dimension should be 4 + + for segment in d_split.operands[1]: + assert segment[0] == 2 # First dimension should be 2 + assert segment[1] == 5 # Second dimension should be 5 + + # Test with a different split size + d_split_3 = d.split_mode("u", 3) + d_split_3.assert_valid() + + assert d_split_3.operands[0].num_segments == 2 # 6/3 = 2 segments + assert d_split_3.operands[1].num_segments == 2 # 6/3 = 2 segments + + # Test error case: split size not divisible by dimension + with pytest.raises(ValueError): + d.split_mode("u", 5) # 6 is not divisible by 5 + + # Test case where mode is not in descriptor + d_unchanged = d.split_mode("v", 2) # 'v' is not in the descriptor + assert d_unchanged == d + + # Test case where mode is not at the beginning of the operand + d_complex = stp.SegmentedTensorProduct.from_subscripts("au,bu+ab") + d_complex.add_segment(0, (3, 6)) + d_complex.add_segment(1, (4, 6)) + d_complex.add_path(0, 0, c=np.ones((3, 4))) + + with pytest.raises(ValueError): + d_complex.split_mode("u", 2) # 'u' is not the first mode in operands + + # Test with coefficient subscripts + d_coeff = stp.SegmentedTensorProduct.from_subscripts("ua,ub,ab+ab") + d_coeff.add_segment(0, (6, 4)) + d_coeff.add_segment(1, (6, 5)) + d_coeff.add_segment(2, (4, 5)) + d_coeff.add_path(0, 0, 0, c=np.ones((4, 5))) + + d_coeff_split = d_coeff.split_mode("u", 2) + d_coeff_split.assert_valid() + + assert d_coeff_split.operands[0].num_segments == 3 + assert d_coeff_split.operands[1].num_segments == 3 + assert d_coeff_split.operands[2].num_segments == 1 # Not affected by u split + + # Check that computation results are equivalent + # Create a simple descriptor with just two operands for testing compute_last_operand + d_compute = stp.SegmentedTensorProduct.from_subscripts("a,b+ab") + d_compute.add_segment(0, (4,)) + d_compute.add_segment(1, (5,)) + d_compute.add_path(0, 0, c=np.ones((4, 5))) + + # Test computation on original descriptor + x_input = np.random.randn(d_compute.operands[0].size) + result_original = stp.compute_last_operand(d_compute, x_input) + + # Verify split_mode works by first flattening the results to remove 'u' mode indices + d_ua = stp.SegmentedTensorProduct.from_subscripts("ua,b+ab") + d_ua.add_segment(0, (6, 4)) + d_ua.add_segment(1, (5,)) + d_ua.add_path(0, 0, c=np.ones((4, 5))) + d_ua_split = d_ua.split_mode("u", 2) + + # Input for the split descriptor - we need a tensor with the right shape + x_input_split = np.random.randn(d_ua_split.operands[0].size) + result_split = stp.compute_last_operand(d_ua_split, x_input_split) + + # Verify the shapes are consistent with our expectations + assert result_original.shape == (5,) + assert result_split.shape == (5,) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index bd257dd4..0e0afb2c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -249,7 +249,7 @@ def f(d: cue.SegmentedTensorProduct) -> cue.SegmentedTensorProduct: [ m for m in d.subscripts.modes() - if not all(m in ope.subscripts for ope in d.operands) + if not all(m in ss for ss in d.subscripts.operands) ] ) d = d.consolidate_modes() @@ -264,7 +264,7 @@ def f(d: cue.SegmentedTensorProduct) -> cue.SegmentedTensorProduct: m = d.subscripts.modes()[0] - if not all(ope.subscripts == m for ope in d.operands): + if not all(ss == m for ss in d.subscripts.operands): raise NotImplementedError("Different subscripts are not supported.") d = d.split_mode(m, math.gcd(*d.get_dims(m))) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py index 5c6b2039..632e2092 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/transpose.py @@ -145,7 +145,7 @@ def _transpose_segments_fx(segments: list[tuple[int, int]]) -> torch.nn.Module: x = torch.fx.Proxy(graph.placeholder("input"), tracer) outputs = [] - source = cue.SegmentedOperand(subscripts="ij", segments=segments) + source = cue.SegmentedOperand(ndim=2, segments=segments) for sl, (u, v) in zip(source.segment_slices(), source.segments): outputs += [ x[..., sl] From 92f501ce2801a63eab70c0af136ab0c6d68412a2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 00:53:19 +0100 Subject: [PATCH 061/107] fix --- .../primitives/symmetric_tensor_product.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index 0e0afb2c..d4fcaf31 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -48,7 +48,8 @@ def __init__( descriptors = [ cue.SegmentedTensorProduct( - operands=(cue.SegmentedOperand.empty_segments(1),) + d.operands, + operands_and_subscripts=[(cue.SegmentedOperand.empty_segments(1), "")] + + list(d.operands_and_subscripts), paths=[ cue.segmented_tensor_product.Path( (0,) + path.indices, path.coefficients From 35387bb5b70bc25b6836c61f945f4a89858f4aea Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 00:55:57 +0100 Subject: [PATCH 062/107] fix? --- .../cuequivariance_torch/primitives/tensor_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 781f1107..0cc9985c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -580,7 +580,7 @@ def __init__( import cuequivariance_ops_torch as ops self._f = ops.FusedTensorProductOp4( - operand_segment_modes=[ope.subscripts for ope in descriptor.operands], + operand_segment_modes=descriptor.subscripts.operands, operand_segment_offsets=[ [s.start for s in ope.segment_slices()] for ope in descriptor.operands ], From 374e26890d472b6caedc6e75568f4bb01a8f4456 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 00:58:27 +0100 Subject: [PATCH 063/107] fix --- .../segmented_tensor_product/segmented_tensor_product.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index 090843a3..c035719c 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -859,7 +859,7 @@ def add_or_rename_modes( ) mapping = mappings[0] - if not all(ch in mapping for ope in self.operands for ch in ope.subscripts): + if not all(ch in mapping for ss in self.subscripts.operands for ch in ss): raise ValueError( f"expected all segment modes to be in the mapping {mapping}." ) @@ -872,14 +872,14 @@ def add_or_rename_modes( for oid in range(self.num_operands): for sid in range(self.operands[oid].num_segments): dims = collections.defaultdict(lambda: 1) - for m, d in zip(self.operands[oid].subscripts, self.operands[oid][sid]): + for m, d in zip(self.subscripts.operands[oid], self.operands[oid][sid]): dims[mapping[m]] = d D.add_segment(oid, dims) for path in self.paths: dims: dict[str, int] = dict() for oid, sid in enumerate(path.indices): - for m, d in zip(D.operands[oid].subscripts, D.operands[oid][sid]): + for m, d in zip(D.subscripts.operands[oid], D.operands[oid][sid]): dims[m] = d D.insert_path_( From 1682b171bf33cb4cb720bd21624baf1749c2675b Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 08:26:33 +0100 Subject: [PATCH 064/107] improve STP testing --- .../descriptor_test.py | 124 ++++++++++++++---- 1 file changed, 96 insertions(+), 28 deletions(-) diff --git a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py index 5cae1a64..306e4846 100644 --- a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py +++ b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py @@ -15,11 +15,11 @@ import numpy as np import pytest -import cuequivariance.segmented_tensor_product as stp +import cuequivariance as cue def test_user_friendly(): - d = stp.SegmentedTensorProduct.from_subscripts("ia_jb_kab+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("ia_jb_kab+ijk") assert ( str(d) == "ia,jb,kab+ijk sizes=0,0,0 num_segments=0,0,0 num_paths=0 a= b= i= j= k=" @@ -58,7 +58,7 @@ def test_user_friendly(): def test_squeeze(): - d = stp.SegmentedTensorProduct.from_subscripts("i_j+ij") + d = cue.SegmentedTensorProduct.from_subscripts("i_j+ij") d.add_segment(0, (1,)) d.add_segment(1, (20,)) d.add_path(0, 0, c=np.ones((1, 20))) @@ -66,7 +66,7 @@ def test_squeeze(): assert d.squeeze_modes().subscripts == ",j+j" - d = stp.SegmentedTensorProduct.from_subscripts("i_j+ij") + d = cue.SegmentedTensorProduct.from_subscripts("i_j+ij") d.add_segment(0, (1,)) d.add_segment(0, (2,)) d.add_segment(1, (20,)) @@ -81,7 +81,7 @@ def test_squeeze(): def test_normalize_paths_for_operand(): - d = stp.SegmentedTensorProduct.from_subscripts("i_j+ij") + d = cue.SegmentedTensorProduct.from_subscripts("i_j+ij") d.add_segments(0, 2 * [(2,)]) d.add_segments(1, 2 * [(3,)]) @@ -105,7 +105,7 @@ def test_normalize_paths_for_operand(): def make_example_descriptor(): - d = stp.SegmentedTensorProduct.from_subscripts("uv_iu_jv+ij") + d = cue.SegmentedTensorProduct.from_subscripts("uv_iu_jv+ij") d.add_path( None, None, @@ -137,11 +137,14 @@ def test_flatten(): x0 = np.random.randn(d.operands[0].size) x1 = np.random.randn(d.operands[1].size) - x2 = stp.compute_last_operand(d, x0, x1) + x2 = cue.segmented_tensor_product.compute_last_operand(d, x0, x1) for channels in ["i", "j", "ij", "ui", "iju", "uvij"]: np.testing.assert_allclose( - x2, stp.compute_last_operand(d.flatten_modes(channels), x0, x1) + x2, + cue.segmented_tensor_product.compute_last_operand( + d.flatten_modes(channels), x0, x1 + ), ) @@ -161,7 +164,7 @@ def test_flatten_coefficients(): def test_consolidate(): - d = stp.SegmentedTensorProduct.from_subscripts("ab_ab") + d = cue.SegmentedTensorProduct.from_subscripts("ab_ab") d.add_segment(0, (2, 3)) d.add_segment(1, (2, 3)) d.add_path(0, 0, c=1.0) @@ -169,7 +172,7 @@ def test_consolidate(): assert d.consolidate_modes().subscripts == "a,a" - d = stp.SegmentedTensorProduct.from_subscripts("ab_ab_a") + d = cue.SegmentedTensorProduct.from_subscripts("ab_ab_a") d.add_segment(0, (2, 3)) d.add_segment(1, (2, 3)) d.add_segment(2, (2,)) @@ -178,7 +181,7 @@ def test_consolidate(): assert d.consolidate_modes() == d - d = stp.SegmentedTensorProduct.from_subscripts("ab_iab+abi") + d = cue.SegmentedTensorProduct.from_subscripts("ab_iab+abi") d.add_segment(0, (2, 3)) d.add_segment(1, (4, 2, 3)) d.add_path(0, 0, c=np.ones((2, 3, 4))) @@ -186,7 +189,7 @@ def test_consolidate(): assert d.consolidate_modes().subscripts == "a,ia+ai" - d = stp.SegmentedTensorProduct.from_subscripts("ab,iab+abi") + d = cue.SegmentedTensorProduct.from_subscripts("ab,iab+abi") d.add_segment(0, (2, 3)) d.add_segment(1, (4, 2, 3)) d.add_path(0, 0, c=np.ones((2, 3, 4))) @@ -195,7 +198,7 @@ def test_consolidate(): def test_stacked_coefficients(): - d = stp.SegmentedTensorProduct.from_subscripts("ab_ab+ab") + d = cue.SegmentedTensorProduct.from_subscripts("ab_ab+ab") d.add_segment(0, (2, 3)) d.add_segment(1, (2, 3)) np.testing.assert_allclose(d.stacked_coefficients, np.ones((0, 2, 3))) @@ -210,7 +213,7 @@ def test_stacked_coefficients(): @pytest.mark.parametrize("extended", [False, True]) def test_data_transfer(extended: bool): - d = stp.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) d.assert_valid() @@ -220,14 +223,14 @@ def test_data_transfer(extended: bool): bin = d.to_bytes(extended) b64 = d.to_base64(extended) - assert d == stp.SegmentedTensorProduct.from_dict(dict) - assert d == stp.SegmentedTensorProduct.from_json(json) - assert d == stp.SegmentedTensorProduct.from_bytes(bin) - assert d == stp.SegmentedTensorProduct.from_base64(b64) + assert d == cue.SegmentedTensorProduct.from_dict(dict) + assert d == cue.SegmentedTensorProduct.from_json(json) + assert d == cue.SegmentedTensorProduct.from_bytes(bin) + assert d == cue.SegmentedTensorProduct.from_base64(b64) def test_to_text(): - d = stp.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) d = d.flatten_modes("ijk") @@ -250,12 +253,12 @@ def test_to_text(): def test_hash(): - d = stp.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) assert hash(d) == hash(d) - d2 = stp.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") + d2 = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") assert hash(d) != hash(d2) d2.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) assert hash(d) != hash(d2) @@ -265,7 +268,7 @@ def test_hash(): def test_split_mode(): # Create a descriptor with a mode that has dimensions divisible by the desired split size - d = stp.SegmentedTensorProduct.from_subscripts("ua,ub+ab") + d = cue.SegmentedTensorProduct.from_subscripts("ua,ub+ab") # Add a segment with u dimension = 6 (divisible by 2 and 3) d.add_segment(0, (6, 4)) @@ -311,7 +314,7 @@ def test_split_mode(): assert d_unchanged == d # Test case where mode is not at the beginning of the operand - d_complex = stp.SegmentedTensorProduct.from_subscripts("au,bu+ab") + d_complex = cue.SegmentedTensorProduct.from_subscripts("au,bu+ab") d_complex.add_segment(0, (3, 6)) d_complex.add_segment(1, (4, 6)) d_complex.add_path(0, 0, c=np.ones((3, 4))) @@ -320,7 +323,7 @@ def test_split_mode(): d_complex.split_mode("u", 2) # 'u' is not the first mode in operands # Test with coefficient subscripts - d_coeff = stp.SegmentedTensorProduct.from_subscripts("ua,ub,ab+ab") + d_coeff = cue.SegmentedTensorProduct.from_subscripts("ua,ub,ab+ab") d_coeff.add_segment(0, (6, 4)) d_coeff.add_segment(1, (6, 5)) d_coeff.add_segment(2, (4, 5)) @@ -335,17 +338,19 @@ def test_split_mode(): # Check that computation results are equivalent # Create a simple descriptor with just two operands for testing compute_last_operand - d_compute = stp.SegmentedTensorProduct.from_subscripts("a,b+ab") + d_compute = cue.SegmentedTensorProduct.from_subscripts("a,b+ab") d_compute.add_segment(0, (4,)) d_compute.add_segment(1, (5,)) d_compute.add_path(0, 0, c=np.ones((4, 5))) # Test computation on original descriptor x_input = np.random.randn(d_compute.operands[0].size) - result_original = stp.compute_last_operand(d_compute, x_input) + result_original = cue.segmented_tensor_product.compute_last_operand( + d_compute, x_input + ) # Verify split_mode works by first flattening the results to remove 'u' mode indices - d_ua = stp.SegmentedTensorProduct.from_subscripts("ua,b+ab") + d_ua = cue.SegmentedTensorProduct.from_subscripts("ua,b+ab") d_ua.add_segment(0, (6, 4)) d_ua.add_segment(1, (5,)) d_ua.add_path(0, 0, c=np.ones((4, 5))) @@ -353,8 +358,71 @@ def test_split_mode(): # Input for the split descriptor - we need a tensor with the right shape x_input_split = np.random.randn(d_ua_split.operands[0].size) - result_split = stp.compute_last_operand(d_ua_split, x_input_split) + result_split = cue.segmented_tensor_product.compute_last_operand( + d_ua_split, x_input_split + ) # Verify the shapes are consistent with our expectations assert result_original.shape == (5,) assert result_split.shape == (5,) + + +def test_add_or_transpose_modes(): + # Test 1: Simple mode transposition + d = cue.SegmentedTensorProduct.from_subscripts("ia,ja+ij") + d.add_segment(0, (3, 4)) + d.add_segment(1, (5, 4)) + d.add_path(0, 0, c=np.ones((3, 5))) + d.assert_valid() + + # Transpose modes in first operand + d_trans = d.add_or_transpose_modes("ai,ja+ij") + d_trans.assert_valid() + assert d_trans.subscripts == "ai,ja+ij" + assert d_trans.operands[0][0] == (4, 3) and d_trans.operands[1][0] == (5, 4) + np.testing.assert_allclose(d_trans.paths[0].coefficients, d.paths[0].coefficients) + + # Test 2: Adding new modes + d = cue.SegmentedTensorProduct.from_subscripts("i,j+ij") + d.add_segment(0, (3,)) + d.add_segment(1, (4,)) + d.add_path(0, 0, c=np.ones((3, 4))) + d.assert_valid() + + # Add new modes with specified dimensions + d_new = d.add_or_transpose_modes("ia,ja+ij", dims={"a": 5}) + d_new.assert_valid() + assert d_new.subscripts == "ia,ja+ij" + assert d_new.operands[0][0] == (3, 5) and d_new.operands[1][0] == (4, 5) + + # Test 3 & 4: Error cases + with pytest.raises(ValueError): + d.add_or_transpose_modes("ia,ja+ij") # Missing dims for a + with pytest.raises(ValueError): + d.add_or_transpose_modes("i,j+i") # Removing j from coefficients + + # Test 5: Transposing coefficient modes + d = cue.SegmentedTensorProduct.from_subscripts("i,j+ij") + d.add_segment(0, (3,)) + d.add_segment(1, (4,)) + d.add_path(0, 0, c=np.ones((3, 4))) + d.assert_valid() + + # Transpose coefficient modes + d_coeff = d.add_or_transpose_modes("i,j+ji") + d_coeff.assert_valid() + assert d_coeff.subscripts == "i,j+ji" + np.testing.assert_allclose(d_coeff.paths[0].coefficients, d.paths[0].coefficients.T) + + # Test 6: Adding batch dimensions + d = cue.SegmentedTensorProduct.from_subscripts("ui,uj+ij") + d.add_segment(0, (5, 3)) + d.add_segment(1, (5, 4)) + d.add_path(0, 0, c=np.ones((3, 4))) + d.assert_valid() + + # Add batch dimension to both operands + d_batch = d.add_or_transpose_modes("bui,buj+ij", dims={"b": 8}) + d_batch.assert_valid() + assert d_batch.subscripts == "bui,buj+ij" + assert d_batch.operands[0][0] == (8, 5, 3) and d_batch.operands[1][0] == (8, 5, 4) From f49197eaf92ed91d1a5b97c1803128bd68d5c93d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 08:37:58 +0100 Subject: [PATCH 065/107] imports --- .../segmented_tensor_product/dispatch.py | 2 +- .../segmented_tensor_product/evaluate.py | 4 ++-- .../compute_last_operand_test.py | 24 ++++++++++--------- .../segmented_tensor_product/dot_test.py | 4 ++-- 4 files changed, 18 insertions(+), 16 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py b/cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py index ecdce832..6e5b3749 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py @@ -16,7 +16,7 @@ import math from typing import Generator, Tuple -import cuequivariance.segmented_tensor_product as stp +import cuequivariance.segmented_tensor_product as stp # we cannot import cuequivariance as cue because of circular import def dispatch( diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py b/cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py index 289a722d..631890ec 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py @@ -19,11 +19,11 @@ import numpy as np -import cuequivariance.segmented_tensor_product as stp +import cuequivariance as cue def compute_last_operand( - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, *inputs: np.ndarray, segment_axes: Union[int, list[int]] = -1, dtype: Optional[np.dtype] = None, diff --git a/cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py b/cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py index d57b3390..f4ed6816 100644 --- a/cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py +++ b/cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py @@ -14,40 +14,42 @@ # limitations under the License. import numpy as np -import cuequivariance.segmented_tensor_product as stp +import cuequivariance as cue def test_compute_last_operand_1(): - d = stp.SegmentedTensorProduct.from_subscripts("uv_vw_uw") + d = cue.SegmentedTensorProduct.from_subscripts("uv_vw_uw") d.add_path(None, None, None, c=1.0, dims={"u": 2, "v": 3, "w": 4}) x0 = np.random.randn(d.operands[0].size) x1 = np.random.randn(d.operands[1].size) - x2 = stp.compute_last_operand(d, x0, x1) + x2 = cue.segmented_tensor_product.compute_last_operand(d, x0, x1) x2_ = (x0.reshape(2, 3) @ x1.reshape(3, 4)).reshape(-1) np.testing.assert_allclose(x2_, x2) def test_compute_last_operand_2(): - d = stp.SegmentedTensorProduct.from_subscripts("uv,vw,uw") + d = cue.SegmentedTensorProduct.from_subscripts("uv,vw,uw") d.add_path(None, None, None, c=1.0, dims={"u": 2, "v": 3, "w": 4}) x0 = np.random.randn(10, d.operands[0].size) x1 = np.random.randn(10, d.operands[1].size) - x2 = stp.compute_last_operand(d, x0, x1) + x2 = cue.segmented_tensor_product.compute_last_operand(d, x0, x1) x2_ = (x0.reshape(10, 2, 3) @ x1.reshape(10, 3, 4)).reshape(10, -1) np.testing.assert_allclose(x2_, x2) def test_compute_last_operand_3(): - d = stp.SegmentedTensorProduct.from_subscripts("uv,vw,uw") + d = cue.SegmentedTensorProduct.from_subscripts("uv,vw,uw") d.add_path(None, None, None, c=1.0, dims={"u": 2, "v": 3, "w": 4}) x0 = np.random.randn(1, d.operands[0].size, 5) x1 = np.random.randn(10, d.operands[1].size, 1) - x2 = stp.compute_last_operand(d, x0, x1, segment_axes=[1, 1, 1]) + x2 = cue.segmented_tensor_product.compute_last_operand( + d, x0, x1, segment_axes=[1, 1, 1] + ) x2_ = np.einsum( "ZuvA,ZvwA->ZuwA", x0.reshape(1, 2, 3, 5), x1.reshape(10, 3, 4, 1) @@ -56,13 +58,13 @@ def test_compute_last_operand_3(): def test_compute_last_operand_4(): - d = stp.SegmentedTensorProduct.from_subscripts("iuv_jvw_kuw+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iuv_jvw_kuw+ijk") c = np.random.randn(2, 3, 4) d.add_path(None, None, None, c=c, dims={"u": 2, "v": 3, "w": 4}) x0 = np.random.randn(d.operands[0].size) x1 = np.random.randn(d.operands[1].size) - x2 = stp.compute_last_operand(d, x0, x1) + x2 = cue.segmented_tensor_product.compute_last_operand(d, x0, x1) x2_ = np.einsum( "ijk,iuv,jvw->kuw", c, x0.reshape(2, 2, 3), x1.reshape(3, 3, 4) @@ -71,7 +73,7 @@ def test_compute_last_operand_4(): def test_primitive_compute_last_operand(): - d = stp.SegmentedTensorProduct.from_subscripts("iuv_jvw_kuw+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iuv_jvw_kuw+ijk") c = np.random.randn(2, 3, 4) d.add_path(None, None, None, c=c, dims={"u": 2, "v": 3, "w": 4}) @@ -84,7 +86,7 @@ def test_primitive_compute_last_operand(): d = d.to_dict(True) - x2 = stp.primitive_compute_last_operand( + x2 = cue.segmented_tensor_product.primitive_compute_last_operand( [ope["subscripts"] for ope in d["operands"]], d["coefficient_subscripts"], [ope["segments"] for ope in d["operands"]], diff --git a/cuequivariance/tests/segmented_tensor_product/dot_test.py b/cuequivariance/tests/segmented_tensor_product/dot_test.py index 18d237de..417570a0 100644 --- a/cuequivariance/tests/segmented_tensor_product/dot_test.py +++ b/cuequivariance/tests/segmented_tensor_product/dot_test.py @@ -20,12 +20,12 @@ def test_dot1(): - d1 = stp.SegmentedTensorProduct.from_subscripts("iab_jb_ak+ijk") + d1 = cue.SegmentedTensorProduct.from_subscripts("iab_jb_ak+ijk") d1.add_path(None, None, None, c=np.random.randn(2, 2, 2), dims={"a": 2, "b": 3}) d1.add_path(None, None, None, c=np.random.randn(2, 3, 2), dims={"a": 4, "b": 3}) d1.add_path(0, 1, 0, c=np.random.randn(2, 3, 2)) - d2 = stp.SegmentedTensorProduct.from_subscripts("jb_b_+j") + d2 = cue.SegmentedTensorProduct.from_subscripts("jb_b_+j") d2.add_path(None, None, None, c=np.random.randn(2), dims={"b": 3}) d2.add_path(None, 0, None, c=np.random.randn(3)) From 559ac8213e1d8b9e88b8ad7ee551697776f352e7 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 09:05:54 +0100 Subject: [PATCH 066/107] add tests --- .../descriptor_test.py | 98 +++++++++++++------ 1 file changed, 67 insertions(+), 31 deletions(-) diff --git a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py index 306e4846..2621b43f 100644 --- a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py +++ b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py @@ -18,6 +18,12 @@ import cuequivariance as cue +def make_coeffs(shape): + n = np.prod(shape) + c = np.arange(n) + 1.0 + return c.reshape(shape) + + def test_user_friendly(): d = cue.SegmentedTensorProduct.from_subscripts("ia_jb_kab+ijk") assert ( @@ -26,7 +32,7 @@ def test_user_friendly(): ) with pytest.raises(ValueError): - d.add_path(0, 0, 0, c=np.ones((2, 2, 3))) # need to add segments first + d.add_path(0, 0, 0, c=make_coeffs((2, 2, 3))) # need to add segments first with pytest.raises(ValueError): d.add_segment(0, (2, 2, 2)) # wrong number of dimensions @@ -37,14 +43,14 @@ def test_user_friendly(): d.add_segment(2, (4, 32, 32)) with pytest.raises(ValueError): - d.add_path(0, 0, 0, c=np.ones((5, 5, 5))) # wrong dimension for k + d.add_path(0, 0, 0, c=make_coeffs((5, 5, 5))) # wrong dimension for k assert ( str(d) == "ia,jb,kab+ijk sizes=80,160,6144 num_segments=1,1,2 num_paths=0 a={16, 32} b=32 i=5 j=5 k=4" ) - d.add_path(0, 0, 0, c=np.ones((5, 5, 4))) + d.add_path(0, 0, 0, c=make_coeffs((5, 5, 4))) assert ( str(d) == "ia,jb,kab+ijk sizes=80,160,6144 num_segments=1,1,2 num_paths=1 a={16, 32} b=32 i=5 j=5 k=4" @@ -61,7 +67,7 @@ def test_squeeze(): d = cue.SegmentedTensorProduct.from_subscripts("i_j+ij") d.add_segment(0, (1,)) d.add_segment(1, (20,)) - d.add_path(0, 0, c=np.ones((1, 20))) + d.add_path(0, 0, c=make_coeffs((1, 20))) d.assert_valid() assert d.squeeze_modes().subscripts == ",j+j" @@ -70,8 +76,8 @@ def test_squeeze(): d.add_segment(0, (1,)) d.add_segment(0, (2,)) d.add_segment(1, (20,)) - d.add_path(0, 0, c=np.ones((1, 20))) - d.add_path(1, 0, c=np.ones((2, 20))) + d.add_path(0, 0, c=make_coeffs((1, 20))) + d.add_path(1, 0, c=make_coeffs((2, 20))) d.assert_valid() assert d.squeeze_modes().subscripts == "i,j+ij" @@ -184,7 +190,7 @@ def test_consolidate(): d = cue.SegmentedTensorProduct.from_subscripts("ab_iab+abi") d.add_segment(0, (2, 3)) d.add_segment(1, (4, 2, 3)) - d.add_path(0, 0, c=np.ones((2, 3, 4))) + d.add_path(0, 0, c=make_coeffs((2, 3, 4))) d.assert_valid() assert d.consolidate_modes().subscripts == "a,ia+ai" @@ -192,7 +198,7 @@ def test_consolidate(): d = cue.SegmentedTensorProduct.from_subscripts("ab,iab+abi") d.add_segment(0, (2, 3)) d.add_segment(1, (4, 2, 3)) - d.add_path(0, 0, c=np.ones((2, 3, 4))) + d.add_path(0, 0, c=make_coeffs((2, 3, 4))) assert d.consolidate_modes().subscripts == "a,ia+ai" @@ -201,21 +207,22 @@ def test_stacked_coefficients(): d = cue.SegmentedTensorProduct.from_subscripts("ab_ab+ab") d.add_segment(0, (2, 3)) d.add_segment(1, (2, 3)) - np.testing.assert_allclose(d.stacked_coefficients, np.ones((0, 2, 3))) + np.testing.assert_allclose(d.stacked_coefficients, make_coeffs((0, 2, 3))) - d.add_path(0, 0, c=np.ones((2, 3))) - d.add_path(0, 0, c=np.ones((2, 3))) - np.testing.assert_allclose(d.stacked_coefficients, np.ones((2, 2, 3))) + d.add_path(0, 0, c=make_coeffs((2, 3))) + d.add_path(0, 0, c=make_coeffs((2, 3))) + expected = np.stack([make_coeffs((2, 3)), make_coeffs((2, 3))], axis=0) + np.testing.assert_allclose(d.stacked_coefficients, expected) d = d.consolidate_paths() - np.testing.assert_allclose(d.stacked_coefficients, 2 * np.ones((1, 2, 3))) + np.testing.assert_allclose(d.stacked_coefficients, 2 * make_coeffs((1, 2, 3))) @pytest.mark.parametrize("extended", [False, True]) def test_data_transfer(extended: bool): d = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") - d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) - d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) + d.add_path(None, None, None, c=make_coeffs((3, 3, 3)), dims={"u": 12}) + d.add_path(None, None, None, c=make_coeffs((1, 2, 1)), dims={"u": 14}) d.assert_valid() dict = d.to_dict(extended) @@ -231,8 +238,8 @@ def test_data_transfer(extended: bool): def test_to_text(): d = cue.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") - d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) - d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) + d.add_path(None, None, None, c=make_coeffs((3, 3, 3)), dims={"u": 12}) + d.add_path(None, None, None, c=make_coeffs((1, 2, 1)), dims={"u": 14}) d = d.flatten_modes("ijk") text = d.to_text() @@ -248,21 +255,21 @@ def test_to_text(): Flop cost: 0->704 1->704 2->704 Memory cost: 164 Path indices: 0 0 0, 0 0 1, 0 0 2, 0 1 0, 0 1 1, 0 1 2, 0 2 0, 0 2 1, 0 2 2, 1 0 0, 1 0 1, 1 0 2, 1 1 0, 1 1 1, 1 1 2, 1 2 0, 1 2 1, 1 2 2, 2 0 0, 2 0 1, 2 0 2, 2 1 0, 2 1 1, 2 1 2, 2 2 0, 2 2 1, 2 2 2, 3 3 3, 3 4 3 -Path coefficients: [1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0]""" +Path coefficients: [1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0 11.0 12.0 13.0 14.0 15.0 16.0 17.0 18.0 19.0 20.0 21.0 22.0 23.0 24.0 25.0 26.0 27.0 1.0 2.0]""" ) def test_hash(): d = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") - d.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) - d.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) + d.add_path(None, None, None, c=make_coeffs((3, 3, 3)), dims={"u": 12}) + d.add_path(None, None, None, c=make_coeffs((1, 2, 1)), dims={"u": 14}) assert hash(d) == hash(d) d2 = cue.SegmentedTensorProduct.from_subscripts("ui,uj,uk+ijk") assert hash(d) != hash(d2) - d2.add_path(None, None, None, c=np.ones((3, 3, 3)), dims={"u": 12}) + d2.add_path(None, None, None, c=make_coeffs((3, 3, 3)), dims={"u": 12}) assert hash(d) != hash(d2) - d2.add_path(None, None, None, c=np.ones((1, 2, 1)), dims={"u": 14}) + d2.add_path(None, None, None, c=make_coeffs((1, 2, 1)), dims={"u": 14}) assert hash(d) == hash(d2) @@ -275,7 +282,7 @@ def test_split_mode(): d.add_segment(1, (6, 5)) # Add a path - d.add_path(0, 0, c=np.ones((4, 5))) + d.add_path(0, 0, c=make_coeffs((4, 5))) d.assert_valid() # Split mode 'u' with size 2 @@ -317,7 +324,7 @@ def test_split_mode(): d_complex = cue.SegmentedTensorProduct.from_subscripts("au,bu+ab") d_complex.add_segment(0, (3, 6)) d_complex.add_segment(1, (4, 6)) - d_complex.add_path(0, 0, c=np.ones((3, 4))) + d_complex.add_path(0, 0, c=make_coeffs((3, 4))) with pytest.raises(ValueError): d_complex.split_mode("u", 2) # 'u' is not the first mode in operands @@ -327,7 +334,7 @@ def test_split_mode(): d_coeff.add_segment(0, (6, 4)) d_coeff.add_segment(1, (6, 5)) d_coeff.add_segment(2, (4, 5)) - d_coeff.add_path(0, 0, 0, c=np.ones((4, 5))) + d_coeff.add_path(0, 0, 0, c=make_coeffs((4, 5))) d_coeff_split = d_coeff.split_mode("u", 2) d_coeff_split.assert_valid() @@ -341,7 +348,7 @@ def test_split_mode(): d_compute = cue.SegmentedTensorProduct.from_subscripts("a,b+ab") d_compute.add_segment(0, (4,)) d_compute.add_segment(1, (5,)) - d_compute.add_path(0, 0, c=np.ones((4, 5))) + d_compute.add_path(0, 0, c=make_coeffs((4, 5))) # Test computation on original descriptor x_input = np.random.randn(d_compute.operands[0].size) @@ -353,7 +360,7 @@ def test_split_mode(): d_ua = cue.SegmentedTensorProduct.from_subscripts("ua,b+ab") d_ua.add_segment(0, (6, 4)) d_ua.add_segment(1, (5,)) - d_ua.add_path(0, 0, c=np.ones((4, 5))) + d_ua.add_path(0, 0, c=make_coeffs((4, 5))) d_ua_split = d_ua.split_mode("u", 2) # Input for the split descriptor - we need a tensor with the right shape @@ -372,7 +379,7 @@ def test_add_or_transpose_modes(): d = cue.SegmentedTensorProduct.from_subscripts("ia,ja+ij") d.add_segment(0, (3, 4)) d.add_segment(1, (5, 4)) - d.add_path(0, 0, c=np.ones((3, 5))) + d.add_path(0, 0, c=make_coeffs((3, 5))) d.assert_valid() # Transpose modes in first operand @@ -386,7 +393,7 @@ def test_add_or_transpose_modes(): d = cue.SegmentedTensorProduct.from_subscripts("i,j+ij") d.add_segment(0, (3,)) d.add_segment(1, (4,)) - d.add_path(0, 0, c=np.ones((3, 4))) + d.add_path(0, 0, c=make_coeffs((3, 4))) d.assert_valid() # Add new modes with specified dimensions @@ -405,7 +412,7 @@ def test_add_or_transpose_modes(): d = cue.SegmentedTensorProduct.from_subscripts("i,j+ij") d.add_segment(0, (3,)) d.add_segment(1, (4,)) - d.add_path(0, 0, c=np.ones((3, 4))) + d.add_path(0, 0, c=make_coeffs((3, 4))) d.assert_valid() # Transpose coefficient modes @@ -418,7 +425,7 @@ def test_add_or_transpose_modes(): d = cue.SegmentedTensorProduct.from_subscripts("ui,uj+ij") d.add_segment(0, (5, 3)) d.add_segment(1, (5, 4)) - d.add_path(0, 0, c=np.ones((3, 4))) + d.add_path(0, 0, c=make_coeffs((3, 4))) d.assert_valid() # Add batch dimension to both operands @@ -426,3 +433,32 @@ def test_add_or_transpose_modes(): d_batch.assert_valid() assert d_batch.subscripts == "bui,buj+ij" assert d_batch.operands[0][0] == (8, 5, 3) and d_batch.operands[1][0] == (8, 5, 4) + + +def test_add_or_rename_modes(): + d = cue.SegmentedTensorProduct.from_subscripts("i,j+ij") + d.add_segment(0, (3,)) + d.add_segment(1, (4,)) + d.add_path(0, 0, c=make_coeffs((3, 4))) + d.assert_valid() + d_id = d.add_or_rename_modes("i,j+ij") + d_id.assert_valid() + assert d_id.subscripts == "i,j+ij" + x_input = np.random.randn(d.operands[0].size) + res0 = cue.segmented_tensor_product.compute_last_operand(d, x_input) + res_id = cue.segmented_tensor_product.compute_last_operand(d_id, x_input) + np.testing.assert_allclose(res0, res_id) + d_ren = d.add_or_rename_modes("a,b+ab") + d_ren.assert_valid() + assert d_ren.subscripts == "a,b+ab" + np.testing.assert_allclose( + res0, cue.segmented_tensor_product.compute_last_operand(d_ren, x_input) + ) + with pytest.raises(ValueError): + d.add_or_rename_modes("i+ij") + d_sup = d.add_or_rename_modes("bi,bj+ij", mapping={"i": "i", "j": "j"}) + d_sup.assert_valid() + assert d_sup.subscripts == "bi,bj+ij" + np.testing.assert_allclose( + res0, cue.segmented_tensor_product.compute_last_operand(d_sup, x_input) + ) From ea990a5742e1153594223b44b2070c87fca5a6e8 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 09:55:00 +0100 Subject: [PATCH 067/107] add test --- .../tests/segmented_tensor_product/descriptor_test.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py index 2621b43f..95b4c634 100644 --- a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py +++ b/cuequivariance/tests/segmented_tensor_product/descriptor_test.py @@ -462,3 +462,13 @@ def test_add_or_rename_modes(): np.testing.assert_allclose( res0, cue.segmented_tensor_product.compute_last_operand(d_sup, x_input) ) + + +def test_consolidate_with_optional_argument(): + d = cue.SegmentedTensorProduct.from_subscripts("ab_ab") + d.add_segment(0, (2, 3)) + d.add_segment(1, (2, 3)) + d.add_path(0, 0, c=1.0) + d.assert_valid() + d_consol = d.consolidate_modes("ab") + assert d_consol.subscripts == "a,a" From cb481b496210af757ba0bd3b91771b237317ad8a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 10:00:08 +0100 Subject: [PATCH 068/107] add test --- cuequivariance/tests/equivariant_polynomial_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/cuequivariance/tests/equivariant_polynomial_test.py b/cuequivariance/tests/equivariant_polynomial_test.py index 19ef6382..04c43260 100644 --- a/cuequivariance/tests/equivariant_polynomial_test.py +++ b/cuequivariance/tests/equivariant_polynomial_test.py @@ -96,3 +96,15 @@ def enc(th: float): y2 = A @ B @ C @ x np.testing.assert_allclose(y1, y2) + + +def test_elementwise_tensor_product(): + irreps1 = cue.Irreps("O3", "8x0e + 7x1o + 4x2e") + irreps2 = cue.Irreps("O3", "8x0e + 7x1o + 4x0e") + irreps3 = cue.Irreps("O3", "0e + 1o + 1e") + poly = cue.descriptors.elementwise_tensor_product(irreps1, irreps2, irreps3) + # Check that squeezing then flattening equals flattening then squeezing + assert ( + poly.squeeze_modes().flatten_coefficient_modes() + == poly.flatten_coefficient_modes().squeeze_modes() + ) From 1460ff7de7a67e3b65a5716529a553c50b837bf3 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 10:03:54 +0100 Subject: [PATCH 069/107] add test --- cuequivariance/tests/equivariant_polynomial_test.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/cuequivariance/tests/equivariant_polynomial_test.py b/cuequivariance/tests/equivariant_polynomial_test.py index 04c43260..532a3cdb 100644 --- a/cuequivariance/tests/equivariant_polynomial_test.py +++ b/cuequivariance/tests/equivariant_polynomial_test.py @@ -103,8 +103,17 @@ def test_elementwise_tensor_product(): irreps2 = cue.Irreps("O3", "8x0e + 7x1o + 4x0e") irreps3 = cue.Irreps("O3", "0e + 1o + 1e") poly = cue.descriptors.elementwise_tensor_product(irreps1, irreps2, irreps3) - # Check that squeezing then flattening equals flattening then squeezing assert ( poly.squeeze_modes().flatten_coefficient_modes() == poly.flatten_coefficient_modes().squeeze_modes() ) + assert poly.num_inputs == 2 + assert poly.num_outputs == 1 + + +def test_symmetric_contraction(): + irreps_in = 16 * cue.Irreps("SO3", "0 + 1 + 2") + irreps_out = 16 * cue.Irreps("SO3", "0 + 1") + poly = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, [0, 1, 2, 3]) + assert poly.num_inputs == 2 + assert poly.num_outputs == 1 From d9e31b21eaadb4b1745e078c645ea1cc89b474c2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 11:10:49 +0100 Subject: [PATCH 070/107] optimize SegmentedTensorProduct.symmetries --- .../cuequivariance/misc/permutations.py | 44 ++++++++++ .../cuequivariance/segmented_operand.py | 10 +++ .../cuequivariance/segmented_polynomial.py | 1 - .../segmented_tensor_product/path.py | 3 +- .../segmented_tensor_product.py | 81 ++++++++++++------- .../tests/equivariant_polynomial_test.py | 7 ++ 6 files changed, 117 insertions(+), 29 deletions(-) create mode 100644 cuequivariance/cuequivariance/misc/permutations.py diff --git a/cuequivariance/cuequivariance/misc/permutations.py b/cuequivariance/cuequivariance/misc/permutations.py new file mode 100644 index 00000000..9d3fda6a --- /dev/null +++ b/cuequivariance/cuequivariance/misc/permutations.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Sequence + + +def compose_permutations(p1: tuple[int, ...], p2: tuple[int, ...]) -> tuple[int, ...]: + """Compose two permutations""" + return tuple(p1[p2[i]] for i in range(len(p1))) + + +def inverse_permutation(p: tuple[int, ...]) -> tuple[int, ...]: + """Inverse a permutation""" + return tuple(p.index(i) for i in range(len(p))) + + +def generate_permutations_from( + generators: Sequence[tuple[int, ...]], +) -> set[tuple[int, ...]]: + """Generate all permutations from a list of generators""" + result = set(generators) + + while True: + n = len(result) + new_result = result.copy() + for g in result: + for h in result: + new_result.add(compose_permutations(g, h)) + if len(new_result) == n: + break + result = new_result + + return result diff --git a/cuequivariance/cuequivariance/segmented_operand.py b/cuequivariance/cuequivariance/segmented_operand.py index 2cd1dd77..44f8e523 100644 --- a/cuequivariance/cuequivariance/segmented_operand.py +++ b/cuequivariance/cuequivariance/segmented_operand.py @@ -46,6 +46,8 @@ def __init__( for segment in self.segments: for i, d in enumerate(segment): _dims.setdefault(i, set()).add(d) + else: + _dims = _dims.copy() object.__setattr__(self, "_dims", _dims) @classmethod @@ -71,6 +73,14 @@ def stack(cls, operands: list[SegmentedOperand]) -> SegmentedOperand: _dims=_dims, ) + def copy(self) -> SegmentedOperand: + """Copy the operand.""" + return SegmentedOperand( + ndim=self.ndim, + segments=self.segments, + _dims=self._dims, + ) + def assert_valid(self): """Assert that the operand is valid.""" for segment in self.segments: diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomial.py index f0a75f0d..eb84eb1d 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomial.py @@ -262,7 +262,6 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): .squeeze_modes() .remove_empty_segments() .consolidate_paths() - .sort_paths() ) if stp.num_paths == 0: return None diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/path.py b/cuequivariance/cuequivariance/segmented_tensor_product/path.py index 8af08de5..8b59ea11 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/path.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/path.py @@ -39,7 +39,8 @@ class Path: def __init__(self, indices, coefficients): super().__setattr__("indices", tuple(int(i) for i in indices)) super().__setattr__( - "coefficients", np.asarray(coefficients, dtype=np.float64, order="C").copy() + "coefficients", + np.asarray(coefficients, dtype=np.float64, order="C", copy=True), ) def assert_valid(self): diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py index c035719c..1bfb1dc5 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py @@ -33,6 +33,10 @@ import cuequivariance as cue # noqa: F401 from cuequivariance import segmented_tensor_product as stp from cuequivariance.misc.linalg import round_to_rational, round_to_sqrt_rational +from cuequivariance.misc.permutations import ( + generate_permutations_from, + inverse_permutation, +) from .dimensions_dict import format_dimensions_dict @@ -66,10 +70,9 @@ class SegmentedTensorProduct: def __init__( self, *, - operands_and_subscripts: Optional[ - list[tuple[cue.SegmentedOperand, stp.Subscripts]] - ] = None, - paths: Optional[list[stp.Path]] = None, + operands_and_subscripts: Sequence[tuple[cue.SegmentedOperand, stp.Subscripts]] + | None = None, + paths: Sequence[stp.Path] | None = None, coefficient_subscripts: str = "", ): if operands_and_subscripts is None: @@ -81,11 +84,10 @@ def __init__( self, "operands_and_subscripts", tuple( - (copy.deepcopy(ope), stp.Subscripts(ss)) - for ope, ss in operands_and_subscripts + (ope.copy(), stp.Subscripts(ss)) for ope, ss in operands_and_subscripts ), ) - object.__setattr__(self, "paths", tuple(copy.deepcopy(path) for path in paths)) + object.__setattr__(self, "paths", tuple(paths)) object.__setattr__(self, "coefficient_subscripts", coefficient_subscripts) def set_operand(self, oid: int, operand: cue.SegmentedOperand): @@ -450,7 +452,7 @@ def to_base64(self, extended: bool = False) -> str: """ return base64.b64encode(self.to_bytes(extended)).decode("ascii") - @functools.lru_cache(maxsize=None) + @functools.cache def get_dimensions_dict(self) -> dict[str, set[int]]: """Get the dimensions of the tensor product.""" dims: dict[str, set[int]] = {ch: set() for ch in self.subscripts.modes()} @@ -560,22 +562,34 @@ def compressed_path_segment(self, operand: int) -> np.ndarray: return np.append(0, np.bincount(i, minlength=n).cumsum()) - @functools.lru_cache(maxsize=None) + def operands_with_identical_segments(self) -> frozenset[frozenset[int]]: + """Groups of operands sharing the same segments.""" + operand_to_oid = collections.defaultdict(list) + for oid, ope in enumerate(self.operands): + operand_to_oid[ope].append(oid) + return frozenset(map(frozenset, operand_to_oid.values())) + + @functools.cache def symmetries(self) -> list[tuple[int, ...]]: """List of permutations that leave the tensor product invariant.""" - def functionally_equivalent( - d1: SegmentedTensorProduct, d2: SegmentedTensorProduct - ) -> bool: - d1 = d1.consolidate_paths().sort_paths() - d2 = d2.consolidate_paths().sort_paths() - return d1 == d2 + d = self.consolidate_paths() - ps = [] - for p in itertools.permutations(range(self.num_operands)): - if functionally_equivalent(self, self.permute_operands(p)): - ps.append(p) - return ps + ps = set() + for group in self.operands_with_identical_segments(): + group = sorted(group) + for p in itertools.permutations(range(len(group))): + p = tuple( + group[p[group.index(i)]] if i in group else i + for i in range(self.num_operands) + ) + if p in ps: + continue + if d == d.permute_operands(p).consolidate_paths(): + ps.add(p) + ps.add(inverse_permutation(p)) + ps = generate_permutations_from(ps) + return sorted(ps) def coefficients_equal_one(self) -> bool: """Check if all coefficients are equal to one.""" @@ -603,7 +617,7 @@ def flops( + d.subscripts.operands[-1] ) - @functools.lru_cache(maxsize=None) + @functools.cache def compute_cost(segment_shapes: tuple[tuple[int, ...], ...]) -> int: _, info = opt_einsum.contract_path( subscripts, *segment_shapes, optimize="optimal", shapes=True @@ -984,7 +998,7 @@ def append_modes_to_all_operands( def permute_operands(self, perm: tuple[int, ...]) -> SegmentedTensorProduct: """Permute the operands of the descriptor.""" - assert set(perm) == set(range(self.num_operands)) + # assert set(perm) == set(range(self.num_operands)) # removed for performance return dataclasses.replace( self, operands_and_subscripts=[self.operands_and_subscripts[i] for i in perm], @@ -1258,11 +1272,10 @@ def fuse_paths_with_same_indices(self) -> SegmentedTensorProduct: """Fuse paths with the same indices.""" paths = dict() for path in self.paths: - indices = tuple(path.indices) - if indices in paths: - paths[indices] += path.coefficients + if path.indices in paths: + paths[path.indices] += path.coefficients else: - paths[indices] = path.coefficients + paths[path.indices] = path.coefficients return dataclasses.replace( self, @@ -1272,9 +1285,23 @@ def fuse_paths_with_same_indices(self) -> SegmentedTensorProduct: ], ) + @functools.cache def consolidate_paths(self) -> SegmentedTensorProduct: """Consolidate the paths by merging duplicates and removing zeros.""" - return self.fuse_paths_with_same_indices().remove_zero_paths() + # equivalent to self.fuse_paths_with_same_indices().remove_zero_paths().sort_paths() + paths = dict() + for path in self.paths: + if path.indices in paths: + paths[path.indices] += path.coefficients + else: + paths[path.indices] = path.coefficients + paths = [ + stp.Path(indices=indices, coefficients=coefficients) + for indices, coefficients in paths.items() + if not np.all(coefficients == 0) + ] + paths = sorted(paths, key=lambda path: path.indices) + return dataclasses.replace(self, paths=paths) def sort_indices_for_identical_operands( self, operands: Sequence[int] diff --git a/cuequivariance/tests/equivariant_polynomial_test.py b/cuequivariance/tests/equivariant_polynomial_test.py index 532a3cdb..72d5cbfb 100644 --- a/cuequivariance/tests/equivariant_polynomial_test.py +++ b/cuequivariance/tests/equivariant_polynomial_test.py @@ -117,3 +117,10 @@ def test_symmetric_contraction(): poly = cue.descriptors.symmetric_contraction(irreps_in, irreps_out, [0, 1, 2, 3]) assert poly.num_inputs == 2 assert poly.num_outputs == 1 + + [_, _, _, (_, d)] = poly.polynomial.tensor_products + assert d.num_paths == 437 + + poly = poly.polynomial.unsymmetrize_for_identical_operands() + [_, _, _, (_, d)] = poly.tensor_products + assert d.num_paths == 105 From 3eda8061d39b2cdbf5c22f9522a12d6ba0742d99 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 11:31:50 +0100 Subject: [PATCH 071/107] add TODOs --- .../primitives/segmented_polynomial_ops_impl.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py index 8d6ac3e8..c874d8b5 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py @@ -71,6 +71,7 @@ def log(msg: str): return log(f"Unsupported buffer type: {b.dtype}") for i in indices: + # TODO: this restriction will be removed by MR109 if i.dtype.type != jnp.int32: return log(f"Unsupported index type: {i.dtype}") @@ -80,6 +81,7 @@ def log(msg: str): if len({b.shape[2] for b in buffers}.union({1})) != 2: return log(f"Buffer shapes not compatible {[b.shape for b in buffers]}") + # TODO: this restriction will be removed by MR109 if max(b.shape[2] for b in buffers) % 32 != 0: return log(f"Extend must be a multiple of 32, got {[b.shape for b in buffers]}") @@ -94,7 +96,7 @@ def log(msg: str): elif b.shape[0] != 1: batch_size = b.shape[0] - # TODO: remove if the backend supports atomic operations for float16/bfloat16 + # TODO: this restriction will be removed by MR109 for i, b in zip( buffer_index[polynomial.num_inputs :], buffers[polynomial.num_inputs :] ): @@ -110,8 +112,8 @@ def log(msg: str): Path, tensor_product_uniform_1d_jit, ) - except ImportError: - return log("cuequivariance_ops_jax is not installed") + except ImportError as e: + return log(f"cuequivariance_ops_jax is not installed: {e}") operations = [] paths = [] From fdbc603d9c74d7671eb89d7ff8cb9cf24e76ce27 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 11:34:08 +0100 Subject: [PATCH 072/107] Remove `IrrepsArray` in favor of `RepArray` --- CHANGELOG.md | 2 +- cuequivariance_jax/cuequivariance_jax/__init__.py | 3 +-- .../cuequivariance_jax/rep_array/jax_rep_array.py | 2 -- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b347e8d..1cd7d28c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ ### Breaking Changes - Rename `SegmentedTensorProduct.flop_cost` in `flops` - Rename `SegmentedTensorProduct.memory_cost` in `memory` - +- Removed `IrrepsArray` in favor of `RepArray` ## 0.3.0-rc1 diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index 8d6aab26..83f30bed 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -19,7 +19,7 @@ ) -from .rep_array.jax_rep_array import RepArray, from_segments, IrrepsArray +from .rep_array.jax_rep_array import RepArray, from_segments from .rep_array.vmap import vmap from .rep_array.utils import concatenate, randn, as_irreps_array, clebsch_gordan @@ -39,7 +39,6 @@ __all__ = [ "RepArray", "from_segments", - "IrrepsArray", "vmap", "concatenate", "randn", diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py index 9cee2d62..35f539f1 100644 --- a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py @@ -603,8 +603,6 @@ def decode_rep_array(static, data) -> RepArray: jax.tree_util.register_pytree_node(RepArray, encode_rep_array, decode_rep_array) -IrrepsArray = RepArray # TODO: do we deprecate IrrepsArray? - def from_segments( irreps: cue.Irreps | str, From 239e002370822e8ace64c3ea2e1032d8f74ada14 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 12:31:33 +0100 Subject: [PATCH 073/107] move files --- cuequivariance/cuequivariance/__init__.py | 36 ++++----- .../cuequivariance/{misc => etc}/linalg.py | 0 .../{misc => etc}/permutations.py | 0 .../{misc => etc}/sympy_utils.py | 0 .../cuequivariance/group_theory/__init__.py | 75 +++++++++++++++++++ .../descriptors/__init__.py | 0 .../descriptors/irreps_tp.py | 0 .../descriptors/rotations.py | 0 .../descriptors/spherical_harmonics_.py | 0 .../descriptors/symmetric_contractions.py | 0 .../descriptors/transposition.py | 0 .../equivariant_polynomial.py | 0 .../equivariant_tensor_product.py | 0 .../experimental/__init__.py | 0 .../{ => group_theory}/experimental/e3nn.py | 0 .../{ => group_theory}/experimental/escn.py | 0 .../{ => group_theory}/experimental/gatr.py | 0 .../experimental/mace/__init__.py | 0 .../mace/symmetric_contractions.py | 0 .../irreps_array/__init__.py | 0 .../irreps_array/context_decorator.py | 0 .../irreps_array/context_irrep_class.py | 0 .../irreps_array/context_layout.py | 0 .../irreps_array/irrep_utils.py | 0 .../{ => group_theory}/irreps_array/irreps.py | 0 .../irreps_array/irreps_and_layout.py | 0 .../irreps_array/irreps_layout.py | 0 .../irreps_array/misc_ui.py | 0 .../irreps_array/numpy_irreps_array.py | 0 .../irreps_array/reduced_tensor_product.py | 0 .../representations}/__init__.py | 0 .../representations}/irrep.py | 0 .../representations}/irrep_o3.py | 0 .../representations}/irrep_so3.py | 0 .../representations}/irrep_su2.py | 0 .../representations}/rep.py | 0 .../__init__.py | 9 ++- .../dimensions_dict.py | 0 .../dispatch.py | 0 .../dot.py | 0 .../evaluate.py | 0 .../{ => segmented_polynomials}/operation.py | 0 .../path.py | 0 .../segmented_operand.py | 0 .../segmented_polynomial.py | 2 +- .../segmented_tensor_product.py | 0 .../subscripts.py | 0 .../{linalg => etc}/round_to_rational_test.py | 0 .../tests/{ => group_theory}/context_test.py | 0 .../equivariant_polynomial_test.py | 0 .../experimental/escn_test.py | 0 .../experimental/gatr_test.py | 0 .../experimental/mace_test.py | 0 .../{ => group_theory}/irreps/irreps_test.py | 0 .../{ => group_theory}/irreps_array_test.py | 0 .../reduced_tensor_product_test.py | 0 .../dot_test.py | 0 .../evaluate_test.py} | 0 .../operation_test.py | 0 .../segmented_polynomial_test.py | 0 .../segmented_tensor_product_test.py} | 0 .../subscripts_test.py | 0 62 files changed, 100 insertions(+), 22 deletions(-) rename cuequivariance/cuequivariance/{misc => etc}/linalg.py (100%) rename cuequivariance/cuequivariance/{misc => etc}/permutations.py (100%) rename cuequivariance/cuequivariance/{misc => etc}/sympy_utils.py (100%) create mode 100644 cuequivariance/cuequivariance/group_theory/__init__.py rename cuequivariance/cuequivariance/{ => group_theory}/descriptors/__init__.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/descriptors/irreps_tp.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/descriptors/rotations.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/descriptors/spherical_harmonics_.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/descriptors/symmetric_contractions.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/descriptors/transposition.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/equivariant_polynomial.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/equivariant_tensor_product.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/experimental/__init__.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/experimental/e3nn.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/experimental/escn.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/experimental/gatr.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/experimental/mace/__init__.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/experimental/mace/symmetric_contractions.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/__init__.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/context_decorator.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/context_irrep_class.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/context_layout.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/irrep_utils.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/irreps.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/irreps_and_layout.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/irreps_layout.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/misc_ui.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/numpy_irreps_array.py (100%) rename cuequivariance/cuequivariance/{ => group_theory}/irreps_array/reduced_tensor_product.py (100%) rename cuequivariance/cuequivariance/{representation => group_theory/representations}/__init__.py (100%) rename cuequivariance/cuequivariance/{representation => group_theory/representations}/irrep.py (100%) rename cuequivariance/cuequivariance/{representation => group_theory/representations}/irrep_o3.py (100%) rename cuequivariance/cuequivariance/{representation => group_theory/representations}/irrep_so3.py (100%) rename cuequivariance/cuequivariance/{representation => group_theory/representations}/irrep_su2.py (100%) rename cuequivariance/cuequivariance/{representation => group_theory/representations}/rep.py (100%) rename cuequivariance/cuequivariance/{segmented_tensor_product => segmented_polynomials}/__init__.py (79%) rename cuequivariance/cuequivariance/{segmented_tensor_product => segmented_polynomials}/dimensions_dict.py (100%) rename cuequivariance/cuequivariance/{segmented_tensor_product => segmented_polynomials}/dispatch.py (100%) rename cuequivariance/cuequivariance/{segmented_tensor_product => segmented_polynomials}/dot.py (100%) rename cuequivariance/cuequivariance/{segmented_tensor_product => segmented_polynomials}/evaluate.py (100%) rename cuequivariance/cuequivariance/{ => segmented_polynomials}/operation.py (100%) rename cuequivariance/cuequivariance/{segmented_tensor_product => segmented_polynomials}/path.py (100%) rename cuequivariance/cuequivariance/{ => segmented_polynomials}/segmented_operand.py (100%) rename cuequivariance/cuequivariance/{ => segmented_polynomials}/segmented_polynomial.py (99%) rename cuequivariance/cuequivariance/{segmented_tensor_product => segmented_polynomials}/segmented_tensor_product.py (100%) rename cuequivariance/cuequivariance/{segmented_tensor_product => segmented_polynomials}/subscripts.py (100%) rename cuequivariance/tests/{linalg => etc}/round_to_rational_test.py (100%) rename cuequivariance/tests/{ => group_theory}/context_test.py (100%) rename cuequivariance/tests/{ => group_theory}/equivariant_polynomial_test.py (100%) rename cuequivariance/tests/{ => group_theory}/experimental/escn_test.py (100%) rename cuequivariance/tests/{ => group_theory}/experimental/gatr_test.py (100%) rename cuequivariance/tests/{ => group_theory}/experimental/mace_test.py (100%) rename cuequivariance/tests/{ => group_theory}/irreps/irreps_test.py (100%) rename cuequivariance/tests/{ => group_theory}/irreps_array_test.py (100%) rename cuequivariance/tests/{ => group_theory}/reduced_tensor_product_test.py (100%) rename cuequivariance/tests/{segmented_tensor_product => segmented_polynomials}/dot_test.py (100%) rename cuequivariance/tests/{segmented_tensor_product/compute_last_operand_test.py => segmented_polynomials/evaluate_test.py} (100%) rename cuequivariance/tests/{ => segmented_polynomials}/operation_test.py (100%) rename cuequivariance/tests/{ => segmented_polynomials}/segmented_polynomial_test.py (100%) rename cuequivariance/tests/{segmented_tensor_product/descriptor_test.py => segmented_polynomials/segmented_tensor_product_test.py} (100%) rename cuequivariance/tests/{segmented_tensor_product => segmented_polynomials}/subscripts_test.py (100%) diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index b489c02b..bf8baec5 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -18,7 +18,14 @@ importlib.resources.files(__package__).joinpath("VERSION").read_text().strip() ) -from cuequivariance.representation import ( +from .segmented_polynomials import ( + Operation, + SegmentedOperand, + SegmentedTensorProduct, + SegmentedPolynomial, +) + +from .group_theory import ( Rep, Irrep, clebsch_gordan, @@ -27,9 +34,6 @@ SU2, SO3, O3, -) - -from cuequivariance.irreps_array import ( get_irrep_scope, MulIrrep, Irreps, @@ -47,16 +51,14 @@ reduced_antisymmetric_tensor_product_basis, ) -from cuequivariance.operation import Operation -from cuequivariance.segmented_operand import SegmentedOperand -from cuequivariance.segmented_tensor_product import SegmentedTensorProduct -from cuequivariance.segmented_polynomial import SegmentedPolynomial -from cuequivariance.equivariant_polynomial import EquivariantPolynomial -from cuequivariance.equivariant_tensor_product import EquivariantTensorProduct - -from cuequivariance import segmented_tensor_product, descriptors +from cuequivariance import segmented_polynomials, group_theory __all__ = [ + "__version__", + "Operation", + "SegmentedOperand", + "SegmentedTensorProduct", + "SegmentedPolynomial", "Rep", "Irrep", "clebsch_gordan", @@ -80,12 +82,6 @@ "reduced_tensor_product_basis", "reduced_symmetric_tensor_product_basis", "reduced_antisymmetric_tensor_product_basis", - "Operation", - "SegmentedOperand", - "SegmentedTensorProduct", - "SegmentedPolynomial", - "EquivariantPolynomial", - "EquivariantTensorProduct", - "segmented_tensor_product", - "descriptors", + "segmented_polynomials", + "group_theory", ] diff --git a/cuequivariance/cuequivariance/misc/linalg.py b/cuequivariance/cuequivariance/etc/linalg.py similarity index 100% rename from cuequivariance/cuequivariance/misc/linalg.py rename to cuequivariance/cuequivariance/etc/linalg.py diff --git a/cuequivariance/cuequivariance/misc/permutations.py b/cuequivariance/cuequivariance/etc/permutations.py similarity index 100% rename from cuequivariance/cuequivariance/misc/permutations.py rename to cuequivariance/cuequivariance/etc/permutations.py diff --git a/cuequivariance/cuequivariance/misc/sympy_utils.py b/cuequivariance/cuequivariance/etc/sympy_utils.py similarity index 100% rename from cuequivariance/cuequivariance/misc/sympy_utils.py rename to cuequivariance/cuequivariance/etc/sympy_utils.py diff --git a/cuequivariance/cuequivariance/group_theory/__init__.py b/cuequivariance/cuequivariance/group_theory/__init__.py new file mode 100644 index 00000000..4e221b98 --- /dev/null +++ b/cuequivariance/cuequivariance/group_theory/__init__.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .representations import ( + Rep, + Irrep, + clebsch_gordan, + selection_rule_product, + selection_rule_power, + SU2, + SO3, + O3, +) + +from .irreps_array import ( + get_irrep_scope, + MulIrrep, + Irreps, + IrrepsLayout, + mul_ir, + ir_mul, + IrrepsAndLayout, + get_layout_scope, + assume, + NumpyIrrepsArray, + from_segments, + concatenate, + reduced_tensor_product_basis, + reduced_symmetric_tensor_product_basis, + reduced_antisymmetric_tensor_product_basis, +) + +from .equivariant_polynomial import EquivariantPolynomial +from .equivariant_tensor_product import EquivariantTensorProduct + + +__all__ = [ + "Rep", + "Irrep", + "clebsch_gordan", + "selection_rule_product", + "selection_rule_power", + "SU2", + "SO3", + "O3", + "get_irrep_scope", + "MulIrrep", + "Irreps", + "IrrepsLayout", + "mul_ir", + "ir_mul", + "IrrepsAndLayout", + "get_layout_scope", + "assume", + "NumpyIrrepsArray", + "from_segments", + "concatenate", + "reduced_tensor_product_basis", + "reduced_symmetric_tensor_product_basis", + "reduced_antisymmetric_tensor_product_basis", + "EquivariantPolynomial", + "EquivariantTensorProduct", +] diff --git a/cuequivariance/cuequivariance/descriptors/__init__.py b/cuequivariance/cuequivariance/group_theory/descriptors/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/descriptors/__init__.py rename to cuequivariance/cuequivariance/group_theory/descriptors/__init__.py diff --git a/cuequivariance/cuequivariance/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py similarity index 100% rename from cuequivariance/cuequivariance/descriptors/irreps_tp.py rename to cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py diff --git a/cuequivariance/cuequivariance/descriptors/rotations.py b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py similarity index 100% rename from cuequivariance/cuequivariance/descriptors/rotations.py rename to cuequivariance/cuequivariance/group_theory/descriptors/rotations.py diff --git a/cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py similarity index 100% rename from cuequivariance/cuequivariance/descriptors/spherical_harmonics_.py rename to cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py diff --git a/cuequivariance/cuequivariance/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py similarity index 100% rename from cuequivariance/cuequivariance/descriptors/symmetric_contractions.py rename to cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py diff --git a/cuequivariance/cuequivariance/descriptors/transposition.py b/cuequivariance/cuequivariance/group_theory/descriptors/transposition.py similarity index 100% rename from cuequivariance/cuequivariance/descriptors/transposition.py rename to cuequivariance/cuequivariance/group_theory/descriptors/transposition.py diff --git a/cuequivariance/cuequivariance/equivariant_polynomial.py b/cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py similarity index 100% rename from cuequivariance/cuequivariance/equivariant_polynomial.py rename to cuequivariance/cuequivariance/group_theory/equivariant_polynomial.py diff --git a/cuequivariance/cuequivariance/equivariant_tensor_product.py b/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py similarity index 100% rename from cuequivariance/cuequivariance/equivariant_tensor_product.py rename to cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py diff --git a/cuequivariance/cuequivariance/experimental/__init__.py b/cuequivariance/cuequivariance/group_theory/experimental/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/__init__.py rename to cuequivariance/cuequivariance/group_theory/experimental/__init__.py diff --git a/cuequivariance/cuequivariance/experimental/e3nn.py b/cuequivariance/cuequivariance/group_theory/experimental/e3nn.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/e3nn.py rename to cuequivariance/cuequivariance/group_theory/experimental/e3nn.py diff --git a/cuequivariance/cuequivariance/experimental/escn.py b/cuequivariance/cuequivariance/group_theory/experimental/escn.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/escn.py rename to cuequivariance/cuequivariance/group_theory/experimental/escn.py diff --git a/cuequivariance/cuequivariance/experimental/gatr.py b/cuequivariance/cuequivariance/group_theory/experimental/gatr.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/gatr.py rename to cuequivariance/cuequivariance/group_theory/experimental/gatr.py diff --git a/cuequivariance/cuequivariance/experimental/mace/__init__.py b/cuequivariance/cuequivariance/group_theory/experimental/mace/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/mace/__init__.py rename to cuequivariance/cuequivariance/group_theory/experimental/mace/__init__.py diff --git a/cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py similarity index 100% rename from cuequivariance/cuequivariance/experimental/mace/symmetric_contractions.py rename to cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py diff --git a/cuequivariance/cuequivariance/irreps_array/__init__.py b/cuequivariance/cuequivariance/group_theory/irreps_array/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/__init__.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/__init__.py diff --git a/cuequivariance/cuequivariance/irreps_array/context_decorator.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/context_decorator.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py diff --git a/cuequivariance/cuequivariance/irreps_array/context_irrep_class.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/context_irrep_class.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py diff --git a/cuequivariance/cuequivariance/irreps_array/context_layout.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/context_layout.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py diff --git a/cuequivariance/cuequivariance/irreps_array/irrep_utils.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/irrep_utils.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py diff --git a/cuequivariance/cuequivariance/irreps_array/irreps.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/irreps.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/irreps.py diff --git a/cuequivariance/cuequivariance/irreps_array/irreps_and_layout.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/irreps_and_layout.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py diff --git a/cuequivariance/cuequivariance/irreps_array/irreps_layout.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_layout.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/irreps_layout.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/irreps_layout.py diff --git a/cuequivariance/cuequivariance/irreps_array/misc_ui.py b/cuequivariance/cuequivariance/group_theory/irreps_array/misc_ui.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/misc_ui.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/misc_ui.py diff --git a/cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py b/cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/numpy_irreps_array.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py diff --git a/cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py b/cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py similarity index 100% rename from cuequivariance/cuequivariance/irreps_array/reduced_tensor_product.py rename to cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py diff --git a/cuequivariance/cuequivariance/representation/__init__.py b/cuequivariance/cuequivariance/group_theory/representations/__init__.py similarity index 100% rename from cuequivariance/cuequivariance/representation/__init__.py rename to cuequivariance/cuequivariance/group_theory/representations/__init__.py diff --git a/cuequivariance/cuequivariance/representation/irrep.py b/cuequivariance/cuequivariance/group_theory/representations/irrep.py similarity index 100% rename from cuequivariance/cuequivariance/representation/irrep.py rename to cuequivariance/cuequivariance/group_theory/representations/irrep.py diff --git a/cuequivariance/cuequivariance/representation/irrep_o3.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py similarity index 100% rename from cuequivariance/cuequivariance/representation/irrep_o3.py rename to cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py diff --git a/cuequivariance/cuequivariance/representation/irrep_so3.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py similarity index 100% rename from cuequivariance/cuequivariance/representation/irrep_so3.py rename to cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py diff --git a/cuequivariance/cuequivariance/representation/irrep_su2.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py similarity index 100% rename from cuequivariance/cuequivariance/representation/irrep_su2.py rename to cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py diff --git a/cuequivariance/cuequivariance/representation/rep.py b/cuequivariance/cuequivariance/group_theory/representations/rep.py similarity index 100% rename from cuequivariance/cuequivariance/representation/rep.py rename to cuequivariance/cuequivariance/group_theory/representations/rep.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/__init__.py b/cuequivariance/cuequivariance/segmented_polynomials/__init__.py similarity index 79% rename from cuequivariance/cuequivariance/segmented_tensor_product/__init__.py rename to cuequivariance/cuequivariance/segmented_polynomials/__init__.py index 3ff358a5..42cadb24 100644 --- a/cuequivariance/cuequivariance/segmented_tensor_product/__init__.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,20 +14,27 @@ # limitations under the License. from .subscripts import Subscripts from .path import Path +from .segmented_operand import SegmentedOperand from .segmented_tensor_product import SegmentedTensorProduct from .dot import dot, trace from .evaluate import compute_last_operand, primitive_compute_last_operand from .dispatch import dispatch +from .operation import Operation +from .segmented_polynomial import SegmentedPolynomial + __all__ = [ "Subscripts", "Path", + "SegmentedOperand", "SegmentedTensorProduct", "dot", "trace", "compute_last_operand", "primitive_compute_last_operand", "dispatch", + "Operation", + "SegmentedPolynomial", ] diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dimensions_dict.py b/cuequivariance/cuequivariance/segmented_polynomials/dimensions_dict.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/dimensions_dict.py rename to cuequivariance/cuequivariance/segmented_polynomials/dimensions_dict.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py b/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/dispatch.py rename to cuequivariance/cuequivariance/segmented_polynomials/dispatch.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/dot.py b/cuequivariance/cuequivariance/segmented_polynomials/dot.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/dot.py rename to cuequivariance/cuequivariance/segmented_polynomials/dot.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py b/cuequivariance/cuequivariance/segmented_polynomials/evaluate.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/evaluate.py rename to cuequivariance/cuequivariance/segmented_polynomials/evaluate.py diff --git a/cuequivariance/cuequivariance/operation.py b/cuequivariance/cuequivariance/segmented_polynomials/operation.py similarity index 100% rename from cuequivariance/cuequivariance/operation.py rename to cuequivariance/cuequivariance/segmented_polynomials/operation.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/path.py b/cuequivariance/cuequivariance/segmented_polynomials/path.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/path.py rename to cuequivariance/cuequivariance/segmented_polynomials/path.py diff --git a/cuequivariance/cuequivariance/segmented_operand.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_operand.py rename to cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py diff --git a/cuequivariance/cuequivariance/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py similarity index 99% rename from cuequivariance/cuequivariance/segmented_polynomial.py rename to cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index eb84eb1d..13e77cfa 100644 --- a/cuequivariance/cuequivariance/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -22,7 +22,7 @@ import numpy as np import cuequivariance as cue -from cuequivariance.operation import IVARS, OVARS +from cuequivariance.segmented_polynomials.operation import IVARS, OVARS @dataclasses.dataclass(init=False, frozen=True) diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/segmented_tensor_product.py rename to cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py diff --git a/cuequivariance/cuequivariance/segmented_tensor_product/subscripts.py b/cuequivariance/cuequivariance/segmented_polynomials/subscripts.py similarity index 100% rename from cuequivariance/cuequivariance/segmented_tensor_product/subscripts.py rename to cuequivariance/cuequivariance/segmented_polynomials/subscripts.py diff --git a/cuequivariance/tests/linalg/round_to_rational_test.py b/cuequivariance/tests/etc/round_to_rational_test.py similarity index 100% rename from cuequivariance/tests/linalg/round_to_rational_test.py rename to cuequivariance/tests/etc/round_to_rational_test.py diff --git a/cuequivariance/tests/context_test.py b/cuequivariance/tests/group_theory/context_test.py similarity index 100% rename from cuequivariance/tests/context_test.py rename to cuequivariance/tests/group_theory/context_test.py diff --git a/cuequivariance/tests/equivariant_polynomial_test.py b/cuequivariance/tests/group_theory/equivariant_polynomial_test.py similarity index 100% rename from cuequivariance/tests/equivariant_polynomial_test.py rename to cuequivariance/tests/group_theory/equivariant_polynomial_test.py diff --git a/cuequivariance/tests/experimental/escn_test.py b/cuequivariance/tests/group_theory/experimental/escn_test.py similarity index 100% rename from cuequivariance/tests/experimental/escn_test.py rename to cuequivariance/tests/group_theory/experimental/escn_test.py diff --git a/cuequivariance/tests/experimental/gatr_test.py b/cuequivariance/tests/group_theory/experimental/gatr_test.py similarity index 100% rename from cuequivariance/tests/experimental/gatr_test.py rename to cuequivariance/tests/group_theory/experimental/gatr_test.py diff --git a/cuequivariance/tests/experimental/mace_test.py b/cuequivariance/tests/group_theory/experimental/mace_test.py similarity index 100% rename from cuequivariance/tests/experimental/mace_test.py rename to cuequivariance/tests/group_theory/experimental/mace_test.py diff --git a/cuequivariance/tests/irreps/irreps_test.py b/cuequivariance/tests/group_theory/irreps/irreps_test.py similarity index 100% rename from cuequivariance/tests/irreps/irreps_test.py rename to cuequivariance/tests/group_theory/irreps/irreps_test.py diff --git a/cuequivariance/tests/irreps_array_test.py b/cuequivariance/tests/group_theory/irreps_array_test.py similarity index 100% rename from cuequivariance/tests/irreps_array_test.py rename to cuequivariance/tests/group_theory/irreps_array_test.py diff --git a/cuequivariance/tests/reduced_tensor_product_test.py b/cuequivariance/tests/group_theory/reduced_tensor_product_test.py similarity index 100% rename from cuequivariance/tests/reduced_tensor_product_test.py rename to cuequivariance/tests/group_theory/reduced_tensor_product_test.py diff --git a/cuequivariance/tests/segmented_tensor_product/dot_test.py b/cuequivariance/tests/segmented_polynomials/dot_test.py similarity index 100% rename from cuequivariance/tests/segmented_tensor_product/dot_test.py rename to cuequivariance/tests/segmented_polynomials/dot_test.py diff --git a/cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py b/cuequivariance/tests/segmented_polynomials/evaluate_test.py similarity index 100% rename from cuequivariance/tests/segmented_tensor_product/compute_last_operand_test.py rename to cuequivariance/tests/segmented_polynomials/evaluate_test.py diff --git a/cuequivariance/tests/operation_test.py b/cuequivariance/tests/segmented_polynomials/operation_test.py similarity index 100% rename from cuequivariance/tests/operation_test.py rename to cuequivariance/tests/segmented_polynomials/operation_test.py diff --git a/cuequivariance/tests/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py similarity index 100% rename from cuequivariance/tests/segmented_polynomial_test.py rename to cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py diff --git a/cuequivariance/tests/segmented_tensor_product/descriptor_test.py b/cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py similarity index 100% rename from cuequivariance/tests/segmented_tensor_product/descriptor_test.py rename to cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py diff --git a/cuequivariance/tests/segmented_tensor_product/subscripts_test.py b/cuequivariance/tests/segmented_polynomials/subscripts_test.py similarity index 100% rename from cuequivariance/tests/segmented_tensor_product/subscripts_test.py rename to cuequivariance/tests/segmented_polynomials/subscripts_test.py From cd7113af064e491f08f72beab4f1f857280d2271 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 12:49:29 +0100 Subject: [PATCH 074/107] quick import fix --- cuequivariance/cuequivariance/__init__.py | 6 +++++ .../group_theory/descriptors/irreps_tp.py | 4 +-- .../group_theory/descriptors/rotations.py | 2 +- .../descriptors/spherical_harmonics_.py | 4 +-- .../descriptors/symmetric_contractions.py | 2 +- .../equivariant_tensor_product.py | 2 +- .../mace/symmetric_contractions.py | 2 +- .../irreps_array/context_decorator.py | 10 +++---- .../irreps_array/context_irrep_class.py | 5 ++-- .../irreps_array/context_layout.py | 2 +- .../group_theory/irreps_array/irrep_utils.py | 14 +++++----- .../irreps_array/irreps_and_layout.py | 3 ++- .../irreps_array/numpy_irreps_array.py | 2 +- .../irreps_array/reduced_tensor_product.py | 26 +++++++++---------- .../group_theory/representations/irrep.py | 2 +- .../group_theory/representations/irrep_o3.py | 2 +- .../group_theory/representations/irrep_so3.py | 4 +-- .../group_theory/representations/irrep_su2.py | 2 +- .../segmented_polynomials/dispatch.py | 2 +- .../segmented_operand.py | 2 +- .../segmented_polynomial.py | 2 +- .../segmented_tensor_product.py | 6 ++--- .../tests/etc/round_to_rational_test.py | 2 +- .../group_theory/experimental/escn_test.py | 2 +- .../group_theory/experimental/gatr_test.py | 2 +- .../group_theory/experimental/mace_test.py | 2 +- .../tests/segmented_polynomials/dot_test.py | 4 +-- .../segmented_polynomials/evaluate_test.py | 10 +++---- .../segmented_tensor_product_test.py | 18 ++++++------- .../segmented_polynomials/subscripts_test.py | 2 +- .../cuequivariance_jax/rep_array/utils.py | 2 +- .../cuequivariance_torch/layers/batchnorm.py | 5 +++- .../layers/tp_conv_fully_connected.py | 2 +- .../cuequivariance_torch/operations/linear.py | 5 +++- .../operations/rotation.py | 2 +- .../operations/symmetric_contraction.py | 5 +++- .../operations/tp_channel_wise.py | 5 +++- .../operations/tp_fully_connected.py | 5 +++- .../primitives/equivariant_tensor_product.py | 2 +- 39 files changed, 101 insertions(+), 80 deletions(-) diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index bf8baec5..cfc39818 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -49,8 +49,11 @@ reduced_tensor_product_basis, reduced_symmetric_tensor_product_basis, reduced_antisymmetric_tensor_product_basis, + EquivariantPolynomial, + EquivariantTensorProduct, # deprecated ) +from cuequivariance.group_theory import descriptors from cuequivariance import segmented_polynomials, group_theory __all__ = [ @@ -82,6 +85,9 @@ "reduced_tensor_product_basis", "reduced_symmetric_tensor_product_basis", "reduced_antisymmetric_tensor_product_basis", + "EquivariantPolynomial", + "EquivariantTensorProduct", + "descriptors", "segmented_polynomials", "group_theory", ] diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py index 1b7a42f6..e2768310 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py @@ -16,8 +16,8 @@ from typing import Optional, Sequence import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp -from cuequivariance.irreps_array.irrep_utils import into_list_of_irrep +from cuequivariance import segmented_polynomials as stp +from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep def fully_connected_tensor_product( diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py index f42804fb..ebb4e875 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py @@ -17,7 +17,7 @@ import numpy as np import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp +from cuequivariance import segmented_polynomials as stp def fixed_axis_angle_rotation( diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index 82cf9c13..fe7935b9 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -17,8 +17,8 @@ import sympy as sp import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp -from cuequivariance.misc.sympy_utils import sqrtQarray_to_sympy +from cuequivariance import segmented_polynomials as stp +from cuequivariance.etc.sympy_utils import sqrtQarray_to_sympy def spherical_harmonics( diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py index 64384e32..dff2873f 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp +from cuequivariance import segmented_polynomials as stp def symmetric_contraction( diff --git a/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py b/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py index 2bac0331..a1e8e426 100644 --- a/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py @@ -20,7 +20,7 @@ from typing import Optional, Sequence, Union import cuequivariance as cue -from cuequivariance import segmented_tensor_product as stp +from cuequivariance import segmented_polynomials as stp @dataclasses.dataclass(init=False, frozen=True) diff --git a/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py index b4959842..c6fb322c 100644 --- a/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py @@ -18,7 +18,7 @@ import numpy as np import cuequivariance as cue -from cuequivariance.misc.linalg import round_to_sqrt_rational, triu_array +from cuequivariance.etc.linalg import round_to_sqrt_rational, triu_array def symmetric_contraction( diff --git a/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py index 3ee8e902..78fda178 100644 --- a/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py @@ -15,13 +15,13 @@ from functools import wraps from typing import Optional, Type, Union -import cuequivariance as cue -import cuequivariance.irreps_array as irreps_array -from cuequivariance.irreps_array.context_irrep_class import ( +import cuequivariance.group_theory.irreps_array as irreps_array +from cuequivariance.group_theory import Irrep +from cuequivariance.group_theory.irreps_array.context_irrep_class import ( pop_irrep_scope, push_irrep_scope, ) -from cuequivariance.irreps_array.context_layout import ( +from cuequivariance.group_theory.irreps_array.context_layout import ( pop_layout_scope, push_layout_scope, ) @@ -51,7 +51,7 @@ class assume: def __init__( self, - irrep_class: Optional[Union[str, Type[cue.Irrep]]] = None, + irrep_class: Optional[Union[str, Type[Irrep]]] = None, layout: Optional[irreps_array.IrrepsLayout] = None, ): if isinstance(irrep_class, irreps_array.IrrepsLayout) and layout is None: diff --git a/cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py index 84b4c370..8be62c8c 100644 --- a/cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/context_irrep_class.py @@ -15,11 +15,12 @@ from typing import Type, Union import cuequivariance as cue +from cuequivariance.group_theory.representations import Irrep -_irrep_class: Union[None, str, Type[cue.Irrep]] = None +_irrep_class: Union[None, str, Type[Irrep]] = None -def get_irrep_scope(raising: bool = True) -> Type[cue.Irrep]: +def get_irrep_scope(raising: bool = True) -> Type[Irrep]: if raising and _irrep_class is None: raise ValueError( "No irrep class set in the context. Please specify the irrep class explicitly or use ``with cue.assume(irrep):``." diff --git a/cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py index 766c5df6..62a33246 100644 --- a/cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/context_layout.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Union -from cuequivariance.irreps_array import IrrepsLayout +from cuequivariance.group_theory.irreps_array import IrrepsLayout _layout: Union[None, IrrepsLayout] = None diff --git a/cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py index 3ecb66a6..ec119d90 100644 --- a/cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/irrep_utils.py @@ -15,21 +15,21 @@ from typing import Iterable, Type, Union import cuequivariance as cue -from cuequivariance import irreps_array +from cuequivariance.group_theory import Irrep, irreps_array def into_list_of_irrep( - irrep_class: Type[cue.Irrep], + irrep_class: Type[Irrep], input: Union[ str, - cue.Irrep, + Irrep, irreps_array.MulIrrep, - Iterable[Union[str, cue.Irrep, irreps_array.MulIrrep]], + Iterable[Union[str, Irrep, irreps_array.MulIrrep]], ], -) -> list[cue.Irrep]: +) -> list[Irrep]: if isinstance(input, str): return [rep for _, rep in cue.Irreps(irrep_class, input)] - if isinstance(input, cue.Irrep): + if isinstance(input, Irrep): return [input] if isinstance(input, irreps_array.MulIrrep): return [input.ir] @@ -41,7 +41,7 @@ def into_list_of_irrep( output = [] for rep in input: - if isinstance(rep, cue.Irrep): + if isinstance(rep, Irrep): output.append(rep) elif isinstance(rep, irreps_array.MulIrrep): output.append(rep.ir) diff --git a/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py index b92b7780..0abe4081 100644 --- a/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/irreps_and_layout.py @@ -19,10 +19,11 @@ import numpy as np import cuequivariance as cue +from cuequivariance.group_theory.representations import Rep @dataclass(init=False, frozen=True) -class IrrepsAndLayout(cue.Rep): +class IrrepsAndLayout(Rep): r""" A group representation (:class:`Rep`) made from the combination of :class:`Irreps` and :class:`IrrepsLayout` into a single object. diff --git a/cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py b/cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py index 2c51c629..ea1f34f3 100644 --- a/cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/numpy_irreps_array.py @@ -20,7 +20,7 @@ import numpy as np import cuequivariance as cue -from cuequivariance.irreps_array.irrep_utils import into_list_of_irrep +from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep # This class is inspired by https://github.com/e3nn/e3nn-jax/blob/245e17eb23deaccad9f2c9cfd40fe40515e3c074/e3nn_jax/_src/irreps_array.py diff --git a/cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py b/cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py index 854ed16f..79147099 100644 --- a/cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/reduced_tensor_product.py @@ -21,9 +21,7 @@ import numpy as np import cuequivariance as cue -from cuequivariance import irreps_array -from cuequivariance.irreps_array.irrep_utils import into_list_of_irrep -from cuequivariance.misc.linalg import ( +from cuequivariance.etc.linalg import ( basis_intersection, perm_compose, perm_inverse, @@ -31,6 +29,8 @@ round_to_sqrt_rational, sparsify_matrix, ) +from cuequivariance.group_theory import Irrep, irreps_array +from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep logger = logging.getLogger(__name__) @@ -55,7 +55,7 @@ def reduced_tensor_product_basis( *, layout: irreps_array.IrrepsLayout = irreps_array.mul_ir, epsilon: float = 1e-5, - keep_ir: Optional[Union[irreps_array.Irreps, List[cue.Irrep]]] = None, + keep_ir: Optional[Union[irreps_array.Irreps, List[Irrep]]] = None, _use_optimized_implementation: bool = True, **irreps_dict, ) -> irreps_array.NumpyIrrepsArray: @@ -150,7 +150,7 @@ def reduced_symmetric_tensor_product_basis( *, layout: irreps_array.IrrepsLayout = irreps_array.mul_ir, epsilon: float = 1e-5, - keep_ir: Optional[Union[irreps_array.Irreps, List[cue.Irrep]]] = None, + keep_ir: Optional[Union[irreps_array.Irreps, List[Irrep]]] = None, _use_optimized_implementation: bool = True, ) -> irreps_array.NumpyIrrepsArray: r"""Reduce a symmetric tensor product, usually called for a single irrep. @@ -197,7 +197,7 @@ def reduced_antisymmetric_tensor_product_basis( *, layout: irreps_array.IrrepsLayout = irreps_array.mul_ir, epsilon: float = 1e-5, - keep_ir: Optional[Union[irreps_array.Irreps, List[cue.Irrep]]] = None, + keep_ir: Optional[Union[irreps_array.Irreps, List[Irrep]]] = None, _use_optimized_implementation: bool = True, ) -> irreps_array.NumpyIrrepsArray: r"""Reduce an antisymmetric tensor product. @@ -240,7 +240,7 @@ def reduced_antisymmetric_tensor_product_basis( def _entrypoint( irreps_tuple: Tuple[irreps_array.Irreps, ...], perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], - keep_ir: Optional[FrozenSet[cue.Irrep]], + keep_ir: Optional[FrozenSet[Irrep]], layout: irreps_array.IrrepsLayout, epsilon: float, _use_optimized_implementation: bool, @@ -276,7 +276,7 @@ def _entrypoint( def _main_cached_recursive( irreps_tuple: Tuple[irreps_array.Irreps, ...], perm_repr: FrozenSet[Tuple[int, Tuple[int, ...]]], - keep_ir: Optional[FrozenSet[cue.Irrep]], + keep_ir: Optional[FrozenSet[Irrep]], epsilon: float, _use_optimized_implementation: bool, ) -> irreps_array.NumpyIrrepsArray: @@ -469,7 +469,7 @@ def _optimized_reduced_symmetric_tensor_product_basis( degree: int, *, epsilon: float, - keep_ir: Optional[List[cue.Irrep]] = None, + keep_ir: Optional[List[Irrep]] = None, ): r"""Reduce a symmetric tensor product. @@ -682,7 +682,7 @@ def germinate_perm_repr( def reduce_basis_product( basis1: irreps_array.NumpyIrrepsArray, basis2: irreps_array.NumpyIrrepsArray, - keep_ir: Optional[List[cue.Irrep]] = None, + keep_ir: Optional[List[Irrep]] = None, ) -> irreps_array.NumpyIrrepsArray: """Reduce the product of two basis.""" basis1 = basis1.regroup() @@ -692,7 +692,7 @@ def reduce_basis_product( layout = basis1.layout if keep_ir is not None: - assert all(isinstance(ir, cue.Irrep) for ir in keep_ir) + assert all(isinstance(ir, Irrep) for ir in keep_ir) keep_ir = frozenset(keep_ir) assert basis1.array.dtype == np.float64 @@ -710,7 +710,7 @@ def reduce_basis_product( if cache_key in _cache_reduce_basis_product: return _cache_reduce_basis_product[cache_key] - new_irreps: List[Tuple[int, cue.Irrep]] = [] + new_irreps: List[Tuple[int, Irrep]] = [] new_list = [] for (mul1, ir1), x1 in zip(basis1.irreps, basis1.segments): @@ -764,7 +764,7 @@ def constrain_rotation_basis_by_permutation_basis( (permutation_basis.shape[0], prod(permutation_basis.shape[1:])), ) # (free, dim) - new_irreps: List[Tuple[int, cue.Irrep]] = [] + new_irreps: List[Tuple[int, Irrep]] = [] new_list: List[np.ndarray] = [] for ir in sorted({ir for mul, ir in rotation_basis.irreps}): diff --git a/cuequivariance/cuequivariance/group_theory/representations/irrep.py b/cuequivariance/cuequivariance/group_theory/representations/irrep.py index e4338596..99cbd940 100644 --- a/cuequivariance/cuequivariance/group_theory/representations/irrep.py +++ b/cuequivariance/cuequivariance/group_theory/representations/irrep.py @@ -22,7 +22,7 @@ import numpy as np import cuequivariance as cue # noqa: F401 -from cuequivariance.representation import Rep +from cuequivariance.group_theory.representations import Rep # This class is inspired from https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irrep.py diff --git a/cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py index 524aa56e..cf88f992 100644 --- a/cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py +++ b/cuequivariance/cuequivariance/group_theory/representations/irrep_o3.py @@ -21,7 +21,7 @@ import numpy as np -from cuequivariance.representation import SO3, Irrep +from cuequivariance.group_theory.representations import SO3, Irrep # This class is an adaptation of https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irreps/o3_real.py diff --git a/cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py index 10d47d61..e7fcadca 100644 --- a/cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py +++ b/cuequivariance/cuequivariance/group_theory/representations/irrep_so3.py @@ -21,8 +21,8 @@ import numpy as np -from cuequivariance.misc.linalg import round_to_sqrt_rational -from cuequivariance.representation import SU2, Irrep +from cuequivariance.etc.linalg import round_to_sqrt_rational +from cuequivariance.group_theory.representations import SU2, Irrep # This function is copied from https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irreps/so3_real.py diff --git a/cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py b/cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py index d378bbea..11e8dea6 100644 --- a/cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py +++ b/cuequivariance/cuequivariance/group_theory/representations/irrep_su2.py @@ -22,7 +22,7 @@ import numpy as np -from cuequivariance.representation import Irrep +from cuequivariance.group_theory.representations import Irrep # This class is an adaptation of https://github.com/lie-nn/lie-nn/blob/70adebce44e3197ee17f780585c6570d836fc2fe/lie_nn/_src/irreps/su2.py diff --git a/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py b/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py index 6e5b3749..975b00e8 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py @@ -16,7 +16,7 @@ import math from typing import Generator, Tuple -import cuequivariance.segmented_tensor_product as stp # we cannot import cuequivariance as cue because of circular import +import cuequivariance.segmented_polynomials as stp # we cannot import cuequivariance as cue because of circular import def dispatch( diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py index 44f8e523..7e146139 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py @@ -17,7 +17,7 @@ import dataclasses import math -from .segmented_tensor_product.dimensions_dict import format_dimensions_dict +from cuequivariance.segmented_polynomials.dimensions_dict import format_dimensions_dict @dataclasses.dataclass(init=False, frozen=True) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index 13e77cfa..3e67b8ee 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -182,7 +182,7 @@ def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: for ope, stp in self.tensor_products: oid, bid = ope.output_operand_buffer(self.num_inputs) outputs[bid - self.num_inputs] += ( - cue.segmented_tensor_product.compute_last_operand( + cue.segmented_polynomials.compute_last_operand( stp.move_operand_last(oid), *[inputs[bid] for bid in ope.input_buffers(self.num_inputs)], dtype=inferred_dtype, diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index 1bfb1dc5..5e154e4b 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -31,9 +31,9 @@ import opt_einsum import cuequivariance as cue # noqa: F401 -from cuequivariance import segmented_tensor_product as stp -from cuequivariance.misc.linalg import round_to_rational, round_to_sqrt_rational -from cuequivariance.misc.permutations import ( +import cuequivariance.segmented_polynomials as stp +from cuequivariance.etc.linalg import round_to_rational, round_to_sqrt_rational +from cuequivariance.etc.permutations import ( generate_permutations_from, inverse_permutation, ) diff --git a/cuequivariance/tests/etc/round_to_rational_test.py b/cuequivariance/tests/etc/round_to_rational_test.py index 33dbf486..d488ac41 100644 --- a/cuequivariance/tests/etc/round_to_rational_test.py +++ b/cuequivariance/tests/etc/round_to_rational_test.py @@ -14,7 +14,7 @@ # limitations under the License. import numpy as np -from cuequivariance.misc import linalg +from cuequivariance.etc import linalg def test_as_approx_integer_ratio(): diff --git a/cuequivariance/tests/group_theory/experimental/escn_test.py b/cuequivariance/tests/group_theory/experimental/escn_test.py index 2b770f4f..bb5afb5d 100644 --- a/cuequivariance/tests/group_theory/experimental/escn_test.py +++ b/cuequivariance/tests/group_theory/experimental/escn_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import cuequivariance as cue -from cuequivariance.experimental.escn import escn_tp, escn_tp_compact +from cuequivariance.group_theory.experimental.escn import escn_tp, escn_tp_compact def test_escn(): diff --git a/cuequivariance/tests/group_theory/experimental/gatr_test.py b/cuequivariance/tests/group_theory/experimental/gatr_test.py index cef872de..dc55373f 100644 --- a/cuequivariance/tests/group_theory/experimental/gatr_test.py +++ b/cuequivariance/tests/group_theory/experimental/gatr_test.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from cuequivariance.experimental.gatr import ( +from cuequivariance.group_theory.experimental.gatr import ( gatr_geometric_product, gatr_linear, gatr_outer_product, diff --git a/cuequivariance/tests/group_theory/experimental/mace_test.py b/cuequivariance/tests/group_theory/experimental/mace_test.py index fec5d863..943c8059 100644 --- a/cuequivariance/tests/group_theory/experimental/mace_test.py +++ b/cuequivariance/tests/group_theory/experimental/mace_test.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import cuequivariance as cue -from cuequivariance.experimental.mace import symmetric_contraction +from cuequivariance.group_theory.experimental.mace import symmetric_contraction def test_symmetric_contraction(): diff --git a/cuequivariance/tests/segmented_polynomials/dot_test.py b/cuequivariance/tests/segmented_polynomials/dot_test.py index 417570a0..1e1bf2bd 100644 --- a/cuequivariance/tests/segmented_polynomials/dot_test.py +++ b/cuequivariance/tests/segmented_polynomials/dot_test.py @@ -15,8 +15,8 @@ import numpy as np import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp -from cuequivariance import descriptors +import cuequivariance.segmented_polynomials as stp +from cuequivariance.group_theory import descriptors def test_dot1(): diff --git a/cuequivariance/tests/segmented_polynomials/evaluate_test.py b/cuequivariance/tests/segmented_polynomials/evaluate_test.py index f4ed6816..7f1ced10 100644 --- a/cuequivariance/tests/segmented_polynomials/evaluate_test.py +++ b/cuequivariance/tests/segmented_polynomials/evaluate_test.py @@ -23,7 +23,7 @@ def test_compute_last_operand_1(): x0 = np.random.randn(d.operands[0].size) x1 = np.random.randn(d.operands[1].size) - x2 = cue.segmented_tensor_product.compute_last_operand(d, x0, x1) + x2 = cue.segmented_polynomials.compute_last_operand(d, x0, x1) x2_ = (x0.reshape(2, 3) @ x1.reshape(3, 4)).reshape(-1) np.testing.assert_allclose(x2_, x2) @@ -35,7 +35,7 @@ def test_compute_last_operand_2(): x0 = np.random.randn(10, d.operands[0].size) x1 = np.random.randn(10, d.operands[1].size) - x2 = cue.segmented_tensor_product.compute_last_operand(d, x0, x1) + x2 = cue.segmented_polynomials.compute_last_operand(d, x0, x1) x2_ = (x0.reshape(10, 2, 3) @ x1.reshape(10, 3, 4)).reshape(10, -1) np.testing.assert_allclose(x2_, x2) @@ -47,7 +47,7 @@ def test_compute_last_operand_3(): x0 = np.random.randn(1, d.operands[0].size, 5) x1 = np.random.randn(10, d.operands[1].size, 1) - x2 = cue.segmented_tensor_product.compute_last_operand( + x2 = cue.segmented_polynomials.compute_last_operand( d, x0, x1, segment_axes=[1, 1, 1] ) @@ -64,7 +64,7 @@ def test_compute_last_operand_4(): x0 = np.random.randn(d.operands[0].size) x1 = np.random.randn(d.operands[1].size) - x2 = cue.segmented_tensor_product.compute_last_operand(d, x0, x1) + x2 = cue.segmented_polynomials.compute_last_operand(d, x0, x1) x2_ = np.einsum( "ijk,iuv,jvw->kuw", c, x0.reshape(2, 2, 3), x1.reshape(3, 3, 4) @@ -86,7 +86,7 @@ def test_primitive_compute_last_operand(): d = d.to_dict(True) - x2 = cue.segmented_tensor_product.primitive_compute_last_operand( + x2 = cue.segmented_polynomials.primitive_compute_last_operand( [ope["subscripts"] for ope in d["operands"]], d["coefficient_subscripts"], [ope["segments"] for ope in d["operands"]], diff --git a/cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py b/cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py index 95b4c634..749ea75d 100644 --- a/cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py +++ b/cuequivariance/tests/segmented_polynomials/segmented_tensor_product_test.py @@ -143,12 +143,12 @@ def test_flatten(): x0 = np.random.randn(d.operands[0].size) x1 = np.random.randn(d.operands[1].size) - x2 = cue.segmented_tensor_product.compute_last_operand(d, x0, x1) + x2 = cue.segmented_polynomials.compute_last_operand(d, x0, x1) for channels in ["i", "j", "ij", "ui", "iju", "uvij"]: np.testing.assert_allclose( x2, - cue.segmented_tensor_product.compute_last_operand( + cue.segmented_polynomials.compute_last_operand( d.flatten_modes(channels), x0, x1 ), ) @@ -352,9 +352,7 @@ def test_split_mode(): # Test computation on original descriptor x_input = np.random.randn(d_compute.operands[0].size) - result_original = cue.segmented_tensor_product.compute_last_operand( - d_compute, x_input - ) + result_original = cue.segmented_polynomials.compute_last_operand(d_compute, x_input) # Verify split_mode works by first flattening the results to remove 'u' mode indices d_ua = cue.SegmentedTensorProduct.from_subscripts("ua,b+ab") @@ -365,7 +363,7 @@ def test_split_mode(): # Input for the split descriptor - we need a tensor with the right shape x_input_split = np.random.randn(d_ua_split.operands[0].size) - result_split = cue.segmented_tensor_product.compute_last_operand( + result_split = cue.segmented_polynomials.compute_last_operand( d_ua_split, x_input_split ) @@ -445,14 +443,14 @@ def test_add_or_rename_modes(): d_id.assert_valid() assert d_id.subscripts == "i,j+ij" x_input = np.random.randn(d.operands[0].size) - res0 = cue.segmented_tensor_product.compute_last_operand(d, x_input) - res_id = cue.segmented_tensor_product.compute_last_operand(d_id, x_input) + res0 = cue.segmented_polynomials.compute_last_operand(d, x_input) + res_id = cue.segmented_polynomials.compute_last_operand(d_id, x_input) np.testing.assert_allclose(res0, res_id) d_ren = d.add_or_rename_modes("a,b+ab") d_ren.assert_valid() assert d_ren.subscripts == "a,b+ab" np.testing.assert_allclose( - res0, cue.segmented_tensor_product.compute_last_operand(d_ren, x_input) + res0, cue.segmented_polynomials.compute_last_operand(d_ren, x_input) ) with pytest.raises(ValueError): d.add_or_rename_modes("i+ij") @@ -460,7 +458,7 @@ def test_add_or_rename_modes(): d_sup.assert_valid() assert d_sup.subscripts == "bi,bj+ij" np.testing.assert_allclose( - res0, cue.segmented_tensor_product.compute_last_operand(d_sup, x_input) + res0, cue.segmented_polynomials.compute_last_operand(d_sup, x_input) ) diff --git a/cuequivariance/tests/segmented_polynomials/subscripts_test.py b/cuequivariance/tests/segmented_polynomials/subscripts_test.py index 80598381..54fdf480 100644 --- a/cuequivariance/tests/segmented_polynomials/subscripts_test.py +++ b/cuequivariance/tests/segmented_polynomials/subscripts_test.py @@ -14,7 +14,7 @@ # limitations under the License. import pytest -import cuequivariance.segmented_tensor_product as stp +import cuequivariance.segmented_polynomials as stp def test_subscripts(): diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py b/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py index f7ce2e84..4c279ab2 100644 --- a/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py +++ b/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py @@ -20,7 +20,7 @@ import cuequivariance as cue import cuequivariance_jax as cuex -from cuequivariance.irreps_array.misc_ui import assert_same_group +from cuequivariance.group_theory.irreps_array.misc_ui import assert_same_group def concatenate(arrays: list[cuex.RepArray]) -> cuex.RepArray: diff --git a/cuequivariance_torch/cuequivariance_torch/layers/batchnorm.py b/cuequivariance_torch/cuequivariance_torch/layers/batchnorm.py index d71fe644..98f0c766 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/batchnorm.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/batchnorm.py @@ -16,7 +16,10 @@ import torch import cuequivariance as cue -from cuequivariance.irreps_array.misc_ui import default_irreps, default_layout +from cuequivariance.group_theory.irreps_array.misc_ui import ( + default_irreps, + default_layout, +) # This implementation is an adaptation of https://github.com/e3nn/e3nn/blob/ef93f876c9985b3816aefb2982b3cf4325df6ba4/e3nn/nn/_batchnorm.py diff --git a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py index a19ba1a4..55e030ea 100644 --- a/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/layers/tp_conv_fully_connected.py @@ -19,7 +19,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance.irreps_array.misc_ui import ( +from cuequivariance.group_theory.irreps_array.misc_ui import ( assert_same_group, default_irreps, default_layout, diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index 8872f426..5f3c5ffe 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -19,7 +19,10 @@ import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -from cuequivariance.irreps_array.misc_ui import assert_same_group, default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import ( + assert_same_group, + default_irreps, +) class Linear(torch.nn.Module): diff --git a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py index 69a88a07..f457e7b2 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/rotation.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/rotation.py @@ -19,7 +19,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -from cuequivariance.irreps_array.misc_ui import default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import default_irreps class Rotation(torch.nn.Module): diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 2366c9b7..2ad630a9 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -21,7 +21,10 @@ from cuequivariance.experimental.mace.symmetric_contractions import ( symmetric_contraction, ) -from cuequivariance.irreps_array.misc_ui import assert_same_group, default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import ( + assert_same_group, + default_irreps, +) class SymmetricContraction(torch.nn.Module): diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index cf672f07..67c73ee1 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -19,7 +19,10 @@ import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -from cuequivariance.irreps_array.misc_ui import assert_same_group, default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import ( + assert_same_group, + default_irreps, +) class ChannelWiseTensorProduct(torch.nn.Module): diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index 4c5b4277..f4d0c29b 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -19,7 +19,10 @@ import cuequivariance as cue import cuequivariance_torch as cuet from cuequivariance import descriptors -from cuequivariance.irreps_array.misc_ui import assert_same_group, default_irreps +from cuequivariance.group_theory.irreps_array.misc_ui import ( + assert_same_group, + default_irreps, +) class FullyConnectedTensorProduct(torch.nn.Module): diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 51e50bc7..34045bdc 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -18,7 +18,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance.irreps_array.misc_ui import default_layout +from cuequivariance.group_theory.irreps_array.misc_ui import default_layout class Dispatcher(torch.nn.Module): From 2843bcb99c1ad824548e3cf1ee647414d85c01c0 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 12:52:18 +0100 Subject: [PATCH 075/107] fix --- .../group_theory/irreps_array/context_decorator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py b/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py index 78fda178..00e614d2 100644 --- a/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py +++ b/cuequivariance/cuequivariance/group_theory/irreps_array/context_decorator.py @@ -15,6 +15,7 @@ from functools import wraps from typing import Optional, Type, Union +import cuequivariance as cue # noqa: F401 import cuequivariance.group_theory.irreps_array as irreps_array from cuequivariance.group_theory import Irrep from cuequivariance.group_theory.irreps_array.context_irrep_class import ( From bbaa8fc4fdc2e6b99c4f150df6cbd18c9b0a269a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 12:52:52 +0100 Subject: [PATCH 076/107] fix --- .../cuequivariance_torch/primitives/tensor_product.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 0cc9985c..cf8b13fc 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -21,7 +21,7 @@ import torch import torch.fx -from cuequivariance import segmented_tensor_product as stp +from cuequivariance import segmented_polynomials as stp logger = logging.getLogger(__name__) From bfb1a856f75327a20258aa5642e8ea1b92f0ae98 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 12:59:51 +0100 Subject: [PATCH 077/107] fix --- .../cuequivariance_torch/operations/symmetric_contraction.py | 2 +- .../tests/operations/symmetric_contraction_test.py | 2 +- .../tests/primitives/symmetric_tensor_product_test.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py index 2ad630a9..d2e6b4f8 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/symmetric_contraction.py @@ -18,7 +18,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance.experimental.mace.symmetric_contractions import ( +from cuequivariance.group_theory.experimental.mace.symmetric_contractions import ( symmetric_contraction, ) from cuequivariance.group_theory.irreps_array.misc_ui import ( diff --git a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py index 8a67026c..55b504bd 100644 --- a/cuequivariance_torch/tests/operations/symmetric_contraction_test.py +++ b/cuequivariance_torch/tests/operations/symmetric_contraction_test.py @@ -20,7 +20,7 @@ import cuequivariance as cue import cuequivariance_torch as cuet -from cuequivariance.experimental.e3nn import O3_e3nn +from cuequivariance.group_theory.experimental.e3nn import O3_e3nn from cuequivariance_torch._tests.utils import ( module_with_mode, ) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 72034ca6..272ab53e 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -16,7 +16,7 @@ import torch import cuequivariance as cue -import cuequivariance.segmented_tensor_product as stp +import cuequivariance.segmented_polynomials as stp import cuequivariance_torch as cuet from cuequivariance import descriptors from cuequivariance_torch._tests.utils import ( From 882e1173ce6052e7ed694b699382bb4f86010349 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 4 Mar 2025 13:05:47 +0100 Subject: [PATCH 078/107] fix --- .../primitives/symmetric_tensor_product.py | 2 +- .../cuequivariance_torch/primitives/tensor_product.py | 2 +- docs/api/cuequivariance.descriptors.rst | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py index d4fcaf31..5da21f2c 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/symmetric_tensor_product.py @@ -51,7 +51,7 @@ def __init__( operands_and_subscripts=[(cue.SegmentedOperand.empty_segments(1), "")] + list(d.operands_and_subscripts), paths=[ - cue.segmented_tensor_product.Path( + cue.segmented_polynomials.Path( (0,) + path.indices, path.coefficients ) for path in d.paths diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index cf8b13fc..257e4eaf 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -525,7 +525,7 @@ def __init__( import cuequivariance_ops_torch as ops self._f = ops.FusedTensorProductOp3( - operand_segment_modes=[ope.subscripts for ope in descriptor.operands], + operand_segment_modes=descriptor.subscripts.operands, operand_segment_offsets=[ [s.start for s in ope.segment_slices()] for ope in descriptor.operands ], diff --git a/docs/api/cuequivariance.descriptors.rst b/docs/api/cuequivariance.descriptors.rst index aed06c7c..aa90c7d8 100644 --- a/docs/api/cuequivariance.descriptors.rst +++ b/docs/api/cuequivariance.descriptors.rst @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. -.. module:: cuequivariance.descriptors -.. currentmodule:: cuequivariance.descriptors +.. module:: cuequivariance.group_theory.descriptors +.. currentmodule:: cuequivariance.group_theory.descriptors Descriptors =========== From da062a07b065e055207ab31e2f66252387d90365 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 5 Mar 2025 14:41:58 +0100 Subject: [PATCH 079/107] re-organize files --- cuequivariance_jax/cuequivariance_jax/__init__.py | 8 ++++---- .../cuequivariance_jax/{operations => }/activation.py | 0 .../{primitives => }/equivariant_polynomial.py | 0 .../segmented_polynomial.py | 11 ++++------- .../segmented_polynomial_ops_impl.py | 3 ++- .../segmented_polynomial_vanilla_impl.py | 0 .../utils.py} | 0 .../{operations => }/spherical_harmonics.py | 0 .../{primitives => }/equivariant_polynomial_test.py | 0 ...jax_irreps_array_test.py => jax_rep_array_test.py} | 0 .../{primitives => }/segmented_polynomial_test.py | 0 .../{operations => }/spherical_harmonics_test.py | 0 12 files changed, 10 insertions(+), 12 deletions(-) rename cuequivariance_jax/cuequivariance_jax/{operations => }/activation.py (100%) rename cuequivariance_jax/cuequivariance_jax/{primitives => }/equivariant_polynomial.py (100%) rename cuequivariance_jax/cuequivariance_jax/{primitives => segmented_polynomials}/segmented_polynomial.py (98%) rename cuequivariance_jax/cuequivariance_jax/{primitives => segmented_polynomials}/segmented_polynomial_ops_impl.py (98%) rename cuequivariance_jax/cuequivariance_jax/{primitives => segmented_polynomials}/segmented_polynomial_vanilla_impl.py (100%) rename cuequivariance_jax/cuequivariance_jax/{primitives/primitives_utils.py => segmented_polynomials/utils.py} (100%) rename cuequivariance_jax/cuequivariance_jax/{operations => }/spherical_harmonics.py (100%) rename cuequivariance_jax/tests/{primitives => }/equivariant_polynomial_test.py (100%) rename cuequivariance_jax/tests/{irreps_array/jax_irreps_array_test.py => jax_rep_array_test.py} (100%) rename cuequivariance_jax/tests/{primitives => }/segmented_polynomial_test.py (100%) rename cuequivariance_jax/tests/{operations => }/spherical_harmonics_test.py (100%) diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index 83f30bed..ac148444 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -23,16 +23,16 @@ from .rep_array.vmap import vmap from .rep_array.utils import concatenate, randn, as_irreps_array, clebsch_gordan -from .primitives.segmented_polynomial import segmented_polynomial -from .primitives.equivariant_polynomial import equivariant_polynomial +from .segmented_polynomials.segmented_polynomial import segmented_polynomial +from .equivariant_polynomial import equivariant_polynomial -from .operations.activation import ( +from .activation import ( normalspace, normalize_function, function_parity, scalar_activation, ) -from .operations.spherical_harmonics import spherical_harmonics, normalize, norm +from .spherical_harmonics import spherical_harmonics, normalize, norm from cuequivariance_jax import flax_linen diff --git a/cuequivariance_jax/cuequivariance_jax/operations/activation.py b/cuequivariance_jax/cuequivariance_jax/activation.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/operations/activation.py rename to cuequivariance_jax/cuequivariance_jax/activation.py diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/primitives/equivariant_polynomial.py rename to cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py similarity index 98% rename from cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py rename to cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py index dc4ba3a1..17e0ba43 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py @@ -23,13 +23,10 @@ from jax.interpreters import ad, batching, mlir, partial_eval, xla import cuequivariance as cue -from cuequivariance_jax.primitives.primitives_utils import reshape -from cuequivariance_jax.primitives.segmented_polynomial_ops_impl import ( - segmented_polynomial_ops_impl, -) -from cuequivariance_jax.primitives.segmented_polynomial_vanilla_impl import ( - segmented_polynomial_vanilla_impl, -) + +from .segmented_polynomial_ops_impl import segmented_polynomial_ops_impl +from .segmented_polynomial_vanilla_impl import segmented_polynomial_vanilla_impl +from .utils import reshape logger = logging.getLogger(__name__) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py similarity index 98% rename from cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py rename to cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py index c874d8b5..fa76f64a 100644 --- a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py @@ -19,7 +19,8 @@ import jax.numpy as jnp import cuequivariance as cue -from cuequivariance_jax.primitives.primitives_utils import reshape + +from .utils import reshape logger = logging.getLogger(__name__) diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/primitives/segmented_polynomial_vanilla_impl.py rename to cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py diff --git a/cuequivariance_jax/cuequivariance_jax/primitives/primitives_utils.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/primitives/primitives_utils.py rename to cuequivariance_jax/cuequivariance_jax/segmented_polynomials/utils.py diff --git a/cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py b/cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/operations/spherical_harmonics.py rename to cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py diff --git a/cuequivariance_jax/tests/primitives/equivariant_polynomial_test.py b/cuequivariance_jax/tests/equivariant_polynomial_test.py similarity index 100% rename from cuequivariance_jax/tests/primitives/equivariant_polynomial_test.py rename to cuequivariance_jax/tests/equivariant_polynomial_test.py diff --git a/cuequivariance_jax/tests/irreps_array/jax_irreps_array_test.py b/cuequivariance_jax/tests/jax_rep_array_test.py similarity index 100% rename from cuequivariance_jax/tests/irreps_array/jax_irreps_array_test.py rename to cuequivariance_jax/tests/jax_rep_array_test.py diff --git a/cuequivariance_jax/tests/primitives/segmented_polynomial_test.py b/cuequivariance_jax/tests/segmented_polynomial_test.py similarity index 100% rename from cuequivariance_jax/tests/primitives/segmented_polynomial_test.py rename to cuequivariance_jax/tests/segmented_polynomial_test.py diff --git a/cuequivariance_jax/tests/operations/spherical_harmonics_test.py b/cuequivariance_jax/tests/spherical_harmonics_test.py similarity index 100% rename from cuequivariance_jax/tests/operations/spherical_harmonics_test.py rename to cuequivariance_jax/tests/spherical_harmonics_test.py From 817644f118798df15d759e82a82ee8c50ba238e4 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 5 Mar 2025 15:05:17 +0100 Subject: [PATCH 080/107] add tests --- .../cuequivariance_jax/activation.py | 4 + .../cuequivariance_jax/spherical_harmonics.py | 20 +++- cuequivariance_jax/tests/activation_test.py | 96 +++++++++++++++++++ .../tests/spherical_harmonics_test.py | 32 +++++++ 4 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 cuequivariance_jax/tests/activation_test.py diff --git a/cuequivariance_jax/cuequivariance_jax/activation.py b/cuequivariance_jax/cuequivariance_jax/activation.py index a171ce97..d919a1bb 100644 --- a/cuequivariance_jax/cuequivariance_jax/activation.py +++ b/cuequivariance_jax/cuequivariance_jax/activation.py @@ -74,6 +74,10 @@ def rho(x): def function_parity(phi: ActFn) -> int: + r"""Determine the parity of a function. + + For example, sin is odd, cos is even, and exp is neither. + """ with jax.ensure_compile_time_eval(): x = jnp.linspace(0.0, 10.0, 256) diff --git a/cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py b/cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py index 8b4f254d..3da4c16f 100644 --- a/cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py +++ b/cuequivariance_jax/cuequivariance_jax/spherical_harmonics.py @@ -91,7 +91,25 @@ def f(x: jax.Array) -> jax.Array: def norm(array: cuex.RepArray, *, squared: bool = False) -> cuex.RepArray: - """Norm of `RepArray`.""" + """Compute the norm of a `RepArray`. + + This function calculates the norm for each element in the irreps array by summing + the squared magnitudes of the elements along the irrep dimension. By default, + the function returns the square root of this sum (the regular norm), but it can + also return the squared norm if requested. + + When the squared norm is zero, the function handles this special case: + - If squared=True, it returns 0.0 + - If squared=False, it safely computes the square root and returns 0.0 + + Args: + array: The equivariant array (RepArray) whose norm should be calculated + squared: If True, returns the squared norm; if False (default), returns the regular norm + + Returns: + A new RepArray with trivial irreps where each element represents the norm + (or squared norm) of the corresponding element in the input array + """ assert array.is_irreps_array() match array.layout: diff --git a/cuequivariance_jax/tests/activation_test.py b/cuequivariance_jax/tests/activation_test.py new file mode 100644 index 00000000..c3ad9d8a --- /dev/null +++ b/cuequivariance_jax/tests/activation_test.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import jax +import jax.numpy as jnp +import pytest + +import cuequivariance as cue +import cuequivariance_jax as cuex +from cuequivariance_jax.activation import ( + function_parity, + normalize_function, + normalspace, + scalar_activation, +) + + +def test_normalspace(): + """Test the normalspace function.""" + n = 100 + points = normalspace(n) + + assert points.shape == (n,) + assert jnp.all(jnp.diff(points) > 0) # Points in ascending order + assert jnp.isclose(jnp.mean(points), 0.0, atol=1e-6) + + +def test_normalize_function(): + """Test the normalize_function.""" + # Test with constant and non-constant functions + norm_const = normalize_function(lambda x: jnp.ones_like(x)) + norm_linear = normalize_function(lambda x: x) + + test_points = normalspace(1001) + + # Check normalization + assert jnp.isclose(jnp.mean(norm_const(test_points) ** 2), 1.0, atol=1e-2) + assert jnp.isclose(jnp.mean(norm_linear(test_points) ** 2), 1.0, atol=5e-2) + + # Test zero function (should raise ValueError) + with pytest.raises(ValueError): + normalize_function(lambda x: jnp.zeros_like(x)) + + +def test_function_parity(): + """Test the function_parity function.""" + # Test even, odd, and neither functions + assert function_parity(jnp.cos) == 1 # Even + assert function_parity(jnp.sin) == -1 # Odd + assert function_parity(jnp.exp) == 0 # Neither + + +@cue.assume("SO3", cue.ir_mul) +def test_scalar_activation(): + """Test scalar_activation function.""" + # Create test data + irreps = cue.Irreps("SO3", "2x0 + 0") + x = cuex.randn(jax.random.key(42), irreps, (5,)) + + # Test with a single activation + y = scalar_activation(x, lambda x: 2 * x) + assert y.irreps == x.irreps + assert y.shape == x.shape + + # Test with multiple activations + y = scalar_activation(x, [jnp.sin, jnp.cos]) + assert y.irreps == x.irreps + + # Test with a dict of activations + y = scalar_activation(x, {cue.SO3(0): jnp.sin}) + assert y.shape == x.shape + + # Test with non-scalar irreps + irreps_with_vectors = cue.Irreps("SO3", "0 + 1") + x_with_vectors = cuex.randn(jax.random.key(43), irreps_with_vectors, (5,)) + + # Should assert when trying to apply activation to non-scalar + with pytest.raises(AssertionError): + scalar_activation(x_with_vectors, jnp.sin) + + # Should work with None for non-scalar components + y = scalar_activation(x_with_vectors, [jnp.sin, None]) + assert y.irreps == x_with_vectors.irreps diff --git a/cuequivariance_jax/tests/spherical_harmonics_test.py b/cuequivariance_jax/tests/spherical_harmonics_test.py index 68305557..416fe383 100644 --- a/cuequivariance_jax/tests/spherical_harmonics_test.py +++ b/cuequivariance_jax/tests/spherical_harmonics_test.py @@ -36,3 +36,35 @@ def test_spherical_harmonics(shape): y = cuex.spherical_harmonics([0, 1, 2], x) assert y.shape == shape + (9,) assert y.irreps == cue.Irreps(cue.O3, "0e + 1o + 2e") + + +@pytest.mark.parametrize("squared", [True, False]) +def test_norm(squared): + """Test the norm function with a batch of vectors.""" + # Create a RepArray with known values + irreps = cue.Irreps(cue.O3, "1o") + shape = (10,) + + # Create batch data with specific values in first two positions + data = np.zeros(shape + (3,)) + data[0] = [3.0, 4.0, 0.0] # norm = 5.0 + data[1] = [0.0, 0.0, 0.0] # norm = 0.0 + data[2:] = np.random.randn(shape[0] - 2, 3) # random values + + array = cuex.RepArray(irreps, data, cue.ir_mul) + + # Calculate norm + norm_result = cuex.norm(array, squared=squared) + + # Verify basic properties + assert norm_result.irreps.dim == 1 + for _, ir in norm_result.irreps: + assert ir.is_trivial() + assert norm_result.shape == shape + (1,) + + # Verify specific test values + expected_first = 25.0 if squared else 5.0 # norm of [3,4,0] + expected_second = 0.0 # norm of [0,0,0] + + np.testing.assert_allclose(norm_result.array[0, 0], expected_first, rtol=1e-6) + np.testing.assert_allclose(norm_result.array[1, 0], expected_second, rtol=1e-6) From 5626756f9a82ec381e01a410f17afe86cad01e2d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 5 Mar 2025 15:39:53 -0800 Subject: [PATCH 081/107] fix --- .python-version | 1 + cuequivariance_jax/cuequivariance_jax/__init__.py | 4 ++-- .../rep_array/{jax_rep_array.py => rep_array_.py} | 0 .../rep_array/{utils.py => rep_array_utils.py} | 0 .../segmented_polynomials/segmented_polynomial.py | 11 +++++++---- .../segmented_polynomial_ops_impl.py | 3 +-- 6 files changed, 11 insertions(+), 8 deletions(-) create mode 100644 .python-version rename cuequivariance_jax/cuequivariance_jax/rep_array/{jax_rep_array.py => rep_array_.py} (100%) rename cuequivariance_jax/cuequivariance_jax/rep_array/{utils.py => rep_array_utils.py} (100%) diff --git a/.python-version b/.python-version new file mode 100644 index 00000000..24ee5b1b --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.13 diff --git a/cuequivariance_jax/cuequivariance_jax/__init__.py b/cuequivariance_jax/cuequivariance_jax/__init__.py index ac148444..8e577cd0 100644 --- a/cuequivariance_jax/cuequivariance_jax/__init__.py +++ b/cuequivariance_jax/cuequivariance_jax/__init__.py @@ -19,9 +19,9 @@ ) -from .rep_array.jax_rep_array import RepArray, from_segments +from .rep_array.rep_array_ import RepArray, from_segments from .rep_array.vmap import vmap -from .rep_array.utils import concatenate, randn, as_irreps_array, clebsch_gordan +from .rep_array.rep_array_utils import concatenate, randn, as_irreps_array, clebsch_gordan from .segmented_polynomials.segmented_polynomial import segmented_polynomial from .equivariant_polynomial import equivariant_polynomial diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py b/cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/rep_array/jax_rep_array.py rename to cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_.py diff --git a/cuequivariance_jax/cuequivariance_jax/rep_array/utils.py b/cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_utils.py similarity index 100% rename from cuequivariance_jax/cuequivariance_jax/rep_array/utils.py rename to cuequivariance_jax/cuequivariance_jax/rep_array/rep_array_utils.py diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py index 17e0ba43..2cb49de7 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py @@ -23,10 +23,13 @@ from jax.interpreters import ad, batching, mlir, partial_eval, xla import cuequivariance as cue - -from .segmented_polynomial_ops_impl import segmented_polynomial_ops_impl -from .segmented_polynomial_vanilla_impl import segmented_polynomial_vanilla_impl -from .utils import reshape +from cuequivariance_jax.segmented_polynomials.segmented_polynomial_ops_impl import ( + segmented_polynomial_ops_impl, +) +from cuequivariance_jax.segmented_polynomials.segmented_polynomial_vanilla_impl import ( + segmented_polynomial_vanilla_impl, +) +from cuequivariance_jax.segmented_polynomials.utils import reshape logger = logging.getLogger(__name__) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py index fa76f64a..19e69ec1 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py @@ -19,8 +19,7 @@ import jax.numpy as jnp import cuequivariance as cue - -from .utils import reshape +from cuequivariance_jax.segmented_polynomials.utils import reshape logger = logging.getLogger(__name__) From 250e20e16c5406afdb783b2b657ee08154c09df4 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 5 Mar 2025 15:41:19 -0800 Subject: [PATCH 082/107] change changelog --- CHANGELOG.md | 59 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cd7d28c..8fabaa16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,23 +5,36 @@ - Rename `SegmentedTensorProduct.memory_cost` in `memory` - Removed `IrrepsArray` in favor of `RepArray` -## 0.3.0-rc1 + +## 0.3.0 (2025-03-05) + +The main changes are: +1. [JAX] New JIT Uniform 1d kernel with JAX bindings + 1. Computes any polynomial based on 1d uniform STPs + 2. Supports arbitrary derivatives + 3. Provides optional fused scatter/gather for the inputs and outputs + 4. 🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉 +2. [Torch] Adds torch.compile support +3. [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (see tutorial in the documentation) +4. [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (see tutorial in the documentation) ### Breaking Changes -- In `cuex.equivariant_tensor_product`, the arguments `dtype_math` and `dtype_output` are renamed `math_dtype` and `output_dtype` respectively. Adding consistency with the rest of the library. -- In `cuex.equivariant_tensor_product`, the arguments `algorithm`, `precision`, `use_custom_primitive` and `use_custom_kernels` are removed. This is to avoid a proliferation of arguments that are not used in all the implementations. An argument `impl: str` is added instead to select the implementation. -- Removed `cue.TensorProductExecution` and added instead `cue.Operation` that is more lightweight and aligned with the backend. -- Removed `cuex.symmetric_tensor_product`. `cuex.tensor_product` is now able to handle any non-homogeneous polynomials. +- In `cuex.equivariant_tensor_product`, the arguments `dtype_math` and `dtype_output` are renamed to `math_dtype` and `output_dtype` respectively. This change adds consistency with the rest of the library. +- In `cuex.equivariant_tensor_product`, the arguments `algorithm`, `precision`, `use_custom_primitive` and `use_custom_kernels` have been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argument `impl: str` has been added instead to select the implementation. +- Removed `cue.TensorProductExecution` and added `cue.Operation` which is more lightweight and better aligned with the backend. +- Removed `cuex.symmetric_tensor_product`. The `cuex.tensor_product` function now handles any non-homogeneous polynomials. - Removed `cuex.flax_linen.Linear` to reduce maintenance burden. Use `cue.descriptor.linear` together with `cuex.equivariant_tensor_product` instead. -- The batching support (`jax.vmap`) of `cuex.equivariant_tensor_product` is currently limited. -- The interface of `cuex.tensor_product` is changed. Now it takes a list of `tuple[cue.Operation, cue.SegmentedTensorProduct]` instead of a single `cue.SegmentedTensorProduct`. This allows `cuex.tensor_product` to execute any sort of non-homogeneous polynomials. +- The batching support (`jax.vmap`) of `cuex.equivariant_tensor_product` is now limited to specific use cases. +- The interface of `cuex.tensor_product` has changed. It now takes a list of `tuple[cue.Operation, cue.SegmentedTensorProduct]` instead of a single `cue.SegmentedTensorProduct`. This change allows `cuex.tensor_product` to execute any type of non-homogeneous polynomials. ### Fixed -- Identified bug in CUDA kernel, disable CUDA kernel for `cuet.TransposeSegments` and `cuet.TransposeIrrepsLayout`. -- `cue.descriptor.full_tensor_product` was ignoring the `irreps3_filter` argument. +- Identified a bug in the CUDA kernel and disabled CUDA kernel for `cuet.TransposeSegments` and `cuet.TransposeIrrepsLayout`. +- Fixed `cue.descriptor.full_tensor_product` which was ignoring the `irreps3_filter` argument. +- Fixed a rare bug with `np.bincount` when using an old version of numpy. The input is now flattened to make it work with all versions. ### Added -- JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) have only one index. It handles batched/shared/indexed input/output. The indexed input/output are handled by atomic operations. +- Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations. +- Added an `indices` argument to `cuex.equivariant_tensor_product` and `cuex.tensor_product` to handle the scatter/gather fusion. - Added `__mul__` to `cue.EquivariantTensorProduct` to allow rescaling the equivariant tensor product. - Added a uniform 1d kernel with scatter/gather fusion under `cuet.primitives.tensor_product.TensorProductUniform4x1dIndexed` and `cuet.primitives.tensor_product.TensorProductUniform3x1dIndexed`. @@ -30,28 +43,28 @@ ### Breaking Changes -- Minimal python version is now 3.10 in all packages. -- `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` now require inputs to be of shape `(batch_size, dim)` or `(1, dim)`. Inputs of dimension `(dim,)` are no more allowed. -- `cuex.IrrepsArray` is an alias for `cuex.RepArray`. -- `cuex.RepArray.irreps` and `cuex.RepArray.segments` are not functions anymore. They are now properties. -- `cuex.IrrepsArray.is_simple` is replaced by `cuex.RepArray.is_irreps_array`. -- The function `cuet.spherical_harmonics` is replaced by the Torch Module `cuet.SphericalHarmonics`. This was done to allow the use of `torch.jit.script` and `torch.compile`. +- Minimal Python version is now 3.10 in all packages. +- `cuet.TensorProduct` and `cuet.EquivariantTensorProduct` now require inputs to be of shape `(batch_size, dim)` or `(1, dim)`. Inputs of dimension `(dim,)` are no longer allowed. +- `cuex.IrrepsArray` is now an alias for `cuex.RepArray`. +- `cuex.RepArray.irreps` and `cuex.RepArray.segments` are no longer functions. They are now properties. +- `cuex.IrrepsArray.is_simple` has been replaced by `cuex.RepArray.is_irreps_array`. +- The function `cuet.spherical_harmonics` has been replaced by the Torch Module `cuet.SphericalHarmonics`. This change enables the use of `torch.jit.script` and `torch.compile`. ### Added -- Add an experimental support for `torch.compile`. Known issue: the export in c++ is not working. -- Add `cue.IrrepsAndLayout`: A simple class that inherits from `cue.Rep` and contains a `cue.Irreps` and a `cue.IrrepsLayout`. -- Add `cuex.RepArray` for representing an array of any kind of representations (not only irreps like before with `cuex.IrrepsArray`). +- Added experimental support for `torch.compile`. Known issue: the export in C++ is not working. +- Added `cue.IrrepsAndLayout`: A simple class that inherits from `cue.Rep` and contains a `cue.Irreps` and a `cue.IrrepsLayout`. +- Added `cuex.RepArray` for representing an array of any kind of representations (not only irreps as was previously possible with `cuex.IrrepsArray`). ### Fixed -- Add support for empty batch dimension in `cuet` (`cuequivariance_torch`). -- Move `README.md` and `LICENSE` into the source distribution. -- Fix `cue.SegmentedTensorProduct.flop_cost` for the special case of 1 operand. +- Added support for empty batch dimension in `cuet` (`cuequivariance_torch`). +- Moved `README.md` and `LICENSE` into the source distribution. +- Fixed `cue.SegmentedTensorProduct.flop_cost` for the special case of 1 operand. ### Improved -- No more special case for degree 0 in `cuet.SymmetricTensorProduct`. +- Removed special case handling for degree 0 in `cuet.SymmetricTensorProduct`. ## 0.1.0 (2024-11-18) From dc5a260def23ea35565fd47edaf5bce4957b1297 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 5 Mar 2025 15:48:03 -0800 Subject: [PATCH 083/107] test jax on self-hosted --- .github/workflows/tests.yml | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 48aaca35..21c95c25 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,12 +33,16 @@ jobs: pytest --doctest-modules -x cuequivariance cuequivariance-jax: - - runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.10", "3.12"] + include: + - runner: "ubuntu-latest" + python-version: "3.10" + - runner: "self-hosted" + python-version: "3.12" + + runs-on: ${{ matrix.runner }} steps: - uses: actions/checkout@v4 @@ -50,9 +54,9 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install --upgrade uv - python -m uv pip install pytest - python -m uv pip install ./cuequivariance - python -m uv pip install ./cuequivariance_jax + python -m uv pip install pytest "jax[cuda12]" + python -m uv pip install ./cuequivariance --force-reinstall + python -m uv pip install ./cuequivariance_jax --force-reinstall - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_jax From c9b0c85f1d3cea4c7888af7989d1d590b4870b6f Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 5 Mar 2025 15:57:28 -0800 Subject: [PATCH 084/107] try --- .github/workflows/tests.yml | 2 ++ cuequivariance/cuequivariance/__init__.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 21c95c25..82b10993 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -55,6 +55,7 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade uv python -m uv pip install pytest "jax[cuda12]" + python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch python -m uv pip install ./cuequivariance --force-reinstall python -m uv pip install ./cuequivariance_jax --force-reinstall - name: Test with pytest @@ -84,6 +85,7 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade uv python -m uv pip install pytest torch e3nn + python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch python -m uv pip install ./cuequivariance --force-reinstall python -m uv pip install ./cuequivariance_torch --force-reinstall - name: Test with pytest diff --git a/cuequivariance/cuequivariance/__init__.py b/cuequivariance/cuequivariance/__init__.py index cfc39818..db540153 100644 --- a/cuequivariance/cuequivariance/__init__.py +++ b/cuequivariance/cuequivariance/__init__.py @@ -53,8 +53,9 @@ EquivariantTensorProduct, # deprecated ) -from cuequivariance.group_theory import descriptors -from cuequivariance import segmented_polynomials, group_theory +from cuequivariance import segmented_polynomials as segmented_polynomials +from cuequivariance import group_theory as group_theory +from cuequivariance.group_theory import descriptors as descriptors __all__ = [ "__version__", @@ -87,7 +88,7 @@ "reduced_antisymmetric_tensor_product_basis", "EquivariantPolynomial", "EquivariantTensorProduct", - "descriptors", "segmented_polynomials", "group_theory", + "descriptors", ] From e3738b83c1086ba84ab1922bb7916a3666a5e685 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 5 Mar 2025 16:08:01 -0800 Subject: [PATCH 085/107] fix --- .../segmented_polynomials/segmented_tensor_product.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index 5e154e4b..776fa40f 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -222,12 +222,14 @@ def from_base64(cls, data: str) -> SegmentedTensorProduct: ################################ Properties ################################ def __hash__(self) -> int: - return hash((self.operands, self.paths, self.coefficient_subscripts)) + return hash( + (self.operands_and_subscripts, self.paths, self.coefficient_subscripts) + ) def __eq__(self, value: SegmentedTensorProduct) -> bool: assert isinstance(value, SegmentedTensorProduct) return ( - self.operands == value.operands + self.operands_and_subscripts == value.operands_and_subscripts and self.paths == value.paths and self.coefficient_subscripts == value.coefficient_subscripts ) From 03e912bced7f5c0332869134a2a234d0bef877a8 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Wed, 5 Mar 2025 16:11:29 -0800 Subject: [PATCH 086/107] fix --- .github/workflows/tests.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 82b10993..0db7b93a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -56,8 +56,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest "jax[cuda12]" python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance --force-reinstall - python -m uv pip install ./cuequivariance_jax --force-reinstall + python -m uv pip install ./cuequivariance + python -m uv pip install ./cuequivariance_jax - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_jax @@ -86,8 +86,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest torch e3nn python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance --force-reinstall - python -m uv pip install ./cuequivariance_torch --force-reinstall + python -m uv pip install ./cuequivariance + python -m uv pip install ./cuequivariance_torch - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_torch From 493d011c4772d1f106ddabad2618b2b115c7e106 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 12:13:13 -0700 Subject: [PATCH 087/107] improve constructors of SegmentedOperand, SegmentedTensorProduct and SegmentedPolynomial --- .../segmented_operand.py | 11 +++++--- .../segmented_polynomial.py | 8 ++++-- .../segmented_tensor_product.py | 28 ++++++++++--------- 3 files changed, 27 insertions(+), 20 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py index 7e146139..1a6644df 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_operand.py @@ -30,17 +30,20 @@ class SegmentedOperand: def __init__( self, - *, - ndim: int, segments: list[tuple[int, ...]] | None = None, + *, + ndim: int | None = None, _dims: dict[int, set[int]] | None = None, ): - object.__setattr__(self, "ndim", ndim) - if segments is None: segments = [] object.__setattr__(self, "segments", tuple(segments)) + if ndim is None: + assert len(self.segments) > 0 + ndim = len(self.segments[0]) + object.__setattr__(self, "ndim", ndim) + if _dims is None: _dims = dict() for segment in self.segments: diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index 3e67b8ee..81f6e4b0 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -50,7 +50,9 @@ def __init__( outputs: tuple[cue.SegmentedOperand, ...], tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], ): - buffers = list(inputs) + list(outputs) + inputs = tuple(inputs) + outputs = tuple(outputs) + buffers = inputs + outputs for ope, stp in tensor_products: assert isinstance(ope, cue.Operation) @@ -59,8 +61,8 @@ def __init__( for buffer_id, operand in zip(ope.buffers, stp.operands): assert operand == buffers[buffer_id] - object.__setattr__(self, "inputs", tuple(inputs)) - object.__setattr__(self, "outputs", tuple(outputs)) + object.__setattr__(self, "inputs", inputs) + object.__setattr__(self, "outputs", outputs) object.__setattr__(self, "tensor_products", tuple(sorted(tensor_products))) @classmethod diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index 776fa40f..e829d4f1 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -60,35 +60,37 @@ class SegmentedTensorProduct: .. rubric:: Methods """ - operands_and_subscripts: tuple[tuple[cue.SegmentedOperand, stp.Subscripts], ...] - paths: tuple[stp.Path, ...] + operands_and_subscripts: tuple[tuple[cue.SegmentedOperand, str], ...] coefficient_subscripts: str + paths: tuple[stp.Path, ...] ################################ Initializers ################################ # From here we can use object.__setattr__ to modify the attributes def __init__( self, - *, - operands_and_subscripts: Sequence[tuple[cue.SegmentedOperand, stp.Subscripts]] + operands_and_subscripts: Sequence[tuple[cue.SegmentedOperand | None, str]] | None = None, - paths: Sequence[stp.Path] | None = None, coefficient_subscripts: str = "", + *, + paths: Sequence[stp.Path] | None = None, ): if operands_and_subscripts is None: operands_and_subscripts = [] if paths is None: paths = [] - object.__setattr__( - self, - "operands_and_subscripts", - tuple( - (ope.copy(), stp.Subscripts(ss)) for ope, ss in operands_and_subscripts - ), + operands_and_subscripts = tuple( + ( + ope.copy() if ope is not None else cue.SegmentedOperand(ndim=len(ss)), + str(ss), + ) + for ope, ss in operands_and_subscripts ) - object.__setattr__(self, "paths", tuple(paths)) + + object.__setattr__(self, "operands_and_subscripts", operands_and_subscripts) object.__setattr__(self, "coefficient_subscripts", coefficient_subscripts) + object.__setattr__(self, "paths", tuple(paths)) def set_operand(self, oid: int, operand: cue.SegmentedOperand): assert oid < len(self.operands_and_subscripts) @@ -96,7 +98,7 @@ def set_operand(self, oid: int, operand: cue.SegmentedOperand): self, "operands_and_subscripts", self.operands_and_subscripts[:oid] - + ((copy.deepcopy(operand), self.operands_and_subscripts[oid][1]),) + + ((operand.copy(), self.operands_and_subscripts[oid][1]),) + self.operands_and_subscripts[oid + 1 :], ) From 2e41831b0626f99e224bcc7229d16d5aa112817d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 13:06:30 -0700 Subject: [PATCH 088/107] add SegmentedPolynomial.from_stps --- .../segmented_polynomials/operation.py | 4 +- .../segmented_polynomial.py | 52 +++++++++++++++---- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/operation.py b/cuequivariance/cuequivariance/segmented_polynomials/operation.py index f9851c06..fcb22be5 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/operation.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/operation.py @@ -47,7 +47,9 @@ class Operation: buffers: tuple[int, ...] - def __init__(self, buffers: tuple[int, ...]): + def __init__(self, buffers: tuple[int, ...] | Operation): + if isinstance(buffers, Operation): + buffers = buffers.buffers assert len(buffers) > 0, buffers assert all(isinstance(b, int) for b in buffers), buffers assert all(i >= 0 for i in buffers), buffers diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index 81f6e4b0..f7fa0224 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -46,24 +46,30 @@ class SegmentedPolynomial: def __init__( self, - inputs: tuple[cue.SegmentedOperand, ...], - outputs: tuple[cue.SegmentedOperand, ...], - tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], + inputs: Sequence[cue.SegmentedOperand], + outputs: Sequence[cue.SegmentedOperand], + tensor_products: Sequence[ + tuple[cue.Operation | Sequence[int], cue.SegmentedTensorProduct] + ], ): inputs = tuple(inputs) outputs = tuple(outputs) buffers = inputs + outputs + _tensor_products = [] for ope, stp in tensor_products: + ope = cue.Operation(ope) assert isinstance(ope, cue.Operation) assert isinstance(stp, cue.SegmentedTensorProduct) assert len(ope.buffers) == stp.num_operands for buffer_id, operand in zip(ope.buffers, stp.operands): assert operand == buffers[buffer_id] + _tensor_products.append((ope, stp)) + _tensor_products = sorted(_tensor_products) object.__setattr__(self, "inputs", inputs) object.__setattr__(self, "outputs", outputs) - object.__setattr__(self, "tensor_products", tuple(sorted(tensor_products))) + object.__setattr__(self, "tensor_products", tuple(_tensor_products)) @classmethod def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): @@ -76,13 +82,15 @@ def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): @classmethod def from_default_buffers( cls, - inputs: tuple[cue.SegmentedOperand, ...], - outputs: tuple[cue.SegmentedOperand, ...], - tensor_products: Sequence[tuple[cue.Operation, cue.SegmentedTensorProduct]], + inputs: Sequence[cue.SegmentedOperand | None], + outputs: Sequence[cue.SegmentedOperand | None], + tensor_products: Sequence[ + tuple[cue.Operation | Sequence[int], cue.SegmentedTensorProduct] + ], ): buffers = list(inputs) + list(outputs) for ope, stp in tensor_products: - assert isinstance(ope, cue.Operation) + ope = cue.Operation(ope) assert isinstance(stp, cue.SegmentedTensorProduct) assert len(ope.buffers) == stp.num_operands for buffer_id, operand in zip(ope.buffers, stp.operands): @@ -90,6 +98,25 @@ def from_default_buffers( return cls(buffers[: len(inputs)], buffers[len(inputs) :], tensor_products) + @classmethod + def from_stps( + cls, + inputs: Sequence[cue.SegmentedOperand | None], + outputs: Sequence[cue.SegmentedOperand | None], + tensor_products: Sequence[ + tuple[cue.Operation | Sequence[int], cue.SegmentedTensorProduct] + ], + ) -> SegmentedPolynomial: + """Stack segmented tensor products together.""" + inputs, outputs = list(inputs), list(outputs) + return cls.stack( + [ + cls.from_default_buffers(inputs, outputs, [(ope, stp)]) + for ope, stp in tensor_products + ], + [ope is None for ope in inputs + outputs], + ) + def __hash__(self) -> int: return hash((self.inputs, self.outputs, self.tensor_products)) @@ -364,7 +391,14 @@ def stack( for bid in range(num_inputs + num_outputs): if stacked[bid]: operands.append( - cue.SegmentedOperand.stack([pol.operands[bid] for pol in polys]) + cue.SegmentedOperand.stack( + [ + pol.operands[bid] + for pol in polys + if pol.operands[bid] + is not None # special case for .from_stps + ] + ) ) else: ope = polys[0].operands[bid] From 82ec484bc6d041a54018748721b07f86ca85acd5 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 13:36:13 -0700 Subject: [PATCH 089/107] improve repr of SegmentedPolynomial --- .../segmented_polynomials/operation.py | 12 +++-- .../segmented_polynomial.py | 50 ++++++++++--------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/operation.py b/cuequivariance/cuequivariance/segmented_polynomials/operation.py index fcb22be5..f26c724e 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/operation.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/operation.py @@ -58,8 +58,8 @@ def __init__(self, buffers: tuple[int, ...] | Operation): def __repr__(self): return f"Operation({self.buffers})" - def to_string(self, num_inputs: int) -> str: - return " ".join(IVARS[b] if b < num_inputs else OVARS[b] for b in self.buffers) + def to_letters(self, num_inputs: int) -> list[str]: + return [IVARS[b] if b < num_inputs else OVARS[b] for b in self.buffers] @staticmethod def list_to_string( @@ -69,7 +69,7 @@ def list_to_string( o = ", ".join(OVARS[num_inputs : num_inputs + num_outputs]) s = f"({i}) -> ({o})" for op in operations: - s += "\n " + op.to_string(num_inputs) + s += "\n " + " ".join(op.to_letters(num_inputs)) return s def __lt__(self, value): @@ -86,6 +86,12 @@ def __eq__(self, value): def permute_operands(self, permutation: tuple[int, ...]) -> Operation: return Operation(tuple(self.buffers[p] for p in permutation)) + def move_operand_last(self, operand: int) -> Operation: + buffers = list(self.buffers) + b = buffers.pop(operand) + buffers.append(b) + return Operation(tuple(buffers)) + def input_operands_buffers(self, num_inputs: int) -> list[tuple[int, int]]: return [(op, i) for op, i in enumerate(self.buffers) if i < num_inputs] diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index f7fa0224..fd09ec2e 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -24,6 +24,8 @@ import cuequivariance as cue from cuequivariance.segmented_polynomials.operation import IVARS, OVARS +from .dimensions_dict import format_dimensions_dict + @dataclasses.dataclass(init=False, frozen=True) class SegmentedPolynomial: @@ -54,7 +56,7 @@ def __init__( ): inputs = tuple(inputs) outputs = tuple(outputs) - buffers = inputs + outputs + operands = inputs + outputs _tensor_products = [] for ope, stp in tensor_products: @@ -63,8 +65,12 @@ def __init__( assert isinstance(stp, cue.SegmentedTensorProduct) assert len(ope.buffers) == stp.num_operands for buffer_id, operand in zip(ope.buffers, stp.operands): - assert operand == buffers[buffer_id] - _tensor_products.append((ope, stp)) + assert operand == operands[buffer_id] + + out_oid, _ = ope.output_operand_buffer(len(inputs)) + _tensor_products.append( + (ope.move_operand_last(out_oid), stp.move_operand_last(out_oid)) + ) _tensor_products = sorted(_tensor_products) object.__setattr__(self, "inputs", inputs) @@ -170,33 +176,29 @@ def to_string(self, buffer_names: list[str] | None = None) -> str: buffer_txts[self.num_inputs : self.num_inputs + self.num_outputs] ) ) - lines = [ - "│ " + ope.to_string(self.num_inputs) for ope, _ in self.tensor_products - ] + + def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct) -> str: + items = [ + f"{buffer}[{ss}]" + for buffer, ss in zip( + ope.to_letters(self.num_inputs), stp.subscripts.operands + ) + ] + return "·".join(items[:-1]) + "➜" + items[-1] + + lines = ["│ " + f(ope, stp) for ope, stp in self.tensor_products] if len(lines) > 0: lines[-1] = "╰─" + lines[-1][2:] - n = max(len(line) for line in lines) + n = max(len(line) for line in lines) lines = [ - line + " " + "─" * (n - len(line)) + "─ " + str(stp) + line + + " " + + "─" * (n - len(line)) + + "─ " + + f"num_paths={stp.num_paths} {format_dimensions_dict(stp.get_dimensions_dict())}" for line, (_, stp) in zip(lines, self.tensor_products) ] - - modes = sorted( - {mode for _, stp in self.tensor_products for mode in stp.subscripts.modes()} - ) - if len(modes) > 1: - modes = [] - for a in ["sizes=", "num_segments=", "num_paths="] + [f"{m}=" for m in modes]: - if not all(line.count(a) == 1 for line in lines): - continue - - splits = [line.split(a) for line in lines] - n = max(len(before) for before, _ in splits) - lines = [ - before + " " * (n - len(before)) + a + after for before, after in splits - ] - lines = ["╭ " + header] + lines return "\n".join(lines) From 247ea9e26b228bf84e01519d94055d7c5159f019 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 13:43:22 -0700 Subject: [PATCH 090/107] fix --- .../segmented_polynomials/segmented_polynomial.py | 1 + .../cuequivariance_jax/equivariant_polynomial.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index fd09ec2e..15061006 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -201,6 +201,7 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct) -> str: ] lines = ["╭ " + header] + lines + lines = [line.rstrip() for line in lines] return "\n".join(lines) def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: diff --git a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py index ec23efe6..9d854c4f 100644 --- a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py @@ -56,9 +56,9 @@ def equivariant_polynomial( >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) >>> e ╭ a=1 -> B=0+1+2 - │ B ───── sizes=9 num_segments=9 num_paths=1 - │ a B ─── , sizes=3,9 num_segments=3,9 num_paths=3 - ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=11 + │ ➜B[] ──────── num_paths=1 + │ a[]➜B[] ───── num_paths=3 + ╰─ a[]·a[]➜B[] ─ num_paths=11 Basic usage with single input: From 01f1c986d5aa341318d25cba2c4ba48e8d6fb2e3 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 13:49:18 -0700 Subject: [PATCH 091/107] fix --- .../cuequivariance/group_theory/descriptors/irreps_tp.py | 2 +- .../group_theory/descriptors/spherical_harmonics_.py | 6 +++--- .../group_theory/descriptors/symmetric_contractions.py | 6 +++--- .../segmented_polynomials/segmented_polynomial.py | 4 +++- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py index e2768310..c6b3ba24 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py @@ -48,7 +48,7 @@ def fully_connected_tensor_product( ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), ... ) ╭ a=61440x0 b=16x0+16x1+16x2 c=16x0+16x1+16x2 -> D=16x0+16x1+16x2 - ╰─ a b c D ─ uvw,iu,jv,kw+ijk sizes=61440,144,144,144 num_segments=15,3,3,3 num_paths=15 i={1, 3, 5} j={1, 3, 5} k={1, 3, 5} u=16 v=16 w=16 + ╰─ a[uvw]·b[iu]·c[jv]·[ijk]➜D[kw] ─ num_paths=15 i={1, 3, 5} j={1, 3, 5} k={1, 3, 5} u=16 v=16 w=16 Where ``61440x0`` are the 61440 weights needed to mix all the inputs with all the outputs. """ diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index fe7935b9..bd53bc73 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -39,9 +39,9 @@ def spherical_harmonics( Example: >>> spherical_harmonics(cue.SO3(1), [0, 1, 2]) ╭ a=1 -> B=0+1+2 - │ B ───── sizes=9 num_segments=9 num_paths=1 - │ a B ─── , sizes=3,9 num_segments=3,9 num_paths=3 - ╰─ a a B ─ ,, sizes=3,3,9 num_segments=3,3,9 num_paths=11 + │ []➜B[] ───────── num_paths=1 + │ a[]·[]➜B[] ───── num_paths=3 + ╰─ a[]·a[]·[]➜B[] ─ num_paths=11 """ if len(ls) != 1: return cue.EquivariantPolynomial.stack( diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py index dff2873f..673421d5 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py @@ -43,9 +43,9 @@ def symmetric_contraction( ... [1, 2, 3] ... ) ╭ a=32x0+80x0+176x0 b=16x0+16x1+16x2 -> C=16x0+16x1 - │ a b C ───── u,u,u sizes=288,144,64 num_segments=18,9,4 num_paths=4 u=16 - │ a b b C ─── u,u,u,u sizes=288,144,144,64 num_segments=18,9,9,4 num_paths=37 u=16 - ╰─ a b b b C ─ u,u,u,u,u sizes=288,144,144,144,64 num_segments=18,9,9,9,4 num_paths=437 u=16 + │ a[u]·b[u]·[]➜C[u] ─────────── num_paths=4 u=16 + │ a[u]·b[u]·b[u]·[]➜C[u] ────── num_paths=37 u=16 + ╰─ a[u]·b[u]·b[u]·b[u]·[]➜C[u] ─ num_paths=437 u=16 Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). """ diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index 15061006..6338eb10 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -184,7 +184,9 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct) -> str: ope.to_letters(self.num_inputs), stp.subscripts.operands ) ] - return "·".join(items[:-1]) + "➜" + items[-1] + out = items[-1] + items = items[:-1] + [f"[{stp.coefficient_subscripts}]"] + return "·".join(items) + "➜" + out lines = ["│ " + f(ope, stp) for ope, stp in self.tensor_products] if len(lines) > 0: From 0da919749ce41938da6e30a77788ae6bbf724f6b Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 15:07:45 -0700 Subject: [PATCH 092/107] try --- .github/workflows/tests.yml | 6 +++--- .../cuequivariance_jax/equivariant_polynomial.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 0db7b93a..945be6a3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -56,8 +56,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest "jax[cuda12]" python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance - python -m uv pip install ./cuequivariance_jax + python -m uv pip install ./cuequivariance_jax + python -m uv pip install ./cuequivariance - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_jax @@ -86,8 +86,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest torch e3nn python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance python -m uv pip install ./cuequivariance_torch + python -m uv pip install ./cuequivariance - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_torch diff --git a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py index 9d854c4f..cb6abda9 100644 --- a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py @@ -56,9 +56,9 @@ def equivariant_polynomial( >>> e = cue.descriptors.spherical_harmonics(cue.SO3(1), [0, 1, 2]) >>> e ╭ a=1 -> B=0+1+2 - │ ➜B[] ──────── num_paths=1 - │ a[]➜B[] ───── num_paths=3 - ╰─ a[]·a[]➜B[] ─ num_paths=11 + │ []➜B[] ───────── num_paths=1 + │ a[]·[]➜B[] ───── num_paths=3 + ╰─ a[]·a[]·[]➜B[] ─ num_paths=11 Basic usage with single input: From dadd6d4738d990b0f88b7c81462a463438f7953a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 15:12:01 -0700 Subject: [PATCH 093/107] try --- .github/workflows/tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 945be6a3..ad0386be 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade uv python -m uv pip install pytest - python -m uv pip install ./cuequivariance + python -m uv pip install ./cuequivariance --force-reinstall - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance @@ -56,8 +56,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest "jax[cuda12]" python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance_jax - python -m uv pip install ./cuequivariance + python -m uv pip install ./cuequivariance --force-reinstall + python -m uv pip install ./cuequivariance_jax --force-reinstall - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_jax @@ -86,8 +86,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest torch e3nn python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance_torch - python -m uv pip install ./cuequivariance + python -m uv pip install ./cuequivariance --force-reinstall + python -m uv pip install ./cuequivariance_torch --force-reinstall - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_torch From 6945c482d1e5e410f00d475b770f51435dfabc47 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 15:14:08 -0700 Subject: [PATCH 094/107] try --- .github/workflows/tests.yml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index ad0386be..fed0aeb5 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -58,6 +58,8 @@ jobs: python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch python -m uv pip install ./cuequivariance --force-reinstall python -m uv pip install ./cuequivariance_jax --force-reinstall + python -c "import cuequivariance; print('cue', cuequivariance.__version__)" + python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)" - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_jax @@ -88,6 +90,8 @@ jobs: python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch python -m uv pip install ./cuequivariance --force-reinstall python -m uv pip install ./cuequivariance_torch --force-reinstall + python -c "import cuequivariance; print('cue', cuequivariance.__version__)" + python -c "import cuequivariance_torch; print('cuet', cuequivariance_torch.__version__)" - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance_torch From ffceff64d8486b005d69b13191f5b019bae8e191 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 15:17:03 -0700 Subject: [PATCH 095/107] test --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 9325c3cc..a2268e2d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.3.0 \ No newline at end of file +0.3.1 \ No newline at end of file From 67f3f4eba01f23d83429d26d2d8e91e20ba526c6 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 15:19:49 -0700 Subject: [PATCH 096/107] fix --- .github/workflows/tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fed0aeb5..e667de79 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade uv python -m uv pip install pytest - python -m uv pip install ./cuequivariance --force-reinstall + python -m uv pip install ./cuequivariance - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance @@ -56,8 +56,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest "jax[cuda12]" python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance --force-reinstall - python -m uv pip install ./cuequivariance_jax --force-reinstall + python -m uv pip install ./cuequivariance + python -m uv pip install --no-deps ./cuequivariance_jax python -c "import cuequivariance; print('cue', cuequivariance.__version__)" python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)" - name: Test with pytest @@ -88,8 +88,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest torch e3nn python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance --force-reinstall - python -m uv pip install ./cuequivariance_torch --force-reinstall + python -m uv pip install ./cuequivariance + python -m uv pip install --no-deps ./cuequivariance_torch python -c "import cuequivariance; print('cue', cuequivariance.__version__)" python -c "import cuequivariance_torch; print('cuet', cuequivariance_torch.__version__)" - name: Test with pytest From c1df76cd025924b7d6a6a8b6b28185d2d71abc12 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 15:21:17 -0700 Subject: [PATCH 097/107] reset --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index a2268e2d..9325c3cc 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.3.1 \ No newline at end of file +0.3.0 \ No newline at end of file From 0cd795d8e778ccf6abeeabb539fc78a5d616b977 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Mar 2025 08:49:39 -0700 Subject: [PATCH 098/107] SegmentedPolynomial: change operation ordering and repr (coeff first) --- .../group_theory/descriptors/irreps_tp.py | 2 +- .../descriptors/spherical_harmonics_.py | 4 ++-- .../descriptors/symmetric_contractions.py | 6 +++--- .../segmented_polynomial.py | 17 +++++++++-------- .../equivariant_polynomial.py | 4 ++-- .../segmented_polynomial.py | 2 +- 6 files changed, 18 insertions(+), 17 deletions(-) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py index c6b3ba24..5008efc0 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py @@ -48,7 +48,7 @@ def fully_connected_tensor_product( ... 16 * cue.Irreps("SO3", "0 + 1 + 2"), ... ) ╭ a=61440x0 b=16x0+16x1+16x2 c=16x0+16x1+16x2 -> D=16x0+16x1+16x2 - ╰─ a[uvw]·b[iu]·c[jv]·[ijk]➜D[kw] ─ num_paths=15 i={1, 3, 5} j={1, 3, 5} k={1, 3, 5} u=16 v=16 w=16 + ╰─ [ijk]·a[uvw]·b[iu]·c[jv]➜D[kw] ─ num_paths=15 i={1, 3, 5} j={1, 3, 5} k={1, 3, 5} u=16 v=16 w=16 Where ``61440x0`` are the 61440 weights needed to mix all the inputs with all the outputs. """ diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index bd53bc73..f54fd76c 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -40,8 +40,8 @@ def spherical_harmonics( >>> spherical_harmonics(cue.SO3(1), [0, 1, 2]) ╭ a=1 -> B=0+1+2 │ []➜B[] ───────── num_paths=1 - │ a[]·[]➜B[] ───── num_paths=3 - ╰─ a[]·a[]·[]➜B[] ─ num_paths=11 + │ []·a[]➜B[] ───── num_paths=3 + ╰─ []·a[]·a[]➜B[] ─ num_paths=11 """ if len(ls) != 1: return cue.EquivariantPolynomial.stack( diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py index 673421d5..ea121ad0 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py @@ -43,9 +43,9 @@ def symmetric_contraction( ... [1, 2, 3] ... ) ╭ a=32x0+80x0+176x0 b=16x0+16x1+16x2 -> C=16x0+16x1 - │ a[u]·b[u]·[]➜C[u] ─────────── num_paths=4 u=16 - │ a[u]·b[u]·b[u]·[]➜C[u] ────── num_paths=37 u=16 - ╰─ a[u]·b[u]·b[u]·b[u]·[]➜C[u] ─ num_paths=437 u=16 + │ []·a[u]·b[u]➜C[u] ─────────── num_paths=4 u=16 + │ []·a[u]·b[u]·b[u]➜C[u] ────── num_paths=37 u=16 + ╰─ []·a[u]·b[u]·b[u]·b[u]➜C[u] ─ num_paths=437 u=16 Where ``32x0+80x0+176x0`` are the weights needed for each degree (32 for degree 1, 80 for degree 2, 176 for degree 3). """ diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index 6338eb10..75a645c4 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -59,19 +59,20 @@ def __init__( operands = inputs + outputs _tensor_products = [] - for ope, stp in tensor_products: - ope = cue.Operation(ope) - assert isinstance(ope, cue.Operation) + for opt, stp in tensor_products: + opt = cue.Operation(opt) + assert isinstance(opt, cue.Operation) assert isinstance(stp, cue.SegmentedTensorProduct) - assert len(ope.buffers) == stp.num_operands - for buffer_id, operand in zip(ope.buffers, stp.operands): + assert len(opt.buffers) == stp.num_operands + for buffer_id, operand in zip(opt.buffers, stp.operands): assert operand == operands[buffer_id] - out_oid, _ = ope.output_operand_buffer(len(inputs)) + out_oid, bid = opt.output_operand_buffer(len(inputs)) _tensor_products.append( - (ope.move_operand_last(out_oid), stp.move_operand_last(out_oid)) + (bid, opt.move_operand_last(out_oid), stp.move_operand_last(out_oid)) ) _tensor_products = sorted(_tensor_products) + _tensor_products = [(opt, stp) for _, opt, stp in _tensor_products] object.__setattr__(self, "inputs", inputs) object.__setattr__(self, "outputs", outputs) @@ -185,7 +186,7 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct) -> str: ) ] out = items[-1] - items = items[:-1] + [f"[{stp.coefficient_subscripts}]"] + items = [f"[{stp.coefficient_subscripts}]"] + items[:-1] return "·".join(items) + "➜" + out lines = ["│ " + f(ope, stp) for ope, stp in self.tensor_products] diff --git a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py index cb6abda9..26f72130 100644 --- a/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/equivariant_polynomial.py @@ -57,8 +57,8 @@ def equivariant_polynomial( >>> e ╭ a=1 -> B=0+1+2 │ []➜B[] ───────── num_paths=1 - │ a[]·[]➜B[] ───── num_paths=3 - ╰─ a[]·a[]·[]➜B[] ─ num_paths=11 + │ []·a[]➜B[] ───── num_paths=3 + ╰─ []·a[]·a[]➜B[] ─ num_paths=11 Basic usage with single input: diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py index 2cb49de7..0a0b4bb4 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py @@ -69,7 +69,7 @@ def segmented_polynomial( Features: - CUDA kernel activation conditions: - STPs have a single mode which is a multiple of 32 (e.g. channelwise - tensor product with subscripts ``u,u,,u`` where u=128) + tensor product with subscripts ``u,u,,u`` where u=128) - Math data type is float32 or float64 - Input/output data types can be float32, float64, float16, or bfloat16 - Indices must be int32 From 4183fdd35f35da471d39a6c53d7b1c6bf00dc929 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Mar 2025 08:54:04 -0700 Subject: [PATCH 099/107] test --- .github/workflows/tests.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e667de79..196060f4 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,7 +27,7 @@ jobs: python -m pip install --upgrade pip python -m pip install --upgrade uv python -m uv pip install pytest - python -m uv pip install ./cuequivariance + python -m uv pip install -e ./cuequivariance - name: Test with pytest run: | pytest --doctest-modules -x cuequivariance @@ -56,8 +56,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest "jax[cuda12]" python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance - python -m uv pip install --no-deps ./cuequivariance_jax + python -m uv pip install -e ./cuequivariance + python -m uv pip install --no-deps -e ./cuequivariance_jax python -c "import cuequivariance; print('cue', cuequivariance.__version__)" python -c "import cuequivariance_jax; print('cuex', cuequivariance_jax.__version__)" - name: Test with pytest @@ -88,8 +88,8 @@ jobs: python -m pip install --upgrade uv python -m uv pip install pytest torch e3nn python -m uv pip uninstall cuequivariance cuequivariance_jax cuequivariance_torch - python -m uv pip install ./cuequivariance - python -m uv pip install --no-deps ./cuequivariance_torch + python -m uv pip install -e ./cuequivariance + python -m uv pip install --no-deps -e ./cuequivariance_torch python -c "import cuequivariance; print('cue', cuequivariance.__version__)" python -c "import cuequivariance_torch; print('cuet', cuequivariance_torch.__version__)" - name: Test with pytest From 77ba13642d466392a527a14c2248490a873caf80 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Mar 2025 08:58:31 -0700 Subject: [PATCH 100/107] remove python version --- .python-version | 1 - 1 file changed, 1 deletion(-) delete mode 100644 .python-version diff --git a/.python-version b/.python-version deleted file mode 100644 index 24ee5b1b..00000000 --- a/.python-version +++ /dev/null @@ -1 +0,0 @@ -3.13 From 03b211a966b2d39a02a1987669b715107bca383d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Mar 2025 09:16:24 -0700 Subject: [PATCH 101/107] remove import ... as stp --- .../group_theory/descriptors/irreps_tp.py | 11 ++- .../group_theory/descriptors/rotations.py | 19 +++-- .../descriptors/spherical_harmonics_.py | 33 ++++---- .../descriptors/symmetric_contractions.py | 5 +- .../equivariant_tensor_product.py | 17 +++-- .../segmented_polynomials/dispatch.py | 11 +-- .../segmented_tensor_product.py | 76 +++++++++---------- .../tests/segmented_polynomials/dot_test.py | 22 +++--- .../segmented_polynomials/subscripts_test.py | 16 ++-- .../primitives/tensor_product.py | 28 +++---- .../symmetric_tensor_product_test.py | 9 +-- 11 files changed, 126 insertions(+), 121 deletions(-) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py index 5008efc0..de2a7f2b 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/irreps_tp.py @@ -16,7 +16,6 @@ from typing import Optional, Sequence import cuequivariance as cue -from cuequivariance import segmented_polynomials as stp from cuequivariance.group_theory.irreps_array.irrep_utils import into_list_of_irrep @@ -54,7 +53,7 @@ def fully_connected_tensor_product( """ G = irreps1.irrep_class - d = stp.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("uvw,iu,jv,kw+ijk") for mul, ir in irreps1: d.add_segment(1, (ir.dim, mul)) @@ -110,7 +109,7 @@ def full_tensor_product( if irreps3_filter is not None: irreps3_filter = into_list_of_irrep(G, irreps3_filter) - d = stp.SegmentedTensorProduct.from_subscripts("iu,jv,kuv+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iu,jv,kuv+ijk") for mul, ir in irreps1: d.add_segment(0, (ir.dim, mul)) @@ -174,7 +173,7 @@ def channelwise_tensor_product( if irreps3_filter is not None: irreps3_filter = into_list_of_irrep(G, irreps3_filter) - d = stp.SegmentedTensorProduct.from_subscripts("uv,iu,jv,kuv+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("uv,iu,jv,kuv+ijk") for mul, ir in irreps1: d.add_segment(1, (ir.dim, mul)) @@ -269,7 +268,7 @@ def elementwise_tensor_product( irreps1_cut, irreps2_cut = _align_two_irreps(irreps1, irreps2, cue.ir_mul) - d = stp.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("iu,ju,ku+ijk") irreps3 = [] for (mul, ir1), (_, ir2) in zip(irreps1_cut, irreps2_cut): @@ -310,7 +309,7 @@ def linear(irreps_in: cue.Irreps, irreps_out: cue.Irreps) -> cue.EquivariantPoly Returns: :class:`cue.EquivariantPolynomial `: Descriptor of the linear transformation. """ - d = stp.SegmentedTensorProduct.from_subscripts("uv_iu_iv") + d = cue.SegmentedTensorProduct.from_subscripts("uv_iu_iv") for mul, ir in irreps_in: d.add_segment(1, (ir.dim, mul)) for mul, ir in irreps_out: diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py index ebb4e875..fd4740fc 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py @@ -17,7 +17,6 @@ import numpy as np import cuequivariance as cue -from cuequivariance import segmented_polynomials as stp def fixed_axis_angle_rotation( @@ -30,7 +29,7 @@ def fixed_axis_angle_rotation( """ assert irreps.irrep_class in [cue.SO3, cue.O3] - d = stp.SegmentedTensorProduct.from_subscripts("iu_ju+ij") + d = cue.SegmentedTensorProduct.from_subscripts("iu_ju+ij") for mul, ir in irreps: # Note the transpose @@ -72,7 +71,9 @@ def yxy_rotation( # gamma, beta, input, A cbio = xy_rotation(irreps, lmax).polynomial.tensor_products[0][1] aio = y_rotation(irreps, lmax).polynomial.tensor_products[0][1] # alpha, A, output - cbiao = stp.dot(cbio, aio, (3, 1)) # gamma, beta, input, alpha, output + cbiao = cue.segmented_polynomials.dot( + cbio, aio, (3, 1) + ) # gamma, beta, input, alpha, output cbaio = cbiao.move_operand(2, 3) # gamma, beta, alpha, input, output return cue.EquivariantPolynomial( [ @@ -96,7 +97,7 @@ def xy_rotation( """ cio = y_rotation(irreps, lmax).polynomial.tensor_products[0][1] # gamma, input, A bio = x_rotation(irreps, lmax).polynomial.tensor_products[0][1] # beta, A, output - cibo = stp.dot(cio, bio, (2, 1)) # gamma, input, beta, output + cibo = cue.segmented_polynomials.dot(cio, bio, (2, 1)) # gamma, input, beta, output cbio = cibo.move_operand(1, 2) # gamma, beta, input, output return cue.EquivariantPolynomial( [ @@ -119,7 +120,7 @@ def yx_rotation( """ cio = x_rotation(irreps, lmax).d bio = y_rotation(irreps, lmax).d - cibo = stp.dot(cio, bio, (2, 1)) + cibo = cue.segmented_polynomials.dot(cio, bio, (2, 1)) cbio = cibo.move_operand(1, 2) return cue.EquivariantPolynomial( [ @@ -151,7 +152,7 @@ def y_rotation( if lmax is None: lmax = max(ir.l for _, ir in irreps) - d = stp.SegmentedTensorProduct.from_subscripts("i_ju_ku+ijk") + d = cue.SegmentedTensorProduct.from_subscripts("i_ju_ku+ijk") phc = d.add_segment( 0, (lmax,) ) # cos(th * lmax), cos(th * (lmax - 1)), ..., cos(th) @@ -221,7 +222,9 @@ def x_rotation( dz90 = fixed_axis_angle_rotation( irreps, np.array([0.0, 0.0, 1.0]), np.pi / 2.0 ).polynomial.tensor_products[0][1] - d = stp.dot(stp.dot(dy, dz90, (1, 1)), dz90, (1, 1)) + d = cue.segmented_polynomials.dot( + cue.segmented_polynomials.dot(dy, dz90, (1, 1)), dz90, (1, 1) + ) return cue.EquivariantPolynomial( [ @@ -237,7 +240,7 @@ def inversion(irreps: cue.Irreps) -> cue.EquivariantPolynomial: """ subsrcipts: ``input[u],output[u]`` """ - d = stp.SegmentedTensorProduct.from_subscripts("iu_ju+ji") + d = cue.SegmentedTensorProduct.from_subscripts("iu_ju+ji") for mul, ir in irreps: assert len(ir.H) == 1 H = ir.H[0] diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index f54fd76c..890bb159 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -14,10 +14,9 @@ # limitations under the License. from functools import cache -import sympy as sp +import sympy import cuequivariance as cue -from cuequivariance import segmented_polynomials as stp from cuequivariance.etc.sympy_utils import sqrtQarray_to_sympy @@ -52,9 +51,11 @@ def spherical_harmonics( ir, formula = sympy_spherical_harmonics(ir_vec, ell) assert ir_vec.dim == 3 - d = stp.SegmentedTensorProduct.empty_segments([3] * ell + [ir.dim]) + d = cue.SegmentedTensorProduct.empty_segments([3] * ell + [ir.dim]) for i in range(ir.dim): - for degrees, coeff in sp.Poly(formula[i], sp.symbols("x:3")).as_dict().items(): + for degrees, coeff in ( + sympy.Poly(formula[i], sympy.symbols("x:3")).as_dict().items() + ): indices = poly_degrees_to_path_indices(degrees) d.add_path(*indices, i, c=coeff) @@ -66,8 +67,8 @@ def spherical_harmonics( cue.IrrepsAndLayout(cue.Irreps(ir), cue.ir_mul), ], cue.SegmentedPolynomial( - [cue.SegmentedOperand(ndim=0, segments=[()] * 3)], - [cue.SegmentedOperand(ndim=0, segments=[()] * ir.dim)], + [cue.SegmentedOperand([()] * 3)], + [cue.SegmentedOperand([()] * ir.dim)], [(cue.Operation([0] * ell + [1]), d)], ), ) @@ -83,14 +84,14 @@ def poly_degrees_to_path_indices(degrees: tuple[int, ...]) -> tuple[int, ...]: @cache def sympy_spherical_harmonics( ir_vec: cue.Irrep, ell: int -) -> tuple[cue.Irrep, sp.Array]: +) -> tuple[cue.Irrep, sympy.Array]: if ell == 0: - return ir_vec.trivial(), sp.Array([1]) + return ir_vec.trivial(), sympy.Array([1]) if ell == 1: assert ir_vec.dim == 3 - x = sp.symbols("x:3") - return ir_vec, sp.sqrt(3) * sp.Array([x[0], x[1], x[2]]) + x = sympy.symbols("x:3") + return ir_vec, sympy.sqrt(3) * sympy.Array([x[0], x[1], x[2]]) l2 = ell // 2 l1 = ell - l2 @@ -98,11 +99,11 @@ def sympy_spherical_harmonics( ir2, yl2 = sympy_spherical_harmonics(ir_vec, l2) ir = sorted(cue.selection_rule_product(ir1, ir2))[-1] - def sh_var(ir: cue.Irrep, ell: int) -> list[sp.Symbol]: - return [sp.symbols(f"sh{ell}_{m}") for m in range(ir.dim)] + def sh_var(ir: cue.Irrep, ell: int) -> list[sympy.Symbol]: + return [sympy.symbols(f"sh{ell}_{m}") for m in range(ir.dim)] cg = sqrtQarray_to_sympy(ir_vec.clebsch_gordan(ir1, ir2, ir).squeeze(0)) - yl = sp.Array( + yl = sympy.Array( [ sum( sh_var(ir1, l1)[i] * sh_var(ir2, l2)[j] * cg[i, j, k] @@ -116,7 +117,7 @@ def sh_var(ir: cue.Irrep, ell: int) -> list[sp.Symbol]: y = yl.subs(zip(sh_var(ir1, l1), yl1)).subs(zip(sh_var(ir2, l2), yl2)) cst = y.subs({"x0": 0, "x1": 1, "x2": 0}) - norm = sp.sqrt(sum(cst.applyfunc(lambda x: x**2))) + norm = sympy.sqrt(sum(cst.applyfunc(lambda x: x**2))) - y = sp.sqrt(sp.Integer(ir.dim)) * y / norm - return ir, sp.simplify(y) + y = sympy.sqrt(sympy.Integer(ir.dim)) * y / norm + return ir, sympy.simplify(y) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py index ea121ad0..9a94bec4 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/symmetric_contractions.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import cuequivariance as cue -from cuequivariance import segmented_polynomials as stp def symmetric_contraction( @@ -73,7 +72,7 @@ def symmetric_contraction( input_operand = cue.SegmentedOperand(ndim=1, segments=[(mul,)] * irreps_in.dim) if degree == 0: - d = stp.SegmentedTensorProduct.from_subscripts("i_i") + d = cue.SegmentedTensorProduct.from_subscripts("i_i") for _, ir in irreps_out: if not ir.is_scalar(): d.add_segment(output_operand, {"i": ir.dim}) @@ -83,7 +82,7 @@ def symmetric_contraction( else: abc = "abcdefgh"[:degree] - d = stp.SegmentedTensorProduct.from_subscripts( + d = cue.SegmentedTensorProduct.from_subscripts( f"w_{'_'.join(f'{a}' for a in abc)}_i+{abc}iw" ) diff --git a/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py b/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py index a1e8e426..b5f5613f 100644 --- a/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py +++ b/cuequivariance/cuequivariance/group_theory/equivariant_tensor_product.py @@ -20,7 +20,6 @@ from typing import Optional, Sequence, Union import cuequivariance as cue -from cuequivariance import segmented_polynomials as stp @dataclasses.dataclass(init=False, frozen=True) @@ -55,11 +54,11 @@ class EquivariantTensorProduct: """ operands: tuple[cue.Rep, ...] - ds: list[stp.SegmentedTensorProduct] + ds: list[cue.SegmentedTensorProduct] def __init__( self, - d: Union[stp.SegmentedTensorProduct, Sequence[stp.SegmentedTensorProduct]], + d: Union[cue.SegmentedTensorProduct, Sequence[cue.SegmentedTensorProduct]], operands: list[cue.Rep], symmetrize: bool = True, ): @@ -70,7 +69,7 @@ def __init__( stacklevel=2, ) operands = tuple(operands) - if isinstance(d, stp.SegmentedTensorProduct): + if isinstance(d, cue.SegmentedTensorProduct): assert len(operands) == d.num_operands for oid in range(d.num_operands): assert operands[oid].dim == d.operands[oid].size @@ -129,7 +128,7 @@ def __rmul__(self, factor: float) -> EquivariantTensorProduct: return self.__mul__(factor) @property - def d(self) -> stp.SegmentedTensorProduct: + def d(self) -> cue.SegmentedTensorProduct: assert len(self.ds) == 1 return self.ds[0] @@ -269,7 +268,7 @@ def change_layout( del layout layouts = [cue.IrrepsLayout.as_layout(layout) for layout in layouts] - def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: + def f(d: cue.SegmentedTensorProduct) -> cue.SegmentedTensorProduct: ii = self.map_operands(d.num_operands) assert len(ii) == d.num_operands @@ -303,7 +302,9 @@ def f(d: stp.SegmentedTensorProduct) -> stp.SegmentedTensorProduct: continue raise NotImplementedError return d.add_or_transpose_modes( - stp.Subscripts.from_operands(new_subscripts, d.coefficient_subscripts) + cue.segmented_polynomials.Subscripts.from_operands( + new_subscripts, d.coefficient_subscripts + ) ) return EquivariantTensorProduct( @@ -405,7 +406,7 @@ def stack( assert all(e.operands[oid] == ope for e in es) new_operands.append(ope) - new_ds: dict[int, stp.SegmentedTensorProduct] = {} + new_ds: dict[int, cue.SegmentedTensorProduct] = {} for eid, e in enumerate(es): for d in e.ds: d = copy.deepcopy(d) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py b/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py index 975b00e8..40cf2ec3 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/dispatch.py @@ -16,14 +16,15 @@ import math from typing import Generator, Tuple -import cuequivariance.segmented_polynomials as stp # we cannot import cuequivariance as cue because of circular import +# we cannot import cuequivariance as cue because of circular import +from cuequivariance.segmented_polynomials import SegmentedTensorProduct, Subscripts def dispatch( - descriptor: stp.SegmentedTensorProduct, - targets: list[stp.Subscripts], + descriptor: SegmentedTensorProduct, + targets: list[Subscripts], permutation_mode: str, -) -> Generator[Tuple[stp.SegmentedTensorProduct, Tuple[int, ...]], None, None]: +) -> Generator[Tuple[SegmentedTensorProduct, Tuple[int, ...]], None, None]: """Dispatch a descriptor to a target subscripts. Args: @@ -41,7 +42,7 @@ def dispatch( and dispatch the flattened descriptor to the target subscripts. The function will yield all the possible dispatches found. """ - targets = [stp.Subscripts(subscripts) for subscripts in targets] + targets = [Subscripts(subscripts) for subscripts in targets] targets = [ subscripts for subscripts in targets diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index e829d4f1..05c18a76 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -31,12 +31,12 @@ import opt_einsum import cuequivariance as cue # noqa: F401 -import cuequivariance.segmented_polynomials as stp from cuequivariance.etc.linalg import round_to_rational, round_to_sqrt_rational from cuequivariance.etc.permutations import ( generate_permutations_from, inverse_permutation, ) +from cuequivariance.segmented_polynomials import Path, Subscripts from .dimensions_dict import format_dimensions_dict @@ -62,7 +62,7 @@ class SegmentedTensorProduct: operands_and_subscripts: tuple[tuple[cue.SegmentedOperand, str], ...] coefficient_subscripts: str - paths: tuple[stp.Path, ...] + paths: tuple[Path, ...] ################################ Initializers ################################ @@ -73,7 +73,7 @@ def __init__( | None = None, coefficient_subscripts: str = "", *, - paths: Sequence[stp.Path] | None = None, + paths: Sequence[Path] | None = None, ): if operands_and_subscripts is None: operands_and_subscripts = [] @@ -102,10 +102,10 @@ def set_operand(self, oid: int, operand: cue.SegmentedOperand): + self.operands_and_subscripts[oid + 1 :], ) - def set_paths(self, paths: list[stp.Path]): + def set_paths(self, paths: list[Path]): object.__setattr__(self, "paths", tuple(copy.deepcopy(path) for path in paths)) - def insert_path_(self, path_index: int, path: stp.Path): + def insert_path_(self, path_index: int, path: Path): object.__setattr__( self, "paths", @@ -119,7 +119,7 @@ def operands(self) -> tuple[cue.SegmentedOperand, ...]: return tuple(ope for ope, _ in self.operands_and_subscripts) def assert_valid(self): - assert stp.Subscripts.is_valid(self.subscripts) + assert Subscripts.is_valid(self.subscripts) for m in self.subscripts.modes(): if self.subscripts.count(m) == 1: @@ -152,7 +152,7 @@ def assert_valid(self): ) @classmethod - def from_subscripts(cls, subscripts: stp.Subscripts) -> SegmentedTensorProduct: + def from_subscripts(cls, subscripts: Subscripts) -> SegmentedTensorProduct: r""" Create a descriptor from a subscripts string. @@ -166,7 +166,7 @@ def from_subscripts(cls, subscripts: stp.Subscripts) -> SegmentedTensorProduct: >>> print(d) uv,ui,vj+ij operands=[(2, 3)],[(2, 5)],[(3, 4)] paths=[op0[0]*op1[0]*op2[0]*c c.shape=(5, 4) c.nnz=20] """ - subscripts = stp.Subscripts(subscripts) + subscripts = Subscripts(subscripts) operands = [ cue.SegmentedOperand(ndim=len(operand)) for operand in subscripts.operands ] @@ -280,9 +280,9 @@ def num_paths(self) -> int: return len(self.paths) @property - def subscripts(self) -> stp.Subscripts: + def subscripts(self) -> Subscripts: """Subscripts of the tensor product.""" - return stp.Subscripts.from_operands( + return Subscripts.from_operands( [subscripts for _, subscripts in self.operands_and_subscripts], self.coefficient_subscripts, ) @@ -486,7 +486,7 @@ def get_dims(self, m: str) -> set[int]: return self.get_dimensions_dict().get(m, set()) def get_path_dimensions_dict( - self, path: Union[int, stp.Path], *, returns_sets: bool = False + self, path: Union[int, Path], *, returns_sets: bool = False ) -> dict[str, Union[int, set[int]]]: """Get the dimensions of a specific path.""" if isinstance(path, int): @@ -505,7 +505,7 @@ def get_path_dimensions_dict( return {m: next(iter(dd)) for m, dd in dims.items()} def get_path_dim( - self, path: Union[int, stp.Path], m: str, *, returns_set=False + self, path: Union[int, Path], m: str, *, returns_set=False ) -> Union[int, set[int]]: """Get the dimension of a specific mode in a specific path.""" if isinstance(path, int): @@ -514,14 +514,14 @@ def get_path_dim( m, set() if returns_set else 0 ) - def segment_slice(self, operand: int, path: Union[int, stp.Path]) -> slice: + def segment_slice(self, operand: int, path: Union[int, Path]) -> slice: """Get the slice of the segment in the given operand selected by the given path.""" if isinstance(path, int): path = self.paths[path] return self.operands[operand].segment_slices()[path.indices[operand]] def get_segment_shape( - self, operand: int, path: Union[int, stp.Path] + self, operand: int, path: Union[int, Path] ) -> tuple[int, ...]: """Get the shape of the segment in the given operand selected by the given path.""" if isinstance(path, int): @@ -721,7 +721,7 @@ def insert_path( dims = {m: next(iter(dd)) for m, dd in dims.items()} - path = stp.Path( + path = Path( [ ( (s + self.operands[oid].num_segments) @@ -795,7 +795,7 @@ def insert_segments(self, operand: int, sid: int, segments: list[tuple[int, ...] ) self.set_paths( [ - stp.Path( + Path( [ s if s < sid or oid != operand else s + n.num_segments for oid, s in enumerate(path.indices) @@ -832,7 +832,7 @@ def canonicalize_subscripts(self) -> SegmentedTensorProduct: This is useful to identify equivalent descriptors. """ - subscripts = stp.Subscripts.canonicalize(self.subscripts) + subscripts = Subscripts.canonicalize(self.subscripts) return self.add_or_rename_modes(subscripts) def add_or_rename_modes( @@ -849,7 +849,7 @@ def add_or_rename_modes( Returns: SegmentedTensorProduct: The new descriptor with the renamed modes. """ - subscripts = stp.Subscripts(subscripts) + subscripts = Subscripts(subscripts) if subscripts.is_equivalent(self.subscripts): d = SegmentedTensorProduct.from_subscripts(subscripts) @@ -858,7 +858,7 @@ def add_or_rename_modes( for path in self.paths: d.insert_path_( len(d.paths), - stp.Path(indices=path.indices, coefficients=path.coefficients), + Path(indices=path.indices, coefficients=path.coefficients), ) return d @@ -902,7 +902,7 @@ def add_or_rename_modes( D.insert_path_( len(D.paths), - stp.Path( + Path( indices=path.indices, coefficients=np.reshape( path.coefficients, @@ -926,7 +926,7 @@ def add_or_transpose_modes( Returns: SegmentedTensorProduct: The new descriptor with the transposed modes. """ - subscripts = stp.Subscripts.complete_wildcards(subscripts, self.subscripts) + subscripts = Subscripts.complete_wildcards(subscripts, self.subscripts) if dims is None: dims = dict() @@ -966,7 +966,7 @@ def add_or_transpose_modes( perm = [old.index(ch) for ch in new] d.set_paths( [ - stp.Path( + Path( indices=path.indices, coefficients=np.transpose(path.coefficients, perm), ) @@ -994,7 +994,7 @@ def append_modes_to_all_operands( ) if not all(ch in dims for ch in modes): raise ValueError(f"expected dimensions for all new modes {modes}.") - subscripts = stp.Subscripts.from_operands( + subscripts = Subscripts.from_operands( [ope + modes for ope in self.subscripts.operands], coefficients=self.subscripts.coefficients, ) @@ -1038,7 +1038,7 @@ def permute_segments( segments=[self.operands[operand][i] for i in perm], ) new_paths = [ - stp.Path( + Path( indices=tuple( sid if oid != operand else perm.index(sid) for oid, sid in enumerate(path.indices) @@ -1120,7 +1120,7 @@ def filter_shape(shape: tuple[int, ...], subscripts: str) -> tuple[int, ...]: for path in self.paths: d.insert_path_( len(d.paths), - stp.Path( + Path( indices=path.indices, coefficients=np.reshape( path.coefficients, @@ -1165,7 +1165,7 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: ): return ( self.add_or_transpose_modes( - stp.Subscripts.from_operands( + Subscripts.from_operands( self.subscripts.operands, mode + self.coefficient_subscripts.replace(mode, ""), ) @@ -1209,7 +1209,7 @@ def split_mode(self, mode: str, size: int) -> SegmentedTensorProduct: if self.coefficient_subscripts.startswith(mode): coefficients = np.split(coefficients, num_subdivisions, axis=0)[i] d.insert_path_( - len(d.paths), stp.Path(indices=indices, coefficients=coefficients) + len(d.paths), Path(indices=indices, coefficients=coefficients) ) logger.debug(f"Split {mode} in {self}: got {d}") @@ -1259,7 +1259,7 @@ def normalize_paths_for_operand(self, operand: int) -> SegmentedTensorProduct: c = path.coefficients / np.sqrt(total_variance) else: c = 0.0 * path.coefficients - new_paths.append(stp.Path(indices=path.indices, coefficients=c)) + new_paths.append(Path(indices=path.indices, coefficients=c)) d = dataclasses.replace(self, paths=new_paths) logger.debug(f"Normalized paths for operand {operand} in {self}: got {d}") @@ -1284,7 +1284,7 @@ def fuse_paths_with_same_indices(self) -> SegmentedTensorProduct: return dataclasses.replace( self, paths=[ - stp.Path(indices=indices, coefficients=coefficients) + Path(indices=indices, coefficients=coefficients) for indices, coefficients in paths.items() ], ) @@ -1300,7 +1300,7 @@ def consolidate_paths(self) -> SegmentedTensorProduct: else: paths[path.indices] = path.coefficients paths = [ - stp.Path(indices=indices, coefficients=coefficients) + Path(indices=indices, coefficients=coefficients) for indices, coefficients in paths.items() if not np.all(coefficients == 0) ] @@ -1329,7 +1329,7 @@ def f(ii: tuple[int, ...]) -> tuple[int, ...]: return dataclasses.replace( self, paths=[ - stp.Path( + Path( indices=f(path.indices), coefficients=path.coefficients, ) @@ -1367,7 +1367,7 @@ def make_global_perm(perm: tuple[int, ...]) -> tuple[int, ...]: for i, oid in enumerate(operands): new_indices[oid] = indices[operands[perm[i]]] paths.append( - stp.Path( + Path( indices=new_indices, coefficients=path.coefficients / len(permutations), ) @@ -1496,8 +1496,8 @@ def make_new_path( old_indices: tuple[int, ...], # old segment indices (one per operand) sub_indices: dict[str, int], coefficients: np.ndarray, - ) -> stp.Path: - return stp.Path( + ) -> Path: + return Path( [ offsets[sid] + ravel_multi_index( @@ -1680,7 +1680,7 @@ def _consolidate_pair_of_modes(self, m: str, n: str) -> SegmentedTensorProduct: c, c.shape[:i] + (c.shape[i] * c.shape[i + 1],) + c.shape[i + 2 :] ) d1.insert_path_( - len(d1.paths), stp.Path(indices=path.indices, coefficients=c) + len(d1.paths), Path(indices=path.indices, coefficients=c) ) else: d1.set_paths(d0.paths) @@ -1699,7 +1699,7 @@ def round_coefficients_to_rational( d = copy.deepcopy(self) d.set_paths( [ - stp.Path( + Path( indices=path.indices, coefficients=round_to_rational(path.coefficients, max_denominator), ) @@ -1720,7 +1720,7 @@ def round_coefficients_to_sqrt_rational( d = copy.deepcopy(self) d.set_paths( [ - stp.Path( + Path( indices=path.indices, coefficients=round_to_sqrt_rational( path.coefficients, max_denominator @@ -1743,7 +1743,7 @@ def modify_coefficients( d = copy.deepcopy(self) d.set_paths( [ - stp.Path(indices=path.indices, coefficients=f(path.coefficients)) + Path(indices=path.indices, coefficients=f(path.coefficients)) for path in d.paths ] ) diff --git a/cuequivariance/tests/segmented_polynomials/dot_test.py b/cuequivariance/tests/segmented_polynomials/dot_test.py index 1e1bf2bd..193645d1 100644 --- a/cuequivariance/tests/segmented_polynomials/dot_test.py +++ b/cuequivariance/tests/segmented_polynomials/dot_test.py @@ -15,7 +15,7 @@ import numpy as np import cuequivariance as cue -import cuequivariance.segmented_polynomials as stp +import cuequivariance.segmented_polynomials as sp from cuequivariance.group_theory import descriptors @@ -29,16 +29,16 @@ def test_dot1(): d2.add_path(None, None, None, c=np.random.randn(2), dims={"b": 3}) d2.add_path(None, 0, None, c=np.random.randn(3)) - d3 = stp.dot(d1, d2, (1, 0)) + d3 = sp.dot(d1, d2, (1, 0)) assert d3.subscripts == "iab,ak,b,+ijk" x0, x2 = np.random.randn(d1.operands[0].size), np.random.randn(d1.operands[2].size) y1 = np.random.randn(d2.operands[1].size) - tmp = stp.compute_last_operand(d1.move_operand_last(1), x0, x2) - z0 = stp.compute_last_operand(d2, tmp, y1) + tmp = sp.compute_last_operand(d1.move_operand_last(1), x0, x2) + z0 = sp.compute_last_operand(d2, tmp, y1) - z1 = stp.compute_last_operand(d3, x0, x2, y1) + z1 = sp.compute_last_operand(d3, x0, x2, y1) np.testing.assert_allclose(z0, z1) @@ -63,16 +63,16 @@ def make_examples(): def test_dot2(): dx, dy = make_examples() - dxy = stp.dot(dx, dy, (3, 1)) + dxy = sp.dot(dx, dy, (3, 1)) assert dxy.subscripts == "uvw,iu,jv,w,l,mw+ijklm" x0, x1, x2 = [np.random.randn(dx.operands[i].size) for i in range(3)] y0, y2 = [np.random.randn(dy.operands[i].size) for i in [0, 2]] - tmp = stp.compute_last_operand(dx, x0, x1, x2) - A = stp.compute_last_operand(dy, y0, tmp, y2) + tmp = sp.compute_last_operand(dx, x0, x1, x2) + A = sp.compute_last_operand(dy, y0, tmp, y2) - B = stp.compute_last_operand(dxy, x0, x1, x2, y0, y2) + B = sp.compute_last_operand(dxy, x0, x1, x2, y0, y2) np.testing.assert_allclose(A, B) @@ -80,13 +80,13 @@ def test_dot2(): def test_trace(): dx, dy = make_examples() - d1 = stp.dot(dx, dy, (3, 1)) + d1 = sp.dot(dx, dy, (3, 1)) d1 = d1.canonicalize_subscripts() d1 = d1.sort_paths() assert dy.subscripts == "w,kw,l,mw+klm" dy = dy.add_or_rename_modes("a_xa_y_za+xyz") - d2 = stp.trace(stp.dot(dx, dy), (3, 4 + 1)) + d2 = sp.trace(sp.dot(dx, dy), (3, 4 + 1)) d2 = d2.canonicalize_subscripts() d2 = d2.sort_paths() diff --git a/cuequivariance/tests/segmented_polynomials/subscripts_test.py b/cuequivariance/tests/segmented_polynomials/subscripts_test.py index 54fdf480..572455f6 100644 --- a/cuequivariance/tests/segmented_polynomials/subscripts_test.py +++ b/cuequivariance/tests/segmented_polynomials/subscripts_test.py @@ -14,23 +14,23 @@ # limitations under the License. import pytest -import cuequivariance.segmented_polynomials as stp +import cuequivariance.segmented_polynomials as sp def test_subscripts(): with pytest.raises(ValueError): - stp.Subscripts("#$%@") + sp.Subscripts("#$%@") with pytest.raises(ValueError): - stp.Subscripts("Zu") # uppercase not supported anymore + sp.Subscripts("Zu") # uppercase not supported anymore with pytest.raises(ValueError): - stp.Subscripts("uZ") # uppercase after lowercase + sp.Subscripts("uZ") # uppercase after lowercase with pytest.raises(ValueError): - stp.Subscripts("uZ+ij+kl") # multiple + signs + sp.Subscripts("uZ+ij+kl") # multiple + signs - subscripts = stp.Subscripts("ui,vj,uvk+ijk") + subscripts = sp.Subscripts("ui,vj,uvk+ijk") assert subscripts.canonicalize() == "ui,vj,uvk+ijk" assert subscripts.coefficients == "ijk" @@ -45,5 +45,5 @@ def test_subscripts(): def test_canonicalize(): - assert stp.Subscripts("ui").canonicalize() == "uv" - assert stp.Subscripts("ab,ad+bd").canonicalize() == "ui,uj+ij" + assert sp.Subscripts("ui").canonicalize() == "uv" + assert sp.Subscripts("ab,ad+bd").canonicalize() == "ui,uj+ij" diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py index 257e4eaf..38c33b13 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/tensor_product.py @@ -21,7 +21,7 @@ import torch import torch.fx -from cuequivariance import segmented_polynomials as stp +import cuequivariance as cue logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class TensorProduct(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, *, device: Optional[torch.device] = None, math_dtype: Optional[torch.dtype] = None, @@ -198,7 +198,7 @@ def disable_type_conv(t): def _tensor_product_fx( - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, optimize_einsums: bool, @@ -313,7 +313,7 @@ def _sum(tensors, *, shape=None, like=None): elif num_inputs == 0: class _no_input(torch.nn.Module): - def __init__(self, descriptor: stp.SegmentedTensorProduct): + def __init__(self, descriptor: cue.SegmentedTensorProduct): super().__init__() for pid, path in enumerate(descriptor.paths): @@ -408,7 +408,7 @@ def forward(self, args: List[torch.Tensor]): class _Wrapper(torch.nn.Module): - def __init__(self, module: torch.nn.Module, descriptor: stp.SegmentedTensorProduct): + def __init__(self, module: torch.nn.Module, descriptor: cue.SegmentedTensorProduct): super().__init__() self.module = CALL_DISPATCHERS[descriptor.num_operands - 1](module) self.descriptor = descriptor @@ -418,7 +418,7 @@ def forward(self, args: List[torch.Tensor]): def _tensor_product_cuda( - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, ) -> torch.nn.Module: @@ -463,7 +463,7 @@ def _tensor_product_cuda( return TensorProductUniform4x1d(d, device, math_dtype) supported_targets = [ - stp.Subscripts(subscripts) + cue.segmented_polynomials.Subscripts(subscripts) for subscripts in [ "u__uw_w", "_v_vw_w", @@ -481,7 +481,9 @@ def _tensor_product_cuda( try: descriptor, perm = next( - stp.dispatch(descriptor, supported_targets, "permute_all_but_last") + cue.segmented_polynomials.dispatch( + descriptor, supported_targets, "permute_all_but_last" + ) ) except StopIteration: raise NotImplementedError( @@ -505,7 +507,7 @@ def _permutation_module(permutation: Tuple[int, ...]): class FusedTensorProductOp3(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, perm: Tuple[int, int], device: Optional[torch.device], math_dtype: torch.dtype, @@ -560,7 +562,7 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: class FusedTensorProductOp4(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, perm: Tuple[int, int, int], device: Optional[torch.device], math_dtype: torch.dtype, @@ -618,7 +620,7 @@ def forward( class TensorProductUniform1d(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, ): @@ -692,7 +694,7 @@ def forward( class TensorProductUniform3x1dIndexed(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, ): @@ -749,7 +751,7 @@ def forward( class TensorProductUniform4x1dIndexed(torch.nn.Module): def __init__( self, - descriptor: stp.SegmentedTensorProduct, + descriptor: cue.SegmentedTensorProduct, device: Optional[torch.device], math_dtype: torch.dtype, ): diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index 272ab53e..a2ed0831 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -16,7 +16,6 @@ import torch import cuequivariance as cue -import cuequivariance.segmented_polynomials as stp import cuequivariance_torch as cuet from cuequivariance import descriptors from cuequivariance_torch._tests.utils import ( @@ -32,10 +31,10 @@ def make_descriptors(): ).polynomial.tensor_products yield [d1, d2, d3] - d1 = stp.SegmentedTensorProduct.from_subscripts(",,") + d1 = cue.SegmentedTensorProduct.from_subscripts(",,") d1.add_path(None, None, None, c=2.0) - d3 = stp.SegmentedTensorProduct.from_subscripts(",,,,") + d3 = cue.SegmentedTensorProduct.from_subscripts(",,,,") d3.add_path(None, None, None, None, None, c=3.0) yield [d1, d3] @@ -60,7 +59,7 @@ def make_descriptors(): @pytest.mark.parametrize("ds", make_descriptors()) @pytest.mark.parametrize("dtype, math_dtype, tol", settings1) def test_primitive_indexed_symmetric_tensor_product_cuda_vs_fx( - ds: list[stp.SegmentedTensorProduct], dtype, math_dtype, tol: float, batch_size: int + ds: list[cue.SegmentedTensorProduct], dtype, math_dtype, tol: float, batch_size: int ): use_fallback = not torch.cuda.is_available() @@ -153,7 +152,7 @@ def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: b @pytest.mark.parametrize("mode", export_modes) @pytest.mark.parametrize("use_fallback", [True, False]) def test_export( - ds: list[stp.SegmentedTensorProduct], + ds: list[cue.SegmentedTensorProduct], mode: str, use_fallback: bool, tmp_path, From a1a5c7ffd45dbed55c5b9d0cfbf46a222ea85404 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Mar 2025 09:23:24 -0700 Subject: [PATCH 102/107] add content in changelog --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 940f96ef..62f07f0f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,13 @@ - Rename `SegmentedTensorProduct.flop_cost` in `flops` - Rename `SegmentedTensorProduct.memory_cost` in `memory` - Removed `IrrepsArray` in favor of `RepArray` +- Change folder structure of cuequivariance and cuequivariance-jax. Now the main subfolders are `segmented_polynomials` and `group_theory` +- Deprecate `cue.EquivariantTensorProduct` in favor of `cue.EquivariantPolynomial` +- The descriptors return `cue.EquivariantPolynomial` instead of `cue.EquivariantTensorProduct` + +# Added +- Class `cue.SegmentedOperand`, `cue.SegmentedPolynomial` +- Class `cue.EquivariantPolynomial` that contains a `cue.SegmentedPolynomial` and the `cue.Rep` of its inputs and outputs ## 0.3.0 (2025-03-05) From a3da853bf43551e67919478ddc9df73389be9be2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Mar 2025 10:10:55 -0700 Subject: [PATCH 103/107] improve repr --- .../segmented_polynomials/segmented_polynomial.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index 75a645c4..a94f8591 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -158,7 +158,14 @@ def __rmul__(self, factor: float) -> SegmentedPolynomial: return self.__mul__(factor) def __repr__(self): - return self.to_string([f"[{ope.size}]" for ope in self.operands]) + buffer_names = [] + for ope in self.operands: + if ope.all_same_segment_shape(): + shape = ",".join(str(d) for d in ope.segment_shape) + buffer_names.append(f"[{ope.size}:{ope.num_segments}⨯({shape})]") + else: + buffer_names.append(f"[{ope.size}]") + return self.to_string(buffer_names) def to_string(self, buffer_names: list[str] | None = None) -> str: buffer_txts = ( From 8f4e9518fc7d666446b7e8e36cf5b6b78834e5de Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Mar 2025 23:01:49 -0700 Subject: [PATCH 104/107] improve repr --- .../segmented_polynomial.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index a94f8591..4495064b 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -158,13 +158,26 @@ def __rmul__(self, factor: float) -> SegmentedPolynomial: return self.__mul__(factor) def __repr__(self): + def sfmt(shape: tuple[int, ...]) -> str: + return "(" + ",".join(str(d) for d in shape) + ")" + buffer_names = [] for ope in self.operands: if ope.all_same_segment_shape(): - shape = ",".join(str(d) for d in ope.segment_shape) - buffer_names.append(f"[{ope.size}:{ope.num_segments}⨯({shape})]") + buffer_names.append( + f"[{ope.size}:{ope.num_segments}⨯{sfmt(ope.segment_shape)}]" + ) else: - buffer_names.append(f"[{ope.size}]") + txts = [] + n = 20 + for s in ope.segments: + txts.append(sfmt(s)) + if len("+".join(txts)) > n: + txts.pop() + break + if len(txts) < len(ope.segments): + txts.append("...") + buffer_names.append(f"[{ope.size}:{'+'.join(txts)}]") return self.to_string(buffer_names) def to_string(self, buffer_names: list[str] | None = None) -> str: @@ -180,9 +193,7 @@ def to_string(self, buffer_names: list[str] | None = None) -> str: header = ( " ".join(buffer_txts[: self.num_inputs]) + " -> " - + " ".join( - buffer_txts[self.num_inputs : self.num_inputs + self.num_outputs] - ) + + " ".join(buffer_txts[self.num_inputs :]) ) def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct) -> str: From d3aeb2d2698b64fc4b278bd9d7b2cb67c9b0f895 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sat, 15 Mar 2025 03:30:56 -0700 Subject: [PATCH 105/107] optimize a bit execution time of spherical_harmomic descriptors for large l --- .../descriptors/spherical_harmonics_.py | 2 +- .../segmented_tensor_product.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py index 890bb159..6f8b15b1 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/spherical_harmonics_.py @@ -59,7 +59,7 @@ def spherical_harmonics( indices = poly_degrees_to_path_indices(degrees) d.add_path(*indices, i, c=coeff) - d = d.symmetrize_operands(range(ell)) + d = d.symmetrize_operands(range(ell), force=True) return cue.EquivariantPolynomial( [ diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index 05c18a76..adfebb84 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -1337,7 +1337,9 @@ def f(ii: tuple[int, ...]) -> tuple[int, ...]: ], ).consolidate_paths() - def symmetrize_operands(self, operands: Sequence[int]) -> SegmentedTensorProduct: + def symmetrize_operands( + self, operands: Sequence[int], force: bool = False + ) -> SegmentedTensorProduct: """Symmetrize the specified operands permuting the indices.""" operands = sorted(set(operands)) if len(operands) < 2: @@ -1345,16 +1347,17 @@ def symmetrize_operands(self, operands: Sequence[int]) -> SegmentedTensorProduct permutations = list(itertools.permutations(range(len(operands)))) - # optimization: skip if already symmetric def make_global_perm(perm: tuple[int, ...]) -> tuple[int, ...]: p = list(range(self.num_operands)) for i, j in enumerate(perm): p[operands[i]] = operands[j] return tuple(p) - symmetries: list[tuple[int, ...]] = self.symmetries() - if all(make_global_perm(perm) in symmetries for perm in permutations): - return self + if not force: + # check if the tensor product is already symmetric + symmetries: list[tuple[int, ...]] = self.symmetries() + if all(make_global_perm(perm) in symmetries for perm in permutations): + return self d = self.sort_indices_for_identical_operands(operands) From dd052f5a6314e86f9092fa7fb2815095fae1d4fc Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sat, 15 Mar 2025 16:46:48 -0700 Subject: [PATCH 106/107] rename operations --- .../group_theory/descriptors/rotations.py | 12 ++-- .../mace/symmetric_contractions.py | 2 +- .../segmented_polynomial.py | 63 +++++++++---------- .../segmented_tensor_product.py | 6 +- .../equivariant_polynomial_test.py | 4 +- .../tests/segmented_polynomials/dot_test.py | 4 +- .../segmented_polynomial_test.py | 18 +++--- .../segmented_polynomial_ops_impl.py | 4 +- .../segmented_polynomial_vanilla_impl.py | 2 +- .../tests/segmented_polynomial_test.py | 2 +- .../cuequivariance_torch/operations/linear.py | 2 +- .../operations/tp_channel_wise.py | 2 +- .../operations/tp_fully_connected.py | 4 +- .../primitives/equivariant_tensor_product.py | 4 +- .../tests/operations/channel_wise_test.py | 2 +- .../tests/operations/fully_connected_test.py | 2 +- .../tests/primitives/primitive_export_test.py | 10 +-- .../symmetric_tensor_product_test.py | 4 +- .../tests/primitives/tensor_product_test.py | 14 ++--- docs/tutorials/beta.rst | 2 +- 20 files changed, 78 insertions(+), 85 deletions(-) diff --git a/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py index fd4740fc..f6525970 100644 --- a/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py +++ b/cuequivariance/cuequivariance/group_theory/descriptors/rotations.py @@ -69,8 +69,8 @@ def yxy_rotation( where l is the maximum L in the input and output irreps. """ # gamma, beta, input, A - cbio = xy_rotation(irreps, lmax).polynomial.tensor_products[0][1] - aio = y_rotation(irreps, lmax).polynomial.tensor_products[0][1] # alpha, A, output + cbio = xy_rotation(irreps, lmax).polynomial.operations[0][1] + aio = y_rotation(irreps, lmax).polynomial.operations[0][1] # alpha, A, output cbiao = cue.segmented_polynomials.dot( cbio, aio, (3, 1) ) # gamma, beta, input, alpha, output @@ -95,8 +95,8 @@ def xy_rotation( Rotation around the y-axis followed by rotation around the x-axis """ - cio = y_rotation(irreps, lmax).polynomial.tensor_products[0][1] # gamma, input, A - bio = x_rotation(irreps, lmax).polynomial.tensor_products[0][1] # beta, A, output + cio = y_rotation(irreps, lmax).polynomial.operations[0][1] # gamma, input, A + bio = x_rotation(irreps, lmax).polynomial.operations[0][1] # beta, A, output cibo = cue.segmented_polynomials.dot(cio, bio, (2, 1)) # gamma, input, beta, output cbio = cibo.move_operand(1, 2) # gamma, beta, input, output return cue.EquivariantPolynomial( @@ -218,10 +218,10 @@ def x_rotation( """ assert irreps.irrep_class in [cue.SO3, cue.O3] - dy = y_rotation(irreps, lmax).polynomial.tensor_products[0][1] + dy = y_rotation(irreps, lmax).polynomial.operations[0][1] dz90 = fixed_axis_angle_rotation( irreps, np.array([0.0, 0.0, 1.0]), np.pi / 2.0 - ).polynomial.tensor_products[0][1] + ).polynomial.operations[0][1] d = cue.segmented_polynomials.dot( cue.segmented_polynomials.dot(dy, dz90, (1, 1)), dz90, (1, 1) ) diff --git a/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py index c6fb322c..71dcb263 100644 --- a/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py +++ b/cuequivariance/cuequivariance/group_theory/experimental/mace/symmetric_contractions.py @@ -68,7 +68,7 @@ def symmetric_contraction( 1, None, ) - for _, d in pol.polynomial.tensor_products + for _, d in pol.polynomial.operations ], axis=1, ) diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py index 4495064b..c876a62a 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_polynomial.py @@ -44,13 +44,13 @@ class SegmentedPolynomial: inputs: tuple[cue.SegmentedOperand, ...] outputs: tuple[cue.SegmentedOperand, ...] - tensor_products: tuple[tuple[cue.Operation, cue.SegmentedTensorProduct], ...] + operations: tuple[tuple[cue.Operation, cue.SegmentedTensorProduct], ...] def __init__( self, inputs: Sequence[cue.SegmentedOperand], outputs: Sequence[cue.SegmentedOperand], - tensor_products: Sequence[ + operations: Sequence[ tuple[cue.Operation | Sequence[int], cue.SegmentedTensorProduct] ], ): @@ -58,8 +58,8 @@ def __init__( outputs = tuple(outputs) operands = inputs + outputs - _tensor_products = [] - for opt, stp in tensor_products: + tmp = [] + for opt, stp in operations: opt = cue.Operation(opt) assert isinstance(opt, cue.Operation) assert isinstance(stp, cue.SegmentedTensorProduct) @@ -68,15 +68,15 @@ def __init__( assert operand == operands[buffer_id] out_oid, bid = opt.output_operand_buffer(len(inputs)) - _tensor_products.append( + tmp.append( (bid, opt.move_operand_last(out_oid), stp.move_operand_last(out_oid)) ) - _tensor_products = sorted(_tensor_products) - _tensor_products = [(opt, stp) for _, opt, stp in _tensor_products] + tmp = sorted(tmp) + operations = [(opt, stp) for _, opt, stp in tmp] object.__setattr__(self, "inputs", inputs) object.__setattr__(self, "outputs", outputs) - object.__setattr__(self, "tensor_products", tuple(_tensor_products)) + object.__setattr__(self, "operations", tuple(operations)) @classmethod def eval_last_operand(cls, stp: cue.SegmentedTensorProduct): @@ -125,14 +125,14 @@ def from_stps( ) def __hash__(self) -> int: - return hash((self.inputs, self.outputs, self.tensor_products)) + return hash((self.inputs, self.outputs, self.operations)) def __eq__(self, value) -> bool: assert isinstance(value, SegmentedPolynomial) return ( self.inputs == value.inputs and self.outputs == value.outputs - and self.tensor_products == value.tensor_products + and self.operations == value.operations ) def __lt__(self, value) -> bool: @@ -140,18 +140,18 @@ def __lt__(self, value) -> bool: return ( self.inputs, self.outputs, - self.tensor_products, + self.operations, ) < ( value.inputs, value.outputs, - value.tensor_products, + value.operations, ) def __mul__(self, factor: float) -> SegmentedPolynomial: return SegmentedPolynomial( self.inputs, self.outputs, - tuple((ope, factor * stp) for ope, stp in self.tensor_products), + tuple((ope, factor * stp) for ope, stp in self.operations), ) def __rmul__(self, factor: float) -> SegmentedPolynomial: @@ -207,7 +207,7 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct) -> str: items = [f"[{stp.coefficient_subscripts}]"] + items[:-1] return "·".join(items) + "➜" + out - lines = ["│ " + f(ope, stp) for ope, stp in self.tensor_products] + lines = ["│ " + f(ope, stp) for ope, stp in self.operations] if len(lines) > 0: lines[-1] = "╰─" + lines[-1][2:] @@ -218,7 +218,7 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct) -> str: + "─" * (n - len(line)) + "─ " + f"num_paths={stp.num_paths} {format_dimensions_dict(stp.get_dimensions_dict())}" - for line, (_, stp) in zip(lines, self.tensor_products) + for line, (_, stp) in zip(lines, self.operations) ] lines = ["╭ " + header] + lines @@ -232,7 +232,7 @@ def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]: np.zeros(inferred_shape + (ope.size,), dtype=inferred_dtype) for ope in self.outputs ] - for ope, stp in self.tensor_products: + for ope, stp in self.operations: oid, bid = ope.output_operand_buffer(self.num_inputs) outputs[bid - self.num_inputs] += ( cue.segmented_polynomials.compute_last_operand( @@ -267,7 +267,7 @@ def map_tensor_products( tuple[cue.Operation, cue.SegmentedTensorProduct] | None, ], ) -> SegmentedPolynomial: - new_tensor_products = [f(ope, stp) for ope, stp in self.tensor_products] + new_tensor_products = [f(ope, stp) for ope, stp in self.operations] new_tensor_products = tuple( ope_stp for ope_stp in new_tensor_products if ope_stp is not None ) @@ -282,7 +282,7 @@ def fuse_stps(self) -> SegmentedPolynomial: ) groups = itertools.groupby( - poly.tensor_products, + poly.operations, key=lambda x: ( x[0], x[1].operands_and_subscripts, @@ -325,14 +325,14 @@ def f(ope: cue.Operation, stp: cue.SegmentedTensorProduct): def used_inputs(self) -> list[bool]: """Inputs used in the polynomial. (List of boolean values)""" return [ - any(buffer in ope.buffers for ope, _ in self.tensor_products) + any(buffer in ope.buffers for ope, _ in self.operations) for buffer in range(self.num_inputs) ] def used_outputs(self) -> list[bool]: """Outputs used in the polynomial. (List of boolean values)""" return [ - any(buffer in ope.buffers for ope, _ in self.tensor_products) + any(buffer in ope.buffers for ope, _ in self.operations) for buffer in range(self.num_inputs, self.num_inputs + self.num_outputs) ] @@ -357,7 +357,7 @@ def select_buffers(self, keep: list[bool]) -> SegmentedPolynomial: # Filter tensor products that write to buffers we want to keep # and remap the buffer indices new_tensor_products = [] - for ope, stp in self.tensor_products: + for ope, stp in self.operations: # Check if the operation writes to a buffer we want to keep bid = ope.output_buffer(self.num_inputs) if keep[bid]: @@ -394,7 +394,7 @@ def compute_only(self, keep: list[bool]) -> SegmentedPolynomial: self.outputs, # on purpose, we keep all outputs [ (ope, stp) - for ope, stp in self.tensor_products + for ope, stp in self.operations if keep[ope.output_buffer(self.num_inputs) - self.num_inputs] ], ) @@ -431,7 +431,7 @@ def stack( tensor_products: list[tuple[cue.Operation, cue.SegmentedTensorProduct]] = [] for index, pol in enumerate(polys): - for ope, stp in pol.tensor_products: + for ope, stp in pol.operations: stp = copy.deepcopy(stp) for oid, buffer in enumerate(ope.buffers): if stacked[buffer]: @@ -450,7 +450,7 @@ def squeeze_modes(self) -> SegmentedPolynomial: return SegmentedPolynomial.from_default_buffers( self.inputs, self.outputs, - [(ope, stp.squeeze_modes()) for ope, stp in self.tensor_products], + [(ope, stp.squeeze_modes()) for ope, stp in self.operations], ) def flatten_coefficient_modes(self) -> SegmentedPolynomial: @@ -458,10 +458,7 @@ def flatten_coefficient_modes(self) -> SegmentedPolynomial: return SegmentedPolynomial.from_default_buffers( self.inputs, self.outputs, - [ - (ope, stp.flatten_coefficient_modes()) - for ope, stp in self.tensor_products - ], + [(ope, stp.flatten_coefficient_modes()) for ope, stp in self.operations], ) def jvp(self, has_tangent: list[bool]) -> SegmentedPolynomial: @@ -472,7 +469,7 @@ def jvp(self, has_tangent: list[bool]) -> SegmentedPolynomial: sym_poly = self.symmetrize_for_identical_operands() new_tps = [] - for ope, stp in sym_poly.tensor_products: + for ope, stp in sym_poly.operations: jvps = ope.jvp(has_tangent) permutations: list[tuple[int, ...]] = stp.symmetries() for multiplicator, ope in cue.Operation.group_by_operational_symmetries( @@ -495,7 +492,7 @@ def transpose( assert len(has_cotangent) == self.num_outputs new_tps = [] - for ope, stp in self.tensor_products: + for ope, stp in self.operations: ope = ope.transpose(is_undefined_primal, has_cotangent) if ope is not None: new_tps.append((ope, stp)) @@ -522,7 +519,7 @@ def backward( def flops(self, batch_size: int = 1) -> int: """Compute the number of floating point operations in the polynomial.""" n = 0 - for ope, stp in self.tensor_products: + for ope, stp in self.operations: oid, _ = ope.output_operand_buffer(self.num_inputs) n += stp.flops(oid) return batch_size * n @@ -534,7 +531,7 @@ def memory(self, batch_sizes: list[int]) -> int: def buffer_segments(self, buffer: int) -> list[tuple[int, ...]]: segments = None - for ope, stp in self.tensor_products: + for ope, stp in self.operations: if buffer in ope.buffers: ope = stp.operands[ope.buffers.index(buffer)] if segments is None: @@ -554,7 +551,7 @@ def symmetrize_for_identical_operands(self) -> SegmentedPolynomial: """ symmetrized_tensor_products = [] - for ope, stp in self.tensor_products: + for ope, stp in self.operations: for set_of_operands in ope.operands_with_identical_buffers(): stp = stp.symmetrize_operands(set_of_operands) stp = stp.sort_paths() diff --git a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py index adfebb84..7682c239 100644 --- a/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py +++ b/cuequivariance/cuequivariance/segmented_polynomials/segmented_tensor_product.py @@ -334,7 +334,7 @@ def to_text(self, coefficient_formatter=lambda x: f"{x}") -> str: ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1") - ... ).polynomial.tensor_products[0][1] + ... ).polynomial.operations[0][1] >>> d = d.flatten_coefficient_modes() >>> print(d.to_text()) uvw,u,v,w sizes=320,16,16,16 num_segments=5,4,4,4 num_paths=16 u=4 v=4 w=4 @@ -450,7 +450,7 @@ def to_base64(self, extended: bool = False) -> str: ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1"), ... cue.Irreps("SO3", "4x0+4x1") - ... ).polynomial.tensor_products[0][1] + ... ).polynomial.operations[0][1] >>> print(d.to_base64()) eJytkstuwjAQRX/F8r...lTF2zlX91/fHyvj2Z4= """ @@ -475,7 +475,7 @@ def get_dims(self, m: str) -> set[int]: ... cue.Irreps("SO3", "4x0+8x1"), ... cue.Irreps("SO3", "3x0+3x1"), ... cue.Irreps("SO3", "5x0+7x1") - ... ).polynomial.tensor_products[0][1] + ... ).polynomial.operations[0][1] >>> d.get_dims("u") {8, 4} >>> d.get_dims("v") diff --git a/cuequivariance/tests/group_theory/equivariant_polynomial_test.py b/cuequivariance/tests/group_theory/equivariant_polynomial_test.py index 72d5cbfb..392b8091 100644 --- a/cuequivariance/tests/group_theory/equivariant_polynomial_test.py +++ b/cuequivariance/tests/group_theory/equivariant_polynomial_test.py @@ -118,9 +118,9 @@ def test_symmetric_contraction(): assert poly.num_inputs == 2 assert poly.num_outputs == 1 - [_, _, _, (_, d)] = poly.polynomial.tensor_products + [_, _, _, (_, d)] = poly.polynomial.operations assert d.num_paths == 437 poly = poly.polynomial.unsymmetrize_for_identical_operands() - [_, _, _, (_, d)] = poly.tensor_products + [_, _, _, (_, d)] = poly.operations assert d.num_paths == 105 diff --git a/cuequivariance/tests/segmented_polynomials/dot_test.py b/cuequivariance/tests/segmented_polynomials/dot_test.py index 193645d1..4d5a68d1 100644 --- a/cuequivariance/tests/segmented_polynomials/dot_test.py +++ b/cuequivariance/tests/segmented_polynomials/dot_test.py @@ -49,11 +49,11 @@ def make_examples(): cue.Irreps("SO3", "4x0 + 3x1"), cue.Irreps("SO3", "3x0 + 5x1"), irreps_middle, - ).polynomial.tensor_products[0][1] + ).polynomial.operations[0][1] assert dx.subscripts == "uvw,iu,jv,kw+ijk" dy = descriptors.channelwise_tensor_product( irreps_middle, cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0 + 1") - ).polynomial.tensor_products[0][1] + ).polynomial.operations[0][1] dy = dy.squeeze_modes("v") assert dy.subscripts == "u,iu,j,ku+ijk" dy = dy.add_or_rename_modes("w_kw_l_mw+klm") diff --git a/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py b/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py index 01f200d8..57ca9d93 100644 --- a/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py +++ b/cuequivariance/tests/segmented_polynomials/segmented_polynomial_test.py @@ -41,8 +41,8 @@ def test_init_segmented_polynomial(): assert poly.num_inputs == 2 assert poly.num_outputs == 1 assert poly.num_operands == 3 - assert len(poly.tensor_products) == 1 - assert poly.tensor_products[0] == (cue.Operation((0, 1, 2)), stp) + assert len(poly.operations) == 1 + assert poly.operations[0] == (cue.Operation((0, 1, 2)), stp) def test_polynomial_equality(): @@ -157,12 +157,12 @@ def test_consolidate(): consolidated = poly.consolidate() # Should have fused the two tensor products - assert len(consolidated.tensor_products) == 1 + assert len(consolidated.operations) == 1 # Coefficients should have been combined for each path - assert len(consolidated.tensor_products[0][1].paths) == 2 + assert len(consolidated.operations[0][1].paths) == 2 # The coefficients should have been added - assert consolidated.tensor_products[0][1].paths[0].coefficients == 2.0 - assert consolidated.tensor_products[0][1].paths[1].coefficients == -4.0 + assert consolidated.operations[0][1].paths[0].coefficients == 2.0 + assert consolidated.operations[0][1].paths[1].coefficients == -4.0 def test_stack(): @@ -198,7 +198,7 @@ def test_stack(): assert [ope.size for ope in stacked.operands] == [2, 2, 4] - [(_, stp)] = stacked.tensor_products + [(_, stp)] = stacked.operations assert stp.operands[0].num_segments == 2 assert stp.operands[1].num_segments == 2 assert stp.operands[2].num_segments == 4 @@ -401,7 +401,7 @@ def test_symmetrize_identical_operands(): # Check that we get 0.5 x0*y1 + 0.5 x1*y0 # This means we should have two paths with coefficient 0.5 - [(_, sym_stp)] = sym_poly.tensor_products + [(_, sym_stp)] = sym_poly.operations assert len(sym_stp.paths) == 2 # Check that we get 0.5 x0*y1 + 0.5 x1*y0 assert sym_stp.paths[0].coefficients == 0.5 @@ -412,7 +412,7 @@ def test_symmetrize_identical_operands(): # Test that unsymmetrize returns to original form unsym_poly = sym_poly.unsymmetrize_for_identical_operands() - [(_, unsym_stp)] = unsym_poly.tensor_products + [(_, unsym_stp)] = unsym_poly.operations assert len(unsym_stp.paths) == 1 assert unsym_stp.paths[0].coefficients == 1.0 assert unsym_stp.paths[0].indices == (0, 1, 0) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py index 19e69ec1..c9d80248 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py @@ -52,7 +52,7 @@ def log(msg: str): assert b.ndim == 2, f"Buffer {b.shape} must be 2D" # Reshape buffers to 3D by using the STP informations - for ope, stp in polynomial.tensor_products: + for ope, stp in polynomial.operations: if len(stp.subscripts.modes()) != 1: return log(f"Unsupported STP: {stp}") if not stp.all_same_segment_shape(): @@ -117,7 +117,7 @@ def log(msg: str): operations = [] paths = [] - for ope, stp in polynomial.tensor_products: + for ope, stp in polynomial.operations: operations.append(Operation(ope.buffers, len(paths), stp.num_paths)) for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py index 4aadbca4..26eea25c 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_vanilla_impl.py @@ -52,7 +52,7 @@ def gather(i: int, x: jax.Array) -> jax.Array: return buffer.at[idx].add(x) return buffer + x - for operation, d in polynomial.tensor_products: + for operation, d in polynomial.operations: ope_out, b_out = operation.output_operand_buffer(num_inputs) out = outputs_shape_dtype[b_out - num_inputs] diff --git a/cuequivariance_jax/tests/segmented_polynomial_test.py b/cuequivariance_jax/tests/segmented_polynomial_test.py index df3571bb..a2b05aad 100644 --- a/cuequivariance_jax/tests/segmented_polynomial_test.py +++ b/cuequivariance_jax/tests/segmented_polynomial_test.py @@ -95,7 +95,7 @@ def test_vmap(): e = cue.descriptors.full_tensor_product( cue.Irreps("SO3", "1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "1") ) - d = e.polynomial.tensor_products[0][1] + d = e.polynomial.operations[0][1] def f(x1, x2, i1): return cuex.segmented_polynomial( diff --git a/cuequivariance_torch/cuequivariance_torch/operations/linear.py b/cuequivariance_torch/cuequivariance_torch/operations/linear.py index 5f3c5ffe..fc124023 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/linear.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/linear.py @@ -62,7 +62,7 @@ def __init__( math_dtype = math_dtype or dtype e = descriptors.linear(irreps_in, irreps_out) - assert e.polynomial.tensor_products[0][1].subscripts == "uv,iu,iv" + assert e.polynomial.operations[0][1].subscripts == "uv,iu,iv" self.irreps_in = irreps_in self.irreps_out = irreps_out diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py index 67c73ee1..eb541bdc 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_channel_wise.py @@ -72,7 +72,7 @@ def __init__( irreps_in1, irreps_in2, filter_irreps_out ) descriptor, irreps_out = ( - e.polynomial.tensor_products[0][1], + e.polynomial.operations[0][1], e.operands[-1].irreps, ) assert descriptor.subscripts == "uv,iu,jv,kuv+ijk" diff --git a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py index f4d0c29b..7a4a724b 100644 --- a/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py +++ b/cuequivariance_torch/cuequivariance_torch/operations/tp_fully_connected.py @@ -73,13 +73,13 @@ def __init__( e = descriptors.fully_connected_tensor_product( irreps_in1, irreps_in2, irreps_out ) - assert e.polynomial.tensor_products[0][1].subscripts == "uvw,iu,jv,kw+ijk" + assert e.polynomial.operations[0][1].subscripts == "uvw,iu,jv,kw+ijk" self.irreps_in1 = irreps_in1 self.irreps_in2 = irreps_in2 self.irreps_out = irreps_out - self.weight_numel = e.polynomial.tensor_products[0][1].operands[0].size + self.weight_numel = e.polynomial.operations[0][1].operands[0].size self.shared_weights = shared_weights self.internal_weights = ( diff --git a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py index 34045bdc..b099c299 100644 --- a/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py +++ b/cuequivariance_torch/cuequivariance_torch/primitives/equivariant_tensor_product.py @@ -160,7 +160,7 @@ def __init__( # TODO: remove this when re-design if isinstance(e, cue.EquivariantPolynomial): assert e.num_outputs == 1 - for ope, stp in e.polynomial.tensor_products: + for ope, stp in e.polynomial.operations: inputs = list(range(e.num_inputs)) output = e.num_inputs expected = tuple( @@ -170,7 +170,7 @@ def __init__( ) assert ope.buffers == expected, f"{ope.buffers} != {expected}" e = cue.EquivariantTensorProduct( - [stp for _, stp in e.polynomial.tensor_products], e.inputs + e.outputs + [stp for _, stp in e.polynomial.operations], e.inputs + e.outputs ) if not isinstance(layout_in, tuple): diff --git a/cuequivariance_torch/tests/operations/channel_wise_test.py b/cuequivariance_torch/tests/operations/channel_wise_test.py index 3b150aba..f88591aa 100644 --- a/cuequivariance_torch/tests/operations/channel_wise_test.py +++ b/cuequivariance_torch/tests/operations/channel_wise_test.py @@ -71,7 +71,7 @@ def test_channel_wise_fwd( d = descriptors.channelwise_tensor_product( irreps1, irreps2, irreps3 - ).polynomial.tensor_products[0][1] + ).polynomial.operations[0][1] d = d.squeeze_modes("v") assert d.subscripts == "u,iu,j,ku+ijk" if layout == cue.mul_ir: diff --git a/cuequivariance_torch/tests/operations/fully_connected_test.py b/cuequivariance_torch/tests/operations/fully_connected_test.py index e62f5b53..fad1de08 100644 --- a/cuequivariance_torch/tests/operations/fully_connected_test.py +++ b/cuequivariance_torch/tests/operations/fully_connected_test.py @@ -72,7 +72,7 @@ def test_fully_connected( d = descriptors.fully_connected_tensor_product( irreps1, irreps2, irreps3 - ).polynomial.tensor_products[0][1] + ).polynomial.operations[0][1] if layout == cue.mul_ir: d = d.add_or_transpose_modes("uvw,ui,vj,wk+ijk") m2 = cuet.TensorProduct(d, math_dtype=torch.float64, use_fallback=True).to(device) diff --git a/cuequivariance_torch/tests/primitives/primitive_export_test.py b/cuequivariance_torch/tests/primitives/primitive_export_test.py index 969d9543..9fa32228 100644 --- a/cuequivariance_torch/tests/primitives/primitive_export_test.py +++ b/cuequivariance_torch/tests/primitives/primitive_export_test.py @@ -42,7 +42,7 @@ def test_script_symmetric_contraction(mode, tmp_path): e = cue.descriptors.symmetric_contraction( 32 * cue.Irreps("SO3", "0 + 1"), 32 * cue.Irreps("SO3", "0 + 1"), [1, 2, 3] ) - ds = [stp for _, stp in e.polynomial.tensor_products] + ds = [stp for _, stp in e.polynomial.operations] batch = 12 x0 = torch.randn(3, ds[0].operands[0].size, device=device, dtype=torch.float32) @@ -66,7 +66,7 @@ def test_script_fused_tp_3(mode, tmp_path): cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") ) - .polynomial.tensor_products[0][1] + .polynomial.operations[0][1] .flatten_coefficient_modes() .squeeze_modes("v") ) @@ -99,7 +99,7 @@ def test_script_fused_tp_4(mode, tmp_path): cue.descriptors.fully_connected_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") ) - .polynomial.tensor_products[0][1] + .polynomial.operations[0][1] .flatten_coefficient_modes() .squeeze_modes("v") .permute_operands([1, 2, 0, 3]) @@ -136,7 +136,7 @@ def test_script_uniform_tp_3(mode, tmp_path): cue.descriptors.full_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1") ) - .polynomial.tensor_products[0][1] + .polynomial.operations[0][1] .flatten_coefficient_modes() .squeeze_modes("v") ) @@ -170,7 +170,7 @@ def test_script_uniform_tp_4(mode, tmp_path): cue.descriptors.channelwise_tensor_product( cue.Irreps("SO3", "32x1"), cue.Irreps("SO3", "1"), cue.Irreps("SO3", "32x1") ) - .polynomial.tensor_products[0][1] + .polynomial.operations[0][1] .flatten_coefficient_modes() .squeeze_modes("v") ) diff --git a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py index a2ed0831..42af7b66 100644 --- a/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/symmetric_tensor_product_test.py @@ -28,7 +28,7 @@ def make_descriptors(): [(_, d1), (_, d2), (_, d3)] = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] - ).polynomial.tensor_products + ).polynomial.operations yield [d1, d2, d3] d1 = cue.SegmentedTensorProduct.from_subscripts(",,") @@ -120,7 +120,7 @@ def test_math_dtype(dtype: torch.dtype, math_dtype: torch.dtype, use_fallback: b e = descriptors.symmetric_contraction( cue.Irreps("SO3", "0 + 1 + 2"), cue.Irreps("SO3", "0"), [1, 2, 3] ) - ds = [stp for _, stp in e.polynomial.tensor_products] + ds = [stp for _, stp in e.polynomial.operations] m = cuet.IWeightedSymmetricTensorProduct( ds, math_dtype=math_dtype, device=device, use_fallback=use_fallback ) diff --git a/cuequivariance_torch/tests/primitives/tensor_product_test.py b/cuequivariance_torch/tests/primitives/tensor_product_test.py index 337d9f71..3f9bc25f 100644 --- a/cuequivariance_torch/tests/primitives/tensor_product_test.py +++ b/cuequivariance_torch/tests/primitives/tensor_product_test.py @@ -31,27 +31,23 @@ def make_descriptors(): cue.Irreps("O3", "4x0e + 4x1o"), cue.Irreps("O3", "6x0e + 6x1o"), cue.Irreps("O3", "5x0e + 5x1o + 5x2e + 5x1e"), - ).polynomial.tensor_products[0][1] + ).polynomial.operations[0][1] - yield descriptors.spherical_harmonics(cue.SO3(1), [2]).polynomial.tensor_products[ - 0 - ][1] - yield descriptors.spherical_harmonics(cue.SO3(1), [3]).polynomial.tensor_products[ - 0 - ][1] + yield descriptors.spherical_harmonics(cue.SO3(1), [2]).polynomial.operations[0][1] + yield descriptors.spherical_harmonics(cue.SO3(1), [3]).polynomial.operations[0][1] d = descriptors.channelwise_tensor_product( cue.Irreps("SU2", "3x1/2 + 4x1"), cue.Irreps("SU2", "1/2 + 1 + 3/2"), cue.Irreps("SU2", "1/2 + 1"), - ).polynomial.tensor_products[0][1] + ).polynomial.operations[0][1] yield d d = descriptors.channelwise_tensor_product( cue.Irreps("SO3", "32x1 + 32x2"), cue.Irreps("SO3", "0 + 1"), cue.Irreps("SO3", "0 + 1"), - ).polynomial.tensor_products[0][1] + ).polynomial.operations[0][1] yield d for subscripts in [ diff --git a/docs/tutorials/beta.rst b/docs/tutorials/beta.rst index 8a1463ef..5dc5dcd2 100644 --- a/docs/tutorials/beta.rst +++ b/docs/tutorials/beta.rst @@ -63,7 +63,7 @@ Again for segmented tensor product with 3 or 4 operands with one mode, we can us ) if device.type == "cuda": - d = e.polynomial.tensor_products[0][1] + ((_, d),) = e.polynomial.operations m = TensorProductUniform4x1dIndexed(d, device, torch.float32) x0 = torch.randn(16, e.inputs[0].dim, device=device) From 505f37a0c5b91d4cfa4b8690581c3c4906e557b2 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Sat, 15 Mar 2025 17:07:22 -0700 Subject: [PATCH 107/107] results --- .../segmented_polynomial.py | 16 +++-- .../segmented_polynomial_ops_impl.py | 68 ++++++++++++++----- 2 files changed, 61 insertions(+), 23 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py index 0a0b4bb4..93ed9285 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial.py @@ -312,7 +312,6 @@ def segmented_polynomial_impl( assert all(polynomial.used_buffers()) polynomial = polynomial.unsymmetrize_for_identical_operands() - outputs = None kwargs = dict( inputs=inputs, outputs_shape_dtype=outputs_shape_dtype, @@ -325,13 +324,16 @@ def segmented_polynomial_impl( assert impl in ("auto", "cuda", "jax") - if platform == "cuda" and impl in ("auto", "cuda"): - outputs, msg = segmented_polynomial_ops_impl(**kwargs) + outputs = None + if platform == "cuda": + if impl in ("auto", "cuda"): + outputs = segmented_polynomial_ops_impl(**kwargs) + if impl == "cuda" and not outputs.is_ok(): + raise RuntimeError(f"Failed to use CUDA implementation: {outputs.msg}") + outputs = outputs.unwrap_or(None) else: - msg = f"{platform=}, {impl=}" - - if impl == "cuda" and outputs is None: - raise RuntimeError(f"Failed to use CUDA implementation: {msg}") + if impl == "cuda": + raise RuntimeError(f"{impl=} but platform is {platform}") if outputs is None: outputs = segmented_polynomial_vanilla_impl(**kwargs) diff --git a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py index c9d80248..1534c1fc 100644 --- a/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py +++ b/cuequivariance_jax/cuequivariance_jax/segmented_polynomials/segmented_polynomial_ops_impl.py @@ -14,6 +14,7 @@ # limitations under the License. import logging import re +from dataclasses import dataclass import jax import jax.numpy as jnp @@ -24,6 +25,37 @@ logger = logging.getLogger(__name__) +@dataclass +class Ok: + outputs: list[jax.Array] + + def unwrap(self) -> list[jax.Array]: + return self.outputs + + def unwrap_or(self, default): + return self.outputs + + def is_ok(self) -> bool: + return True + + +@dataclass +class Err: + msg: str + + def unwrap(self): + raise ValueError(self.msg) + + def unwrap_or(self, default): + return default + + def is_ok(self) -> bool: + return False + + +Result = Ok | Err + + def sanitize_string(s): s = re.sub(r"[^A-Za-z0-9_]", "", s) if s == "" or s[0].isdigit(): @@ -39,10 +71,10 @@ def segmented_polynomial_ops_impl( polynomial: cue.SegmentedPolynomial, math_dtype: jnp.dtype, name: str, -) -> tuple[list[jax.Array] | None, str]: - def log(msg: str): +) -> Result: + def make_error(msg: str) -> Err: logger.info(f"[{name}] {msg}") - return None, name + return Err(msg) assert polynomial.num_inputs == len(buffer_index) - len(outputs_shape_dtype) assert polynomial.num_outputs == len(outputs_shape_dtype) @@ -54,9 +86,9 @@ def log(msg: str): # Reshape buffers to 3D by using the STP informations for ope, stp in polynomial.operations: if len(stp.subscripts.modes()) != 1: - return log(f"Unsupported STP: {stp}") + return make_error(f"Unsupported STP: {stp}") if not stp.all_same_segment_shape(): - return log(f"Unsupported STP: {stp}") + return make_error(f"Unsupported STP: {stp}") for i, operand in zip(ope.buffers, stp.operands): b = buffers[i] @@ -64,30 +96,34 @@ def log(msg: str): if b.ndim == 2: b = buffers[i] = reshape(b, shape) if b.shape != shape: - return log(f"Shape mismatch: {b.shape} != {shape} for {i} {stp} {ope}") + return make_error( + f"Shape mismatch: {b.shape} != {shape} for {i} {stp} {ope}" + ) for b in buffers: if b.dtype.type not in {jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16}: - return log(f"Unsupported buffer type: {b.dtype}") + return make_error(f"Unsupported buffer type: {b.dtype}") for i in indices: # TODO: this restriction will be removed by MR109 if i.dtype.type != jnp.int32: - return log(f"Unsupported index type: {i.dtype}") + return make_error(f"Unsupported index type: {i.dtype}") if not all(b.ndim == 3 for b in buffers): - return log("All buffers must be used") + return make_error("All buffers must be used") if len({b.shape[2] for b in buffers}.union({1})) != 2: - return log(f"Buffer shapes not compatible {[b.shape for b in buffers]}") + return make_error(f"Buffer shapes not compatible {[b.shape for b in buffers]}") # TODO: this restriction will be removed by MR109 if max(b.shape[2] for b in buffers) % 32 != 0: - return log(f"Extend must be a multiple of 32, got {[b.shape for b in buffers]}") + return make_error( + f"Extend must be a multiple of 32, got {[b.shape for b in buffers]}" + ) math_dtype = jnp.dtype(math_dtype) if math_dtype.type not in {jnp.float32, jnp.float64}: - return log(f"Unsupported math_dtype: {math_dtype}") + return make_error(f"Unsupported math_dtype: {math_dtype}") batch_size = 1 for i, b in zip(buffer_index, buffers): @@ -102,7 +138,7 @@ def log(msg: str): ): if b.dtype.type not in {jnp.float32, jnp.float64}: if i >= 0 or b.shape[0] != batch_size: - return log( + return make_error( f"Output buffer {b.shape} of type {b.dtype} and buffer index {i} is not supported" ) @@ -113,7 +149,7 @@ def log(msg: str): tensor_product_uniform_1d_jit, ) except ImportError as e: - return log(f"cuequivariance_ops_jax is not installed: {e}") + return make_error(f"cuequivariance_ops_jax is not installed: {e}") operations = [] paths = [] @@ -122,7 +158,7 @@ def log(msg: str): for path in stp.paths: paths.append(Path(path.indices, path.coefficients.item())) - log("Using the uniform 1d kernel of cuequivariance_ops_jax 🚀\n" + str(polynomial)) + logger.info(f"Using the uniform 1d kernel for '{name}' 🚀\n" + str(polynomial)) outputs = tensor_product_uniform_1d_jit( buffers[: polynomial.num_inputs], buffers[polynomial.num_inputs :], @@ -133,4 +169,4 @@ def log(msg: str): math_dtype=math_dtype, name=sanitize_string(name), ) - return [jnp.reshape(x, (x.shape[0], x.shape[1] * x.shape[2])) for x in outputs], "" + return Ok([jnp.reshape(x, (x.shape[0], x.shape[1] * x.shape[2])) for x in outputs])