From 7bbd3196df883aab4c1da0855e9f80636f1234c0 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 26 Jan 2026 00:33:58 -0800 Subject: [PATCH 1/4] use nnx.Optimizer --- cuequivariance_jax/examples/mace_nnx.py | 22 ++-- cuequivariance_jax/examples/test_examples.py | 117 +++++++++++++++++++ 2 files changed, 125 insertions(+), 14 deletions(-) diff --git a/cuequivariance_jax/examples/mace_nnx.py b/cuequivariance_jax/examples/mace_nnx.py index 72a8f74..8d234e9 100644 --- a/cuequivariance_jax/examples/mace_nnx.py +++ b/cuequivariance_jax/examples/mace_nnx.py @@ -528,22 +528,16 @@ def benchmark( mask=mask, ) - optimizer = optax.adam(1e-2) - opt_state = optimizer.init(nnx.state(model, nnx.Param)) + optimizer = nnx.Optimizer(model, optax.adam(1e-2), wrt=nnx.Param) @nnx.jit - def step(model, opt_state, batch_dict, target_E, target_F): + def step(model, optimizer, batch_dict, target_E, target_F): def loss_fn(model): E, F = model(batch_dict) return jnp.mean((E - target_E) ** 2) + jnp.mean((F - target_F) ** 2) - grad = nnx.grad(loss_fn)(model) - params = nnx.state(model, nnx.Param) - grad_state = nnx.state(grad, nnx.Param) - updates, opt_state_new = optimizer.update(grad_state, opt_state, params) - new_params = optax.apply_updates(params, updates) - nnx.update(model, new_params) - return opt_state_new + grads = nnx.grad(loss_fn)(model) + optimizer.update(model, grads) @nnx.jit def inference(model, batch_dict): @@ -553,12 +547,12 @@ def inference(model, batch_dict): runtime_per_inference = 0 if mode in ["train", "both"]: - opt_state = step(model, opt_state, batch_dict, target_E, target_F) - jax.block_until_ready(opt_state) + step(model, optimizer, batch_dict, target_E, target_F) + jax.block_until_ready(nnx.state(model)) t0 = time.perf_counter() for _ in range(10): - opt_state = step(model, opt_state, batch_dict, target_E, target_F) - jax.block_until_ready(opt_state) + step(model, optimizer, batch_dict, target_E, target_F) + jax.block_until_ready(nnx.state(model)) runtime_per_training_step = 1e3 * (time.perf_counter() - t0) / 10 if mode in ["inference", "both"]: diff --git a/cuequivariance_jax/examples/test_examples.py b/cuequivariance_jax/examples/test_examples.py index 0faed08..d44710a 100644 --- a/cuequivariance_jax/examples/test_examples.py +++ b/cuequivariance_jax/examples/test_examples.py @@ -226,6 +226,123 @@ def convert_linear(w, irreps_in, irreps_out): layer.readout.w[ir][...] = w +def test_mace_linen_nnx_training_equivalence(): + """Test Linen and NNX models produce identical params after training.""" + import optax + + num_features, num_species = 16, 3 + num_steps = 5 + learning_rate = 1e-2 + + config = dict( + num_layers=1, + num_features=num_features, + num_species=num_species, + max_ell=2, + correlation=2, + num_radial_basis=4, + interaction_irreps=cue.Irreps(cue.O3, "0e+1o+2e"), + hidden_irreps=cue.Irreps(cue.O3, "0e+1o"), + offsets=np.zeros(num_species), + cutoff=3.0, + epsilon=0.1, + skip_connection_first_layer=True, + ) + batch = _make_batch( + num_atoms=5, num_edges=10, num_species=num_species, num_graphs=1 + ) + + key = jax.random.key(123) + target_E = jax.random.normal(key, (1,)) + target_F = jax.random.normal(jax.random.split(key)[0], (5, 3)) + + # Initialize Linen model + linen_model = MACEModel(**config) + linen_params = linen_model.init(jax.random.key(0), batch) + + # Initialize NNX model with converted weights + nnx_model = MACEModelNNX(**config, dtype=jnp.float32, rngs=nnx.Rngs(0)) + _convert_linen_to_nnx(linen_params, nnx_model, config) + + # Verify initial outputs match (use 1e-3 tolerance like existing test) + E_linen_init, F_linen_init = linen_model.apply(linen_params, batch) + E_nnx_init, F_nnx_init = nnx_model(batch) + np.testing.assert_allclose(E_linen_init, E_nnx_init, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(F_linen_init, F_nnx_init, atol=1e-3, rtol=1e-3) + print("\nInitial outputs match (within 1e-3):") + print(f" E_linen: {E_linen_init}, E_nnx: {E_nnx_init}") + + # Train Linen model + def loss_fn_linen(w): + E, F = linen_model.apply(w, batch) + return jnp.mean((E - target_E) ** 2) + jnp.mean((F - target_F) ** 2) + + tx = optax.adam(learning_rate) + opt_state_linen = tx.init(linen_params) + for _ in range(num_steps): + grad = jax.grad(loss_fn_linen)(linen_params) + updates, opt_state_linen = tx.update(grad, opt_state_linen, linen_params) + linen_params = optax.apply_updates(linen_params, updates) + + # Train NNX model + def loss_fn_nnx(model): + E, F = model(batch) + return jnp.mean((E - target_E) ** 2) + jnp.mean((F - target_F) ** 2) + + optimizer_nnx = nnx.Optimizer(nnx_model, optax.adam(learning_rate), wrt=nnx.Param) + for _ in range(num_steps): + grads = nnx.grad(loss_fn_nnx)(nnx_model) + optimizer_nnx.update(nnx_model, grads) + + # Compare outputs after training + E_linen, F_linen = linen_model.apply(linen_params, batch) + E_nnx, F_nnx = nnx_model(batch) + + print(f"\nAfter {num_steps} training steps:") + print(f" Linen loss: {loss_fn_linen(linen_params):.6f}") + print(f" NNX loss: {loss_fn_nnx(nnx_model):.6f}") + print(f" E_linen: {E_linen}") + print(f" E_nnx: {E_nnx}") + print(f" max|E diff|: {jnp.max(jnp.abs(E_linen - E_nnx)):.2e}") + print(f" max|F diff|: {jnp.max(jnp.abs(F_linen - F_nnx)):.2e}") + + # Compare multiple parameters + print(" Parameter comparisons:") + + # 1. Embedding + linen_emb = linen_params["params"]["linear_embedding"] + nnx_emb = nnx_model.embedding[...] + emb_diff = jnp.max(jnp.abs(linen_emb - nnx_emb)) + print(f" embedding: max|diff| = {emb_diff:.2e}") + + # 2. Skip connection weights (linZ_skip_tp -> skip.w) + linen_skip = linen_params["params"]["layer_0"]["linZ_skip_tp"] + nnx_skip = nnx_model.layers[0].skip.w[...] + skip_diff = jnp.max(jnp.abs(linen_skip - nnx_skip)) + print(f" layer_0/skip.w: max|diff| = {skip_diff:.2e}") + + # 3. Symmetric contraction weights + linen_sc = linen_params["params"]["layer_0"]["symmetric_contraction"] + nnx_sc = nnx_model.layers[0].sc.w[...] + sc_diff = jnp.max(jnp.abs(linen_sc - nnx_sc)) + print(f" layer_0/sc.w: max|diff| = {sc_diff:.2e}") + + # 4. Radial MLP weights (first layer) + linen_mlp = linen_params["params"]["layer_0"]["MultiLayerPerceptron_0"]["Dense_0"][ + "kernel" + ] + nnx_mlp = nnx_model.layers[0].radial_mlp.linears[0][...] + mlp_diff = jnp.max(jnp.abs(linen_mlp - nnx_mlp)) + print(f" layer_0/radial_mlp[0]: max|diff| = {mlp_diff:.2e}") + + np.testing.assert_allclose(E_linen, E_nnx, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(F_linen, F_nnx, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(linen_emb, nnx_emb, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(linen_skip, nnx_skip, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(linen_sc, nnx_sc, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose(linen_mlp, nnx_mlp, atol=1e-3, rtol=1e-3) + + def test_nequip_model_basic(): """Test NEQUIP model.""" batch = _make_batch() From d65edeebdb2301360461cf8a04947385c4e7fda9 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 26 Jan 2026 01:42:21 -0800 Subject: [PATCH 2/4] require JAX 0.8.1+ --- cuequivariance_jax/pyproject.toml | 2 +- docs/index.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cuequivariance_jax/pyproject.toml b/cuequivariance_jax/pyproject.toml index 86aa4b7..663e41f 100644 --- a/cuequivariance_jax/pyproject.toml +++ b/cuequivariance_jax/pyproject.toml @@ -27,7 +27,7 @@ authors = [{ name = "NVIDIA Corporation" }] requires-python = ">=3.10" dependencies = [ "cuequivariance", - "jax", + "jax>=0.8.1", "packaging", "einops", ] diff --git a/docs/index.rst b/docs/index.rst index d8fbc01..bfdbbaa 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -63,7 +63,7 @@ Requirements - ``cuequivariance-ops-torch-*`` and ``cuequivariance-ops-jax-*`` packages are available for Linux x86_64/aarch64 - Python 3.10-3.14 is supported - PyTorch 2.4.0+ is required for torch packages - - JAX 0.5.0+ is required for jax packages + - JAX 0.8.1+ is required for jax packages Organization ------------ From a6cee9731854fad314cc396b76fc9400d16c660a Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 26 Jan 2026 08:34:49 -0800 Subject: [PATCH 3/4] remove file --- test_visualization.py | 49 ------------------------------------------- 1 file changed, 49 deletions(-) delete mode 100644 test_visualization.py diff --git a/test_visualization.py b/test_visualization.py deleted file mode 100644 index 5d714b2..0000000 --- a/test_visualization.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -"""Test script for the visualize_polynomial function (works without graphviz).""" - -import cuequivariance as cue - - -def test_visualization_api(): - """Test that the visualization function has the correct API without rendering.""" - from cuequivariance.segmented_polynomials import visualize_polynomial - - # Create a simple polynomial - sh_poly = cue.descriptors.spherical_harmonics(cue.SO3(1), [1, 2]).polynomial - - print("Testing visualize_polynomial API...") - print(f"Polynomial: {sh_poly}") - print(f" num_inputs: {sh_poly.num_inputs}") - print(f" num_outputs: {sh_poly.num_outputs}") - print(f" num_operations: {len(sh_poly.operations)}") - print() - - # Test error handling for wrong number of names - try: - visualize_polynomial(sh_poly, ["x", "y"], ["Y"]) # Too many input names - print("❌ Should have raised ValueError for wrong number of inputs") - except ValueError as e: - print(f"✓ Correctly raised ValueError: {e}") - - try: - visualize_polynomial(sh_poly, ["x"], ["Y", "Z"]) # Too many output names - print("❌ Should have raised ValueError for wrong number of outputs") - except ValueError as e: - print(f"✓ Correctly raised ValueError: {e}") - - # Test that it raises ImportError if graphviz is not installed - try: - graph = visualize_polynomial(sh_poly, ["x"], ["Y"]) - print("✓ graphviz is installed, graph created successfully") - print(f" Graph type: {type(graph)}") - # Print the DOT source - print("\nGenerated DOT source:") - print(graph.source) - except ImportError as e: - print(f"✓ Correctly raised ImportError when graphviz not installed: {e}") - - print("\n✓ All API tests passed!") - - -if __name__ == "__main__": - test_visualization_api() From 7045e1b1b06904d48d1adef0c27310700fe9ea32 Mon Sep 17 00:00:00 2001 From: Mario Geiger Date: Mon, 26 Jan 2026 08:42:32 -0800 Subject: [PATCH 4/4] rename functions --- cuequivariance_jax/cuequivariance_jax/ir_dict.py | 12 ++++++------ cuequivariance_jax/cuequivariance_jax/nnx.py | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/cuequivariance_jax/cuequivariance_jax/ir_dict.py b/cuequivariance_jax/cuequivariance_jax/ir_dict.py index 908664c..8ad9ec2 100644 --- a/cuequivariance_jax/cuequivariance_jax/ir_dict.py +++ b/cuequivariance_jax/cuequivariance_jax/ir_dict.py @@ -51,8 +51,8 @@ "segmented_polynomial_uniform_1d", "assert_mul_ir_dict", "mul_ir_dict", - "irreps_to_dict", - "dict_to_irreps", + "flat_to_dict", + "dict_to_flat", "irreps_add", "irreps_zeros_like", ] @@ -243,7 +243,7 @@ def mul_ir_dict(irreps: cue.Irreps, data: Any) -> dict[Irrep, Any]: return jax.tree.broadcast(data, {ir: None for _, ir in irreps}, lambda v: v is None) -def irreps_to_dict( +def flat_to_dict( irreps: cue.Irreps, data: Array, *, layout: str = "mul_ir" ) -> dict[Irrep, Array]: """Convert a flat array to dict[Irrep, Array] with shape (..., mul, ir.dim). @@ -266,7 +266,7 @@ def irreps_to_dict( >>> irreps = cue.Irreps(cue.O3, "128x0e + 64x1o") >>> batch = 32 >>> flat = jnp.ones((batch, irreps.dim)) - >>> d = irreps_to_dict(irreps, flat) + >>> d = flat_to_dict(irreps, flat) >>> d[cue.O3(0, 1)].shape (32, 128, 1) >>> d[cue.O3(1, -1)].shape @@ -287,7 +287,7 @@ def irreps_to_dict( return result -def dict_to_irreps(irreps: cue.Irreps, x: dict[Irrep, Array]) -> Array: +def dict_to_flat(irreps: cue.Irreps, x: dict[Irrep, Array]) -> Array: """Convert dict[Irrep, Array] back to a flat contiguous array. Flattens the (multiplicity, irrep_dim) dimensions and concatenates all irreps. @@ -305,7 +305,7 @@ def dict_to_irreps(irreps: cue.Irreps, x: dict[Irrep, Array]) -> Array: >>> batch = 32 >>> d = {cue.O3(0, 1): jnp.ones((batch, 128, 1)), ... cue.O3(1, -1): jnp.ones((batch, 64, 3))} - >>> flat = dict_to_irreps(irreps, d) + >>> flat = dict_to_flat(irreps, d) >>> flat.shape (32, 320) """ diff --git a/cuequivariance_jax/cuequivariance_jax/nnx.py b/cuequivariance_jax/cuequivariance_jax/nnx.py index 256497b..b1d0fba 100644 --- a/cuequivariance_jax/cuequivariance_jax/nnx.py +++ b/cuequivariance_jax/cuequivariance_jax/nnx.py @@ -26,8 +26,8 @@ import cuequivariance as cue from cuequivariance import Irrep +from . import ir_dict from .activation import normalize_function -from .ir_dict import assert_mul_ir_dict, dict_to_irreps, irreps_to_dict from .rep_array.rep_array_ import RepArray from .segmented_polynomials.segmented_polynomial import segmented_polynomial from .segmented_polynomials.utils import Repeats @@ -114,7 +114,7 @@ def __init__( self.w = nnx.Dict(w) def __call__(self, x: dict[Irrep, Array]) -> dict[Irrep, Array]: - assert_mul_ir_dict(self.irreps_in, x) + ir_dict.assert_mul_ir_dict(self.irreps_in, x) x0 = jax.tree.leaves(x)[0] shape = x0.shape[:-2] @@ -129,7 +129,7 @@ def __call__(self, x: dict[Irrep, Array]) -> dict[Irrep, Array]: / jnp.sqrt(w[...].shape[0]) ) - assert_mul_ir_dict(self.irreps_out, y) + ir_dict.assert_mul_ir_dict(self.irreps_out, y) return y @@ -245,11 +245,11 @@ def __init__( def __call__( self, x: dict[Irrep, Array], num_index_counts: Array ) -> dict[Irrep, Array]: - assert_mul_ir_dict(self.irreps_in, x) + ir_dict.assert_mul_ir_dict(self.irreps_in, x) # Convert dict (batch, mul, ir.dim) -> ir_mul flat order x_ir_mul = jax.tree.map(lambda v: rearrange(v, "... m i -> ... i m"), x) - x_flat = dict_to_irreps(self.irreps_in, x_ir_mul) + x_flat = ir_dict.dict_to_flat(self.irreps_in, x_ir_mul) num_elements = x_flat.shape[0] p = self.e.polynomial @@ -263,6 +263,6 @@ def __call__( ) # Convert ir_mul flat -> dict (batch, mul, ir.dim) - y = irreps_to_dict(self.irreps_out, y_flat, layout="ir_mul") - assert_mul_ir_dict(self.irreps_out, y) + y = ir_dict.flat_to_dict(self.irreps_out, y_flat, layout="ir_mul") + ir_dict.assert_mul_ir_dict(self.irreps_out, y) return y