Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions cuequivariance_jax/cuequivariance_jax/ir_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
"segmented_polynomial_uniform_1d",
"assert_mul_ir_dict",
"mul_ir_dict",
"irreps_to_dict",
"dict_to_irreps",
"flat_to_dict",
"dict_to_flat",
"irreps_add",
"irreps_zeros_like",
]
Expand Down Expand Up @@ -243,7 +243,7 @@ def mul_ir_dict(irreps: cue.Irreps, data: Any) -> dict[Irrep, Any]:
return jax.tree.broadcast(data, {ir: None for _, ir in irreps}, lambda v: v is None)


def irreps_to_dict(
def flat_to_dict(
irreps: cue.Irreps, data: Array, *, layout: str = "mul_ir"
) -> dict[Irrep, Array]:
"""Convert a flat array to dict[Irrep, Array] with shape (..., mul, ir.dim).
Expand All @@ -266,7 +266,7 @@ def irreps_to_dict(
>>> irreps = cue.Irreps(cue.O3, "128x0e + 64x1o")
>>> batch = 32
>>> flat = jnp.ones((batch, irreps.dim))
>>> d = irreps_to_dict(irreps, flat)
>>> d = flat_to_dict(irreps, flat)
>>> d[cue.O3(0, 1)].shape
(32, 128, 1)
>>> d[cue.O3(1, -1)].shape
Expand All @@ -287,7 +287,7 @@ def irreps_to_dict(
return result


def dict_to_irreps(irreps: cue.Irreps, x: dict[Irrep, Array]) -> Array:
def dict_to_flat(irreps: cue.Irreps, x: dict[Irrep, Array]) -> Array:
"""Convert dict[Irrep, Array] back to a flat contiguous array.

Flattens the (multiplicity, irrep_dim) dimensions and concatenates all irreps.
Expand All @@ -305,7 +305,7 @@ def dict_to_irreps(irreps: cue.Irreps, x: dict[Irrep, Array]) -> Array:
>>> batch = 32
>>> d = {cue.O3(0, 1): jnp.ones((batch, 128, 1)),
... cue.O3(1, -1): jnp.ones((batch, 64, 3))}
>>> flat = dict_to_irreps(irreps, d)
>>> flat = dict_to_flat(irreps, d)
>>> flat.shape
(32, 320)
"""
Expand Down
14 changes: 7 additions & 7 deletions cuequivariance_jax/cuequivariance_jax/nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
import cuequivariance as cue
from cuequivariance import Irrep

from . import ir_dict
from .activation import normalize_function
from .ir_dict import assert_mul_ir_dict, dict_to_irreps, irreps_to_dict
from .rep_array.rep_array_ import RepArray
from .segmented_polynomials.segmented_polynomial import segmented_polynomial
from .segmented_polynomials.utils import Repeats
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
self.w = nnx.Dict(w)

def __call__(self, x: dict[Irrep, Array]) -> dict[Irrep, Array]:
assert_mul_ir_dict(self.irreps_in, x)
ir_dict.assert_mul_ir_dict(self.irreps_in, x)

x0 = jax.tree.leaves(x)[0]
shape = x0.shape[:-2]
Expand All @@ -129,7 +129,7 @@ def __call__(self, x: dict[Irrep, Array]) -> dict[Irrep, Array]:
/ jnp.sqrt(w[...].shape[0])
)

assert_mul_ir_dict(self.irreps_out, y)
ir_dict.assert_mul_ir_dict(self.irreps_out, y)
return y


Expand Down Expand Up @@ -245,11 +245,11 @@ def __init__(
def __call__(
self, x: dict[Irrep, Array], num_index_counts: Array
) -> dict[Irrep, Array]:
assert_mul_ir_dict(self.irreps_in, x)
ir_dict.assert_mul_ir_dict(self.irreps_in, x)

# Convert dict (batch, mul, ir.dim) -> ir_mul flat order
x_ir_mul = jax.tree.map(lambda v: rearrange(v, "... m i -> ... i m"), x)
x_flat = dict_to_irreps(self.irreps_in, x_ir_mul)
x_flat = ir_dict.dict_to_flat(self.irreps_in, x_ir_mul)
num_elements = x_flat.shape[0]

p = self.e.polynomial
Expand All @@ -263,6 +263,6 @@ def __call__(
)

# Convert ir_mul flat -> dict (batch, mul, ir.dim)
y = irreps_to_dict(self.irreps_out, y_flat, layout="ir_mul")
assert_mul_ir_dict(self.irreps_out, y)
y = ir_dict.flat_to_dict(self.irreps_out, y_flat, layout="ir_mul")
ir_dict.assert_mul_ir_dict(self.irreps_out, y)
return y
22 changes: 8 additions & 14 deletions cuequivariance_jax/examples/mace_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,22 +528,16 @@ def benchmark(
mask=mask,
)

optimizer = optax.adam(1e-2)
opt_state = optimizer.init(nnx.state(model, nnx.Param))
optimizer = nnx.Optimizer(model, optax.adam(1e-2), wrt=nnx.Param)

@nnx.jit
def step(model, opt_state, batch_dict, target_E, target_F):
def step(model, optimizer, batch_dict, target_E, target_F):
def loss_fn(model):
E, F = model(batch_dict)
return jnp.mean((E - target_E) ** 2) + jnp.mean((F - target_F) ** 2)

grad = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
grad_state = nnx.state(grad, nnx.Param)
updates, opt_state_new = optimizer.update(grad_state, opt_state, params)
new_params = optax.apply_updates(params, updates)
nnx.update(model, new_params)
return opt_state_new
grads = nnx.grad(loss_fn)(model)
optimizer.update(model, grads)

@nnx.jit
def inference(model, batch_dict):
Expand All @@ -553,12 +547,12 @@ def inference(model, batch_dict):
runtime_per_inference = 0

if mode in ["train", "both"]:
opt_state = step(model, opt_state, batch_dict, target_E, target_F)
jax.block_until_ready(opt_state)
step(model, optimizer, batch_dict, target_E, target_F)
jax.block_until_ready(nnx.state(model))
t0 = time.perf_counter()
for _ in range(10):
opt_state = step(model, opt_state, batch_dict, target_E, target_F)
jax.block_until_ready(opt_state)
step(model, optimizer, batch_dict, target_E, target_F)
jax.block_until_ready(nnx.state(model))
runtime_per_training_step = 1e3 * (time.perf_counter() - t0) / 10

if mode in ["inference", "both"]:
Expand Down
117 changes: 117 additions & 0 deletions cuequivariance_jax/examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,123 @@ def convert_linear(w, irreps_in, irreps_out):
layer.readout.w[ir][...] = w


def test_mace_linen_nnx_training_equivalence():
"""Test Linen and NNX models produce identical params after training."""
import optax

num_features, num_species = 16, 3
num_steps = 5
learning_rate = 1e-2

config = dict(
num_layers=1,
num_features=num_features,
num_species=num_species,
max_ell=2,
correlation=2,
num_radial_basis=4,
interaction_irreps=cue.Irreps(cue.O3, "0e+1o+2e"),
hidden_irreps=cue.Irreps(cue.O3, "0e+1o"),
offsets=np.zeros(num_species),
cutoff=3.0,
epsilon=0.1,
skip_connection_first_layer=True,
)
batch = _make_batch(
num_atoms=5, num_edges=10, num_species=num_species, num_graphs=1
)

key = jax.random.key(123)
target_E = jax.random.normal(key, (1,))
target_F = jax.random.normal(jax.random.split(key)[0], (5, 3))

# Initialize Linen model
linen_model = MACEModel(**config)
linen_params = linen_model.init(jax.random.key(0), batch)

# Initialize NNX model with converted weights
nnx_model = MACEModelNNX(**config, dtype=jnp.float32, rngs=nnx.Rngs(0))
_convert_linen_to_nnx(linen_params, nnx_model, config)

# Verify initial outputs match (use 1e-3 tolerance like existing test)
E_linen_init, F_linen_init = linen_model.apply(linen_params, batch)
E_nnx_init, F_nnx_init = nnx_model(batch)
np.testing.assert_allclose(E_linen_init, E_nnx_init, atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(F_linen_init, F_nnx_init, atol=1e-3, rtol=1e-3)
print("\nInitial outputs match (within 1e-3):")
print(f" E_linen: {E_linen_init}, E_nnx: {E_nnx_init}")

# Train Linen model
def loss_fn_linen(w):
E, F = linen_model.apply(w, batch)
return jnp.mean((E - target_E) ** 2) + jnp.mean((F - target_F) ** 2)

tx = optax.adam(learning_rate)
opt_state_linen = tx.init(linen_params)
for _ in range(num_steps):
grad = jax.grad(loss_fn_linen)(linen_params)
updates, opt_state_linen = tx.update(grad, opt_state_linen, linen_params)
linen_params = optax.apply_updates(linen_params, updates)

# Train NNX model
def loss_fn_nnx(model):
E, F = model(batch)
return jnp.mean((E - target_E) ** 2) + jnp.mean((F - target_F) ** 2)

optimizer_nnx = nnx.Optimizer(nnx_model, optax.adam(learning_rate), wrt=nnx.Param)
for _ in range(num_steps):
grads = nnx.grad(loss_fn_nnx)(nnx_model)
optimizer_nnx.update(nnx_model, grads)

# Compare outputs after training
E_linen, F_linen = linen_model.apply(linen_params, batch)
E_nnx, F_nnx = nnx_model(batch)

print(f"\nAfter {num_steps} training steps:")
print(f" Linen loss: {loss_fn_linen(linen_params):.6f}")
print(f" NNX loss: {loss_fn_nnx(nnx_model):.6f}")
print(f" E_linen: {E_linen}")
print(f" E_nnx: {E_nnx}")
print(f" max|E diff|: {jnp.max(jnp.abs(E_linen - E_nnx)):.2e}")
print(f" max|F diff|: {jnp.max(jnp.abs(F_linen - F_nnx)):.2e}")

# Compare multiple parameters
print(" Parameter comparisons:")

# 1. Embedding
linen_emb = linen_params["params"]["linear_embedding"]
nnx_emb = nnx_model.embedding[...]
emb_diff = jnp.max(jnp.abs(linen_emb - nnx_emb))
print(f" embedding: max|diff| = {emb_diff:.2e}")

# 2. Skip connection weights (linZ_skip_tp -> skip.w)
linen_skip = linen_params["params"]["layer_0"]["linZ_skip_tp"]
nnx_skip = nnx_model.layers[0].skip.w[...]
skip_diff = jnp.max(jnp.abs(linen_skip - nnx_skip))
print(f" layer_0/skip.w: max|diff| = {skip_diff:.2e}")

# 3. Symmetric contraction weights
linen_sc = linen_params["params"]["layer_0"]["symmetric_contraction"]
nnx_sc = nnx_model.layers[0].sc.w[...]
sc_diff = jnp.max(jnp.abs(linen_sc - nnx_sc))
print(f" layer_0/sc.w: max|diff| = {sc_diff:.2e}")

# 4. Radial MLP weights (first layer)
linen_mlp = linen_params["params"]["layer_0"]["MultiLayerPerceptron_0"]["Dense_0"][
"kernel"
]
nnx_mlp = nnx_model.layers[0].radial_mlp.linears[0][...]
mlp_diff = jnp.max(jnp.abs(linen_mlp - nnx_mlp))
print(f" layer_0/radial_mlp[0]: max|diff| = {mlp_diff:.2e}")

np.testing.assert_allclose(E_linen, E_nnx, atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(F_linen, F_nnx, atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(linen_emb, nnx_emb, atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(linen_skip, nnx_skip, atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(linen_sc, nnx_sc, atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(linen_mlp, nnx_mlp, atol=1e-3, rtol=1e-3)


def test_nequip_model_basic():
"""Test NEQUIP model."""
batch = _make_batch()
Expand Down
2 changes: 1 addition & 1 deletion cuequivariance_jax/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ authors = [{ name = "NVIDIA Corporation" }]
requires-python = ">=3.10"
dependencies = [
"cuequivariance",
"jax",
"jax>=0.8.1",
"packaging",
"einops",
]
Expand Down
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Requirements
- ``cuequivariance-ops-torch-*`` and ``cuequivariance-ops-jax-*`` packages are available for Linux x86_64/aarch64
- Python 3.10-3.14 is supported
- PyTorch 2.4.0+ is required for torch packages
- JAX 0.5.0+ is required for jax packages
- JAX 0.8.1+ is required for jax packages

Organization
------------
Expand Down
49 changes: 0 additions & 49 deletions test_visualization.py

This file was deleted.