From 82979ef4af46bc375fd6409445cc15d2f32f9285 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 9 Dec 2025 09:35:08 -0500 Subject: [PATCH 1/5] fix grad, translate and lower vjp --- .github/workflows/check-catalyst.yaml | 13 +++++++ frontend/catalyst/from_plxpr/from_plxpr.py | 28 ++++++++++++--- frontend/catalyst/jax_primitives.py | 40 +++++++++++++++++++--- 3 files changed, 72 insertions(+), 9 deletions(-) diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 999e5ad137..245632c1ad 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -475,6 +475,11 @@ jobs: python3 -m pip install oqc-qcaas-client make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -558,6 +563,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -620,6 +629,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index ff30a3d34c..0fba2356da 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -28,6 +28,7 @@ from pennylane.capture import PlxprInterpreter, qnode_prim from pennylane.capture.expand_transforms import ExpandTransformsInterpreter from pennylane.capture.primitives import jacobian_prim as pl_jac_prim +from pennylane.capture.primitives import vjp_prim as pl_vjp_prim from pennylane.capture.primitives import transform_prim from pennylane.transforms import commute_controlled as pl_commute_controlled from pennylane.transforms import decompose as pl_decompose @@ -227,16 +228,33 @@ def __init__(self): @WorkflowInterpreter.register_primitive(pl_jac_prim) -def handle_grad(self, *args, jaxpr, n_consts, **kwargs): +def handle_grad(self, *args, jaxpr, **kwargs): """Translate a grad equation.""" - f = partial(copy(self).eval, jaxpr, args[:n_consts]) - new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:]) + f = partial(copy(self).eval, jaxpr, []) + new_jaxpr = jax.make_jaxpr(f)(*args) - new_args = (*new_jaxpr.consts, *args[n_consts:]) + new_args = (*new_jaxpr.consts, *args) + j = new_jaxpr.jaxpr + new_j = j.replace(constvars=(), invars=j.constvars+j.invars) return pl_jac_prim.bind( - *new_args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **kwargs + *new_args, jaxpr=new_j, **kwargs ) +@WorkflowInterpreter.register_primitive(pl_vjp_prim) +def handle_vjp(self, *args, jaxpr, **kwargs): + """Translate a grad equation.""" + f = partial(copy(self).eval, jaxpr, []) + new_jaxpr = jax.make_jaxpr(f)(*args[:-len(jaxpr.outvars)]) + + new_args = (*new_jaxpr.consts, *args) + j = new_jaxpr.jaxpr + new_j = j.replace(constvars=(), invars=j.constvars+j.invars) + return pl_jac_prim.bind( + *new_args, jaxpr=new_j, **kwargs + ) + + + # pylint: disable=unused-argument, too-many-arguments @WorkflowInterpreter.register_primitive(qnode_prim) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index e14538cf06..685d1c67c5 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -133,6 +133,7 @@ ) from pennylane.capture.primitives import jacobian_prim as pl_jac_prim +from pennylane.capture.primitives import vjp_prim as pl_vjp_prim from catalyst.compiler import get_lib_path from catalyst.jax_extras import ( @@ -722,15 +723,14 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): # pylint: disable=too-many-arguments -def _capture_grad_lowering(ctx, *args, argnums, jaxpr, n_consts, method, h, fn, scalar_out): +def _capture_grad_lowering(ctx, *args, argnums, jaxpr, method, h, fn, scalar_out): mlir_ctx = ctx.module_context.context f64 = ir.F64Type.get(mlir_ctx) finiteDiffParam = ir.FloatAttr.get(f64, h) - new_argnums = [num + n_consts for num in argnums] - argnum_numpy = np.array(new_argnums) + argnum_numpy = np.array(argnums) diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy) - func_op = lower_jaxpr(ctx, jaxpr, (method, h, *new_argnums), fn=fn) + func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn) symbol_ref = get_symbolref(ctx, func_op) output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) @@ -926,6 +926,37 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params): ).results +def _capture_vjp_lowering(ctx, *args, jaxpr, fn, method, argnums, h): + """ + Returns: + MLIR results + """ + args = list(args) + mlir_ctx = ctx.module_context.context + n_params = len(jaxpr.invars) + n_consts + new_argnums = np.array(argnums) + + output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) + flat_output_types = util.flatten(output_types) + func_args = args[:n_params] + cotang_args = args[n_params:] + func_result_types = flat_output_types[: len(flat_output_types) - len(argnums)] + vjp_result_types = flat_output_types[len(flat_output_types) - len(argnums) :] + + func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn) + + symbol_ref = get_symbolref(ctx, func_op) + return VJPOp( + func_result_types, + vjp_result_types, + ir.StringAttr.get(method), + symbol_ref, + mlir.flatten_ir_values(func_args), + mlir.flatten_ir_values(cotang_args), + diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums), + finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, + ).results + # # zne # @@ -2874,6 +2905,7 @@ def subroutine_lowering(*args, **kwargs): (for_p, _for_loop_lowering), (grad_p, _grad_lowering), (pl_jac_prim, _capture_grad_lowering), + (pl_vjp_prim, _capture_vjp_lowering), (func_p, _func_lowering), (jvp_p, _jvp_lowering), (vjp_p, _vjp_lowering), From cd314a404092f3d8166f0cdd28d72283b4761c2f Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 11 Dec 2025 09:54:43 -0500 Subject: [PATCH 2/5] [skip-ci] more fixeS --- frontend/catalyst/autograph/ag_primitives.py | 4 ++ frontend/catalyst/from_plxpr/from_plxpr.py | 2 +- frontend/catalyst/jax_primitives.py | 2 +- frontend/test/pytest/test_autograph.py | 6 ++- frontend/test/pytest/test_jvpvjp.py | 51 +++++++++++++------- 5 files changed, 43 insertions(+), 22 deletions(-) diff --git a/frontend/catalyst/autograph/ag_primitives.py b/frontend/catalyst/autograph/ag_primitives.py index bed3c5cbfd..62f5cac79d 100644 --- a/frontend/catalyst/autograph/ag_primitives.py +++ b/frontend/catalyst/autograph/ag_primitives.py @@ -532,6 +532,10 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None): qml.prod, catalyst.ctrl, qml.ctrl, + qml.grad, + qml.jacobian, + qml.vjp, + qml.jvp, catalyst.grad, catalyst.value_and_grad, catalyst.jacobian, diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 0fba2356da..f20c5d9001 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -249,7 +249,7 @@ def handle_vjp(self, *args, jaxpr, **kwargs): new_args = (*new_jaxpr.consts, *args) j = new_jaxpr.jaxpr new_j = j.replace(constvars=(), invars=j.constvars+j.invars) - return pl_jac_prim.bind( + return pl_vjp_prim.bind( *new_args, jaxpr=new_j, **kwargs ) diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 685d1c67c5..f80b256ca4 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -933,7 +933,7 @@ def _capture_vjp_lowering(ctx, *args, jaxpr, fn, method, argnums, h): """ args = list(args) mlir_ctx = ctx.module_context.context - n_params = len(jaxpr.invars) + n_consts + n_params = len(jaxpr.invars) new_argnums = np.array(argnums) output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index edd3394bd6..9097d77450 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -27,6 +27,7 @@ from jax.errors import TracerBoolConversionError from numpy.testing import assert_allclose from pennylane.capture.autograph.transformer import TRANSFORMER as capture_TRANSFORMER +from pennylane import vjp from catalyst import AutoGraphError, debug, passes, qjit from catalyst.api_extensions import ( @@ -38,7 +39,6 @@ jacobian, jvp, measure, - vjp, vmap, while_loop, ) @@ -396,6 +396,7 @@ def fn(x: float): assert check_cache(inner) assert fn(3) == tuple([jax.numpy.array(2.0), jax.numpy.array(6.0)]) + @pytest.mark.usefixtures("use_both_frontend") def test_vjp_wrapper(self): """Test conversion is happening succesfully on functions wrapped with 'vjp'.""" @@ -407,7 +408,8 @@ def fn(x: float): return vjp(inner, (x,), (1.0, 1.0)) assert hasattr(fn.user_function, "ag_unconverted") - assert check_cache(inner) + if not qml.capture.enabled(): + assert check_cache(inner) assert np.allclose(fn(3)[0], tuple([jnp.array(6.0), jnp.array(9.0)])) assert np.allclose(fn(3)[1], jnp.array(8.0)) diff --git a/frontend/test/pytest/test_jvpvjp.py b/frontend/test/pytest/test_jvpvjp.py index 5c9f406233..3c0f6fe06b 100644 --- a/frontend/test/pytest/test_jvpvjp.py +++ b/frontend/test/pytest/test_jvpvjp.py @@ -407,6 +407,7 @@ def workflow(): assert_allclose(catalyst_res_flatten, jax_res_flatten, rtol=1e-6) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_S_SS(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -419,7 +420,7 @@ def test_vjp_against_jax_full_argnum_case_S_SS(diff_method): @qjit def C_workflow(): f = qml.QNode(circuit_rx, device=qml.device("lightning.qubit", wires=1)) - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -436,6 +437,7 @@ def J_workflow(): assert_allclose(res_jax, res_cat) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_T_T(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -450,7 +452,7 @@ def f(x): @qjit def C_workflow(): - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -467,6 +469,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_TT_T(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -486,7 +489,7 @@ def f(x1, x2): @qjit def C_workflow(): - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -503,6 +506,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_T_TT(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -517,7 +521,7 @@ def f(x): @qjit def C_workflow(): - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -534,6 +538,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_TT_TT(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -556,7 +561,7 @@ def f(x1, x2): @qjit def C_workflow(): - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -573,6 +578,7 @@ def J_workflow(): assert_allclose(r_j, r_c) + @pytest.mark.parametrize("diff_method", diff_methods) def test_jvpvjp_argument_checks(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -617,13 +623,13 @@ def C_workflow_bad1(): @qjit def C_workflow_bad2(): - return C_vjp(f, list(x), 33, argnums=list(range(len(x)))) + return qml.vjp(f, list(x), 33, argnums=list(range(len(x)))) with pytest.raises(ValueError, match="argnums should be integer or a list of integers"): @qjit def C_workflow_bad3(): - return C_vjp(f, x, ct, argnums="invalid") + return qml.vjp(f, x, ct, argnums="invalid") @pytest.mark.parametrize("diff_method", diff_methods) @@ -680,6 +686,7 @@ def _f(a): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_argnum0_case_TT_TT(diff_method): """Numerically tests Catalyst's vjp against the JAX version, in case of empty or singular @@ -703,11 +710,11 @@ def f(x1, x2): @qjit def C_workflowA(): - return C_vjp(f, x, ct, method=diff_method) + return qml.vjp(f, x, ct, method=diff_method) @qjit def C_workflowB(): - return C_vjp(f, x, ct, method=diff_method, argnums=[0]) + return qml.vjp(f, x, ct, method=diff_method, argnums=[0]) @jax.jit def J_workflow(): @@ -736,6 +743,7 @@ def _f(a): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_pytrees_return(diff_method): """Test VJP with pytree return.""" @@ -746,7 +754,7 @@ def f(x, y): @qjit def C_workflowA(): ct2 = [1.0, {"res": 1.0}, 1.0] - return C_vjp(f, [0.1, 0.2], ct2, method=diff_method, argnums=[0, 1]) + return qml.vjp(f, [0.1, 0.2], ct2, method=diff_method, argnums=[0, 1]) @jax.jit def J_workflow(): @@ -763,6 +771,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_pytrees_args(diff_method): """Test VJP with pytree args.""" @@ -773,7 +782,7 @@ def f(x, y): @qjit def C_workflowA(): ct2 = [1.0, 1.0] - return C_vjp(f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1]) + return qml.vjp(f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1]) @jax.jit def J_workflow(): @@ -790,6 +799,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_VJP_pytrees_args_and_return(diff_method): """Test that a VJP with pytrees as args.""" @@ -800,7 +810,7 @@ def f(x, y): @qjit def C_workflowA(): ct2 = [1.0, {"res": 1.0}, 1.0] - return C_vjp(f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1]) + return qml.vjp(f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1]) @jax.jit def J_workflow(): @@ -817,6 +827,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_multi_return(diff_method): """Test VJP with multiple returns.""" @@ -826,7 +837,7 @@ def f(x): @qjit def C_workflowA(): - return C_vjp(f, [0.1], [1.0, 1.0], method=diff_method, argnums=[0]) + return qml.vjp(f, [0.1], [1.0, 1.0], method=diff_method, argnums=[0]) @jax.jit def J_workflow(): @@ -918,6 +929,7 @@ def C_workflow(): return C_jvp(g_R3_to_R2, [1, x], [tangents], method=diff_method, argnums=[1]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_argument_type_checks_correct_inputs(diff_method): """Test that Catalyst's vjp can JIT compile when given the correct types.""" @@ -926,15 +938,16 @@ def test_vjp_argument_type_checks_correct_inputs(diff_method): def C_workflow_f(): x = (1.0,) cotangents = (1.0, 1.0) - return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) + return qml.vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) @qjit def C_workflow_g(): x = jnp.array([2.0, 3.0, 4.0]) cotangents = jnp.ones([2], dtype=float) - return C_vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) + return qml.vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_argument_type_checks_incompatible_n_inputs(diff_method): """Tests error handling of Catalyst's vjp when the number of function output params @@ -954,9 +967,10 @@ def C_workflow(): # If `f` returns two outputs, then `cotangents` must have length 2 x = (1.0,) cotangents = (1.0,) - return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) + return qml.vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_argument_type_checks_incompatible_input_types(diff_method): """Tests error handling of Catalyst's vjp when the types of the function output params @@ -973,9 +987,10 @@ def C_workflow(): # If `x` has type float, then `cotangents` should also have type float x = (1.0,) cotangents = (1, 1) - return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) + return qml.vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_argument_type_checks_incompatible_input_shapes(diff_method): """Tests error handling of Catalyst's vjp when the shapes of the function output params @@ -993,7 +1008,7 @@ def C_workflow(): # shape (2,), but it has shape (3,) x = jnp.array([2.0, 3.0, 4.0]) cotangents = jnp.ones([3], dtype=float) - return C_vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) + return qml.vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) if __name__ == "__main__": From 8f2f98e3adee1da1c3ac546bfb9365204fc54a69 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Thu, 11 Dec 2025 17:43:22 -0500 Subject: [PATCH 3/5] fixes --- .../api_extensions/differentiation.py | 4 ++- frontend/test/pytest/test_jvpvjp.py | 35 +++++++++---------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/frontend/catalyst/api_extensions/differentiation.py b/frontend/catalyst/api_extensions/differentiation.py index 632b7facf3..82b391c8b5 100644 --- a/frontend/catalyst/api_extensions/differentiation.py +++ b/frontend/catalyst/api_extensions/differentiation.py @@ -27,6 +27,7 @@ from jax._src.api import _dtype from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten from jax.api_util import debug_info +import pennylane as qml from pennylane import QNode import catalyst @@ -546,7 +547,8 @@ def f(x): (Array([0.09983342, 0.04 , 0.02 ], dtype=float64), (Array([-0.43750208, 0.07 ], dtype=float64),)) """ - + if qml.capture.enabled(): + return qml.vjp(f, params, cotangents, method=method, h=h, argnums=argnums) def check_is_iterable(x, hint): if not isinstance(x, Iterable): raise ValueError(f"vjp '{hint}' argument must be an iterable, not {type(x)}") diff --git a/frontend/test/pytest/test_jvpvjp.py b/frontend/test/pytest/test_jvpvjp.py index 3c0f6fe06b..755bdf798a 100644 --- a/frontend/test/pytest/test_jvpvjp.py +++ b/frontend/test/pytest/test_jvpvjp.py @@ -434,7 +434,7 @@ def J_workflow(): res_jax, tree_jax = jax.tree_util.tree_flatten(r1) res_cat, tree_cat = jax.tree_util.tree_flatten(r2) assert tree_jax == tree_cat - assert_allclose(res_jax, res_cat) + assert_allclose(res_jax, res_cat, atol=5e-7) @pytest.mark.usefixtures("use_both_frontend") @@ -850,7 +850,7 @@ def J_workflow(): res_cat, tree_cat = jax.tree_util.tree_flatten(r2) assert tree_jax == tree_cat for r_j, r_c in zip(res_jax, res_cat): - assert_allclose(r_j, r_c) + assert_allclose(r_j, r_c, atol=2e-6) @pytest.mark.parametrize("diff_method", diff_methods) @@ -879,8 +879,7 @@ def test_jvp_argument_type_checks_incompatible_n_inputs(diff_method): with pytest.raises( TypeError, match=( - "number of tangent and number of differentiable parameters in catalyst.jvp " - "do not match" + "number of tangent and number of differentiable parameters in" ), ): @@ -931,23 +930,23 @@ def C_workflow(): @pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) -def test_vjp_argument_type_checks_correct_inputs(diff_method): +@pytest.mark.parametrize("vjp_fn", (qml.vjp, C_vjp)) +def test_vjp_argument_type_checks_correct_inputs(diff_method, vjp_fn): """Test that Catalyst's vjp can JIT compile when given the correct types.""" @qjit def C_workflow_f(): x = (1.0,) cotangents = (1.0, 1.0) - return qml.vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) + return vjp_fn(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) @qjit def C_workflow_g(): x = jnp.array([2.0, 3.0, 4.0]) cotangents = jnp.ones([2], dtype=float) - return qml.vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) + return vjp_fn(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) -@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_argument_type_checks_incompatible_n_inputs(diff_method): """Tests error handling of Catalyst's vjp when the number of function output params @@ -957,8 +956,7 @@ def test_vjp_argument_type_checks_incompatible_n_inputs(diff_method): with pytest.raises( TypeError, match=( - "number of cotangent and number of function output parameters in catalyst.vjp " - "do not match" + "number of cotangent and number of function output parameters in" ), ): @@ -967,19 +965,20 @@ def C_workflow(): # If `f` returns two outputs, then `cotangents` must have length 2 x = (1.0,) cotangents = (1.0,) - return qml.vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) + return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) @pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) -def test_vjp_argument_type_checks_incompatible_input_types(diff_method): +@pytest.mark.parametrize("vjp_fn", (qml.vjp, C_vjp)) +def test_vjp_argument_type_checks_incompatible_input_types(diff_method, vjp_fn): """Tests error handling of Catalyst's vjp when the types of the function output params and cotangent arguments are incompatible. """ with pytest.raises( TypeError, - match="function output params and cotangents arguments to catalyst.vjp do not match", + match="function output params and cotangents arguments to ", ): @qjit @@ -987,19 +986,19 @@ def C_workflow(): # If `x` has type float, then `cotangents` should also have type float x = (1.0,) cotangents = (1, 1) - return qml.vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) + return vjp_fn(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) @pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) -def test_vjp_argument_type_checks_incompatible_input_shapes(diff_method): +@pytest.mark.parametrize("vjp_fn", (qml.vjp, C_vjp)) +def test_vjp_argument_type_checks_incompatible_input_shapes(diff_method, vjp_fn): """Tests error handling of Catalyst's vjp when the shapes of the function output params and cotangent arguments are incompatible. """ - with pytest.raises( ValueError, - match="catalyst.vjp called with different function output params and cotangent shapes", + match="vjp called with different function output params and cotangent shapes", ): @qjit @@ -1008,7 +1007,7 @@ def C_workflow(): # shape (2,), but it has shape (3,) x = jnp.array([2.0, 3.0, 4.0]) cotangents = jnp.ones([3], dtype=float) - return qml.vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) + return vjp_fn(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) if __name__ == "__main__": From 83b215f315770ed3098a297bc97ca91977a78776 Mon Sep 17 00:00:00 2001 From: albi3ro Date: Tue, 16 Dec 2025 15:08:48 -0500 Subject: [PATCH 4/5] black, isort, etc. [skip-ci] --- .../api_extensions/differentiation.py | 3 ++- frontend/catalyst/from_plxpr/from_plxpr.py | 19 +++++++------------ frontend/catalyst/jax_primitives.py | 1 + frontend/test/pytest/test_autograph.py | 2 +- frontend/test/pytest/test_jvpvjp.py | 19 +++++++++---------- 5 files changed, 20 insertions(+), 24 deletions(-) diff --git a/frontend/catalyst/api_extensions/differentiation.py b/frontend/catalyst/api_extensions/differentiation.py index 82b391c8b5..734de44c1c 100644 --- a/frontend/catalyst/api_extensions/differentiation.py +++ b/frontend/catalyst/api_extensions/differentiation.py @@ -24,10 +24,10 @@ from typing import Callable, Iterable, List, Optional, Union import jax +import pennylane as qml from jax._src.api import _dtype from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten from jax.api_util import debug_info -import pennylane as qml from pennylane import QNode import catalyst @@ -549,6 +549,7 @@ def f(x): """ if qml.capture.enabled(): return qml.vjp(f, params, cotangents, method=method, h=h, argnums=argnums) + def check_is_iterable(x, hint): if not isinstance(x, Iterable): raise ValueError(f"vjp '{hint}' argument must be an iterable, not {type(x)}") diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index f20c5d9001..a01ef24a16 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -28,8 +28,8 @@ from pennylane.capture import PlxprInterpreter, qnode_prim from pennylane.capture.expand_transforms import ExpandTransformsInterpreter from pennylane.capture.primitives import jacobian_prim as pl_jac_prim -from pennylane.capture.primitives import vjp_prim as pl_vjp_prim from pennylane.capture.primitives import transform_prim +from pennylane.capture.primitives import vjp_prim as pl_vjp_prim from pennylane.transforms import commute_controlled as pl_commute_controlled from pennylane.transforms import decompose as pl_decompose from pennylane.transforms import merge_amplitude_embedding as pl_merge_amplitude_embedding @@ -235,25 +235,20 @@ def handle_grad(self, *args, jaxpr, **kwargs): new_args = (*new_jaxpr.consts, *args) j = new_jaxpr.jaxpr - new_j = j.replace(constvars=(), invars=j.constvars+j.invars) - return pl_jac_prim.bind( - *new_args, jaxpr=new_j, **kwargs - ) + new_j = j.replace(constvars=(), invars=j.constvars + j.invars) + return pl_jac_prim.bind(*new_args, jaxpr=new_j, **kwargs) + @WorkflowInterpreter.register_primitive(pl_vjp_prim) def handle_vjp(self, *args, jaxpr, **kwargs): """Translate a grad equation.""" f = partial(copy(self).eval, jaxpr, []) - new_jaxpr = jax.make_jaxpr(f)(*args[:-len(jaxpr.outvars)]) + new_jaxpr = jax.make_jaxpr(f)(*args[: -len(jaxpr.outvars)]) new_args = (*new_jaxpr.consts, *args) j = new_jaxpr.jaxpr - new_j = j.replace(constvars=(), invars=j.constvars+j.invars) - return pl_vjp_prim.bind( - *new_args, jaxpr=new_j, **kwargs - ) - - + new_j = j.replace(constvars=(), invars=j.constvars + j.invars) + return pl_vjp_prim.bind(*new_args, jaxpr=new_j, **kwargs) # pylint: disable=unused-argument, too-many-arguments diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index f80b256ca4..3cc9758e45 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -957,6 +957,7 @@ def _capture_vjp_lowering(ctx, *args, jaxpr, fn, method, argnums, h): finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, ).results + # # zne # diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index 9097d77450..c8c825f2f5 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -26,8 +26,8 @@ import pytest from jax.errors import TracerBoolConversionError from numpy.testing import assert_allclose -from pennylane.capture.autograph.transformer import TRANSFORMER as capture_TRANSFORMER from pennylane import vjp +from pennylane.capture.autograph.transformer import TRANSFORMER as capture_TRANSFORMER from catalyst import AutoGraphError, debug, passes, qjit from catalyst.api_extensions import ( diff --git a/frontend/test/pytest/test_jvpvjp.py b/frontend/test/pytest/test_jvpvjp.py index 755bdf798a..627b89dfc8 100644 --- a/frontend/test/pytest/test_jvpvjp.py +++ b/frontend/test/pytest/test_jvpvjp.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test JVP/VJP operation lowering""" - +# pylint: disable=too-many-lines from typing import TypeVar import jax @@ -578,7 +578,6 @@ def J_workflow(): assert_allclose(r_j, r_c) - @pytest.mark.parametrize("diff_method", diff_methods) def test_jvpvjp_argument_checks(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -782,7 +781,9 @@ def f(x, y): @qjit def C_workflowA(): ct2 = [1.0, 1.0] - return qml.vjp(f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1]) + return qml.vjp( + f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1] + ) @jax.jit def J_workflow(): @@ -810,7 +811,9 @@ def f(x, y): @qjit def C_workflowA(): ct2 = [1.0, {"res": 1.0}, 1.0] - return qml.vjp(f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1]) + return qml.vjp( + f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1] + ) @jax.jit def J_workflow(): @@ -878,9 +881,7 @@ def test_jvp_argument_type_checks_incompatible_n_inputs(diff_method): with pytest.raises( TypeError, - match=( - "number of tangent and number of differentiable parameters in" - ), + match=("number of tangent and number of differentiable parameters in"), ): @qjit @@ -955,9 +956,7 @@ def test_vjp_argument_type_checks_incompatible_n_inputs(diff_method): with pytest.raises( TypeError, - match=( - "number of cotangent and number of function output parameters in" - ), + match=("number of cotangent and number of function output parameters in"), ): @qjit From 87c2afb0150a1b47f0908c9eaa36af33d02f29ce Mon Sep 17 00:00:00 2001 From: albi3ro Date: Wed, 17 Dec 2025 17:00:28 -0500 Subject: [PATCH 5/5] switch to using pl namespace --- frontend/test/pytest/test_autograph.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index c8c825f2f5..4696ee9517 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -26,21 +26,13 @@ import pytest from jax.errors import TracerBoolConversionError from numpy.testing import assert_allclose -from pennylane import vjp +from pennylane import adjoint, cond, ctrl, for_loop, grad, jacobian, jvp, while_loop, vjp from pennylane.capture.autograph.transformer import TRANSFORMER as capture_TRANSFORMER from catalyst import AutoGraphError, debug, passes, qjit from catalyst.api_extensions import ( - adjoint, - cond, - ctrl, - for_loop, - grad, - jacobian, - jvp, measure, vmap, - while_loop, ) from catalyst.autograph import autograph_source, disable_autograph, run_autograph from catalyst.autograph.transformer import TRANSFORMER @@ -401,7 +393,9 @@ def test_vjp_wrapper(self): """Test conversion is happening succesfully on functions wrapped with 'vjp'.""" def inner(x): - return 2 * x, x**2 + if x > 0: + return 2 * x, x**2 + return 4*x, x**8 @qjit(autograph=True) def fn(x: float):