From 8d4b786ec50bf3a543489173a18d2d7ef7bb0acc Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 10 Mar 2025 21:51:37 +0100 Subject: [PATCH 1/4] Integrate Release branch into main (#92) * improve changelog * add prefixes in the changelog --- CHANGELOG.md | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b34de624..b2bed14d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,28 +9,36 @@ The main changes are: 3. Provides optional fused scatter/gather for the inputs and outputs 4. 🎉 We observed a ~3x speedup for MACE with cuEquivariance-JAX v0.3.0 compared to cuEquivariance-Torch v0.2.0 🎉 2. [Torch] Adds torch.compile support -3. [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (see tutorial in the documentation) -4. [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (see tutorial in the documentation) +3. [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel + 1. enable the new kernel by setting the environement variable `CUEQUIVARIANCE_OPS_USE_JIT=1` +4. [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d + 1. this is a temporary API that will change, `cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed` ### Breaking Changes -- In `cuex.equivariant_tensor_product`, the arguments `dtype_math` and `dtype_output` are renamed to `math_dtype` and `output_dtype` respectively. This change adds consistency with the rest of the library. -- In `cuex.equivariant_tensor_product`, the arguments `algorithm`, `precision`, `use_custom_primitive` and `use_custom_kernels` have been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argument `impl: str` has been added instead to select the implementation. -- Removed `cue.TensorProductExecution` and added `cue.Operation` which is more lightweight and better aligned with the backend. -- Removed `cuex.symmetric_tensor_product`. The `cuex.tensor_product` function now handles any non-homogeneous polynomials. -- Removed `cuex.flax_linen.Linear` to reduce maintenance burden. Use `cue.descriptor.linear` together with `cuex.equivariant_tensor_product` instead. -- The batching support (`jax.vmap`) of `cuex.equivariant_tensor_product` is now limited to specific use cases. -- The interface of `cuex.tensor_product` has changed. It now takes a list of `tuple[cue.Operation, cue.SegmentedTensorProduct]` instead of a single `cue.SegmentedTensorProduct`. This change allows `cuex.tensor_product` to execute any type of non-homogeneous polynomials. +- [Torch/JAX] Removed `cue.TensorProductExecution` and added `cue.Operation` which is more lightweight and better aligned with the backend. +- [JAX] In `cuex.equivariant_tensor_product`, the arguments `dtype_math` and `dtype_output` are renamed to `math_dtype` and `output_dtype` respectively. This change adds consistency with the rest of the library. +- [JAX] In `cuex.equivariant_tensor_product`, the arguments `algorithm`, `precision`, `use_custom_primitive` and `use_custom_kernels` have been removed. This change avoids a proliferation of arguments that are not used in all implementations. An argument `impl: str` has been added instead to select the implementation. +- [JAX] Removed `cuex.symmetric_tensor_product`. The `cuex.tensor_product` function now handles any non-homogeneous polynomials. +- [JAX] The batching support (`jax.vmap`) of `cuex.equivariant_tensor_product` is now limited to specific use cases. +- [JAX] The interface of `cuex.tensor_product` has changed. It now takes a list of `tuple[cue.Operation, cue.SegmentedTensorProduct]` instead of a single `cue.SegmentedTensorProduct`. This change allows `cuex.tensor_product` to execute any type of non-homogeneous polynomials. +- [JAX] Removed `cuex.flax_linen.Linear` to reduce maintenance burden. Use `cue.descriptor.linear` together with `cuex.equivariant_tensor_product` instead. +```python +e = cue.descriptors.linear(input.irreps, output_irreps) +w = self.param(name, jax.random.normal, (e.inputs[0].dim,), input.dtype) +output = cuex.equivariant_tensor_product(e, w, input) +``` ### Fixed -- Identified a bug in the CUDA kernel and disabled CUDA kernel for `cuet.TransposeSegments` and `cuet.TransposeIrrepsLayout`. -- Fixed `cue.descriptor.full_tensor_product` which was ignoring the `irreps3_filter` argument. -- Fixed a rare bug with `np.bincount` when using an old version of numpy. The input is now flattened to make it work with all versions. +- [Torch/JAX] Fixed `cue.descriptor.full_tensor_product` which was ignoring the `irreps3_filter` argument. +- [Torch/JAX] Fixed a rare bug with `np.bincount` when using an old version of numpy. The input is now flattened to make it work with all versions. +- [Torch] Identified a bug in the CUDA kernel and disabled CUDA kernel for `cuet.TransposeSegments` and `cuet.TransposeIrrepsLayout`. ### Added -- Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations. -- Added an `indices` argument to `cuex.equivariant_tensor_product` and `cuex.tensor_product` to handle the scatter/gather fusion. -- Added `__mul__` to `cue.EquivariantTensorProduct` to allow rescaling the equivariant tensor product. -- Added a uniform 1d kernel with scatter/gather fusion under `cuet.primitives.tensor_product.TensorProductUniform4x1dIndexed` and `cuet.primitives.tensor_product.TensorProductUniform3x1dIndexed`. +- [Torch/JAX] Added `__mul__` to `cue.EquivariantTensorProduct` to allow rescaling the equivariant tensor product. +- [JAX] Added JAX Bindings to the uniform 1d JIT kernel. This kernel handles any kind of non-homogeneous polynomials as long as the contraction pattern (subscripts) has only one mode. It handles batched/shared/indexed input/output. The indexed input/output is processed through atomic operations. +- [JAX] Added an `indices` argument to `cuex.equivariant_tensor_product` and `cuex.tensor_product` to handle the scatter/gather fusion. +- [Torch] Beta limited Torch bindings to the new JIT Uniform 1d kernel (enable the new kernel by setting the environement variable `CUEQUIVARIANCE_OPS_USE_JIT=1`) +- [Torch] Implements scatter/gather fusion through a beta API for Uniform 1d (this is a temporary API that will change, `cuequivariance_torch.primitives.tensor_product.TensorProductUniform4x1dIndexed`) ## 0.2.0 (2025-01-24) From 31a7dceec0bc8de59614a60840b639041f471c4d Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Tue, 11 Mar 2025 09:59:39 -0700 Subject: [PATCH 2/4] add infos about the cuda packages --- README.md | 2 +- docs/index.rst | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 20fb6c5c..6cbf76c2 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ pip install cuequivariance-jax pip install cuequivariance-torch pip install cuequivariance # Installs only the core non-ML components -# CUDA kernels for different CUDA versions +# CUDA kernels pip install cuequivariance-ops-jax-cu12 pip install cuequivariance-ops-torch-cu12 # or cu11 ``` diff --git a/docs/index.rst b/docs/index.rst index 6468a5b4..39e6a221 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -39,14 +39,16 @@ The easiest way to install cuEquivariance is from `PyPi `_ us pip install cuequivariance-torch pip install cuequivariance # Installs only the core non-ML components - # CUDA kernels for different CUDA versions + # CUDA kernels pip install cuequivariance-ops-torch-cu11 pip install cuequivariance-ops-torch-cu12 + pip install cuequivariance-ops-jax-cu12 Requirements ------------ -``cuequivariance-ops-torch-*`` packages are only available for Linux x86_64 and require PyTorch 2.4.0 or later. + - ``cuequivariance-ops-torch-*`` packages are available for Linux x86_64/aarch64 and require PyTorch 2.4.0 or later. + - ``cuequivariance-ops-jax-cu12`` package is only available for Linux x86_64 and requires JAX 0.5.0 or later. Organization ------------ From 8476920304935fb22cadf50d13e03b94ff71d995 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Thu, 13 Mar 2025 18:05:11 +0100 Subject: [PATCH 3/4] Update index.rst --- docs/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.rst b/docs/index.rst index 39e6a221..2de84a60 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,7 +47,7 @@ The easiest way to install cuEquivariance is from `PyPi `_ us Requirements ------------ - - ``cuequivariance-ops-torch-*`` packages are available for Linux x86_64/aarch64 and require PyTorch 2.4.0 or later. + - ``cuequivariance-ops-torch-*`` packages are available for Linux x86_64/aarch64 and require PyTorch 2.4.0 or later. aarch64 is only available for Python 3.12. - ``cuequivariance-ops-jax-cu12`` package is only available for Linux x86_64 and requires JAX 0.5.0 or later. Organization From daae1bd3c3863f1a7b5857c7fa09118a570b7373 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 17 Mar 2025 01:25:56 -0700 Subject: [PATCH 4/4] rename changelog to release notes --- docs/changelog.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/changelog.md b/docs/changelog.md index e1a84149..c38e409a 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,3 @@ -# Changelog +# Release Notes ```{include} ../CHANGELOG.md