From df5474ce79c5ce065ce7d9da95d248695f507d24 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 6 Mar 2025 13:20:41 -0800 Subject: [PATCH 1/2] improve changelog --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b34de624..376d3dc8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,8 @@ The main changes are: 3. Provides optional fused scatter/gather for the inputs and outputs 4. 🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉 2. [Torch] Adds torch.compile support -3. [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (see tutorial in the documentation) -4. [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (see tutorial in the documentation) +3. [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable `CUEQUIVARIANCE_OPS_USE_JIT=1`) +4. [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change, `cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed`) ### Breaking Changes - In `cuex.equivariant_tensor_product`, the arguments `dtype_math` and `dtype_output` are renamed to `math_dtype` and `output_dtype` respectively. This change adds consistency with the rest of the library. From 32cd7b629a270b0b1210022dc7f5cb5836b28610 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Fri, 7 Mar 2025 16:18:33 +0100 Subject: [PATCH 2/2] add prefixes in the changelog --- CHANGELOG.md | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 376d3dc8..b2bed14d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,28 +9,36 @@ The main changes are: 3. Provides optional fused scatter/gather for the inputs and outputs 4. 🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉 2. [Torch] Adds torch.compile support -3. [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable `CUEQUIVARIANCE_OPS_USE_JIT=1`) -4. [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change, `cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed`) +3. [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel + 1. enable the new kernel by setting the environement variable `CUEQUIVARIANCE_OPS_USE_JIT=1` +4. [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d + 1. this is a temporary API that will change, `cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed` ### Breaking Changes -- In `cuex.equivariant_tensor_product`, the arguments `dtype_math` and `dtype_output` are renamed to `math_dtype` and `output_dtype` respectively. This change adds consistency with the rest of the library. -- In `cuex.equivariant_tensor_product`, the arguments `algorithm`, `precision`, `use_custom_primitive` and `use_custom_kernels` have been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argument `impl: str` has been added instead to select the implementation. -- Removed `cue.TensorProductExecution` and added `cue.Operation` which is more lightweight and better aligned with the backend. -- Removed `cuex.symmetric_tensor_product`. The `cuex.tensor_product` function now handles any non-homogeneous polynomials. -- Removed `cuex.flax_linen.Linear` to reduce maintenance burden. Use `cue.descriptor.linear` together with `cuex.equivariant_tensor_product` instead. -- The batching support (`jax.vmap`) of `cuex.equivariant_tensor_product` is now limited to specific use cases. -- The interface of `cuex.tensor_product` has changed. It now takes a list of `tuple[cue.Operation, cue.SegmentedTensorProduct]` instead of a single `cue.SegmentedTensorProduct`. This change allows `cuex.tensor_product` to execute any type of non-homogeneous polynomials. +- [Torch/JAX] Removed `cue.TensorProductExecution` and added `cue.Operation` which is more lightweight and better aligned with the backend. +- [JAX] In `cuex.equivariant_tensor_product`, the arguments `dtype_math` and `dtype_output` are renamed to `math_dtype` and `output_dtype` respectively. This change adds consistency with the rest of the library. +- [JAX] In `cuex.equivariant_tensor_product`, the arguments `algorithm`, `precision`, `use_custom_primitive` and `use_custom_kernels` have been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argument `impl: str` has been added instead to select the implementation. +- [JAX] Removed `cuex.symmetric_tensor_product`. The `cuex.tensor_product` function now handles any non-homogeneous polynomials. +- [JAX] The batching support (`jax.vmap`) of `cuex.equivariant_tensor_product` is now limited to specific use cases. +- [JAX] The interface of `cuex.tensor_product` has changed. It now takes a list of `tuple[cue.Operation, cue.SegmentedTensorProduct]` instead of a single `cue.SegmentedTensorProduct`. This change allows `cuex.tensor_product` to execute any type of non-homogeneous polynomials. +- [JAX] Removed `cuex.flax_linen.Linear` to reduce maintenance burden. Use `cue.descriptor.linear` together with `cuex.equivariant_tensor_product` instead. +```python +e = cue.descriptors.linear(input.irreps, output_irreps) +w = self.param(name, jax.random.normal, (e.inputs[0].dim,), input.dtype) +output = cuex.equivariant_tensor_product(e, w, input) +``` ### Fixed -- Identified a bug in the CUDA kernel and disabled CUDA kernel for `cuet.TransposeSegments` and `cuet.TransposeIrrepsLayout`. -- Fixed `cue.descriptor.full_tensor_product` which was ignoring the `irreps3_filter` argument. -- Fixed a rare bug with `np.bincount` when using an old version of numpy. The input is now flattened to make it work with all versions. +- [Torch/JAX] Fixed `cue.descriptor.full_tensor_product` which was ignoring the `irreps3_filter` argument. +- [Torch/JAX] Fixed a rare bug with `np.bincount` when using an old version of numpy. The input is now flattened to make it work with all versions. +- [Torch] Identified a bug in the CUDA kernel and disabled CUDA kernel for `cuet.TransposeSegments` and `cuet.TransposeIrrepsLayout`. ### Added -- Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations. -- Added an `indices` argument to `cuex.equivariant_tensor_product` and `cuex.tensor_product` to handle the scatter/gather fusion. -- Added `__mul__` to `cue.EquivariantTensorProduct` to allow rescaling the equivariant tensor product. -- Added a uniform 1d kernel with scatter/gather fusion under `cuet.primitives.tensor_product.TensorProductUniform4x1dIndexed` and `cuet.primitives.tensor_product.TensorProductUniform3x1dIndexed`. +- [Torch/JAX] Added `__mul__` to `cue.EquivariantTensorProduct` to allow rescaling the equivariant tensor product. +- [JAX] Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations. +- [JAX] Added an `indices` argument to `cuex.equivariant_tensor_product` and `cuex.tensor_product` to handle the scatter/gather fusion. +- [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable `CUEQUIVARIANCE_OPS_USE_JIT=1`) +- [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change, `cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed`) ## 0.2.0 (2025-01-24)