diff --git a/.github/workflows/check-catalyst.yaml b/.github/workflows/check-catalyst.yaml index 2c0397482e..6e2ca3243d 100644 --- a/.github/workflows/check-catalyst.yaml +++ b/.github/workflows/check-catalyst.yaml @@ -477,6 +477,10 @@ jobs: sudo apt-get install -y graphviz make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + - name: Verify Graphviz installation run: | dot -V @@ -565,6 +569,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 @@ -628,6 +636,10 @@ jobs: python3 -m pip install -r requirements.txt make frontend + - name: Install PennyLane branch + run: | + pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp + - name: Get Cached LLVM Build id: cache-llvm-build uses: actions/cache@v4 diff --git a/frontend/catalyst/api_extensions/differentiation.py b/frontend/catalyst/api_extensions/differentiation.py index 632b7facf3..734de44c1c 100644 --- a/frontend/catalyst/api_extensions/differentiation.py +++ b/frontend/catalyst/api_extensions/differentiation.py @@ -24,6 +24,7 @@ from typing import Callable, Iterable, List, Optional, Union import jax +import pennylane as qml from jax._src.api import _dtype from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten from jax.api_util import debug_info @@ -546,6 +547,8 @@ def f(x): (Array([0.09983342, 0.04 , 0.02 ], dtype=float64), (Array([-0.43750208, 0.07 ], dtype=float64),)) """ + if qml.capture.enabled(): + return qml.vjp(f, params, cotangents, method=method, h=h, argnums=argnums) def check_is_iterable(x, hint): if not isinstance(x, Iterable): diff --git a/frontend/catalyst/autograph/ag_primitives.py b/frontend/catalyst/autograph/ag_primitives.py index be7099532d..2d3d90bee8 100644 --- a/frontend/catalyst/autograph/ag_primitives.py +++ b/frontend/catalyst/autograph/ag_primitives.py @@ -532,6 +532,10 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None): qml.prod, catalyst.ctrl, qml.ctrl, + qml.grad, + qml.jacobian, + qml.vjp, + qml.jvp, catalyst.grad, catalyst.value_and_grad, catalyst.jacobian, diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index e3f2acc238..327a6d7133 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -29,6 +29,7 @@ from pennylane.capture.expand_transforms import ExpandTransformsInterpreter from pennylane.capture.primitives import jacobian_prim as pl_jac_prim from pennylane.capture.primitives import transform_prim +from pennylane.capture.primitives import vjp_prim as pl_vjp_prim from pennylane.transforms import commute_controlled as pl_commute_controlled from pennylane.transforms import decompose as pl_decompose from pennylane.transforms import gridsynth as pl_gridsynth @@ -228,15 +229,27 @@ def __init__(self): @WorkflowInterpreter.register_primitive(pl_jac_prim) -def handle_grad(self, *args, jaxpr, n_consts, **kwargs): +def handle_grad(self, *args, jaxpr, **kwargs): """Translate a grad equation.""" - f = partial(copy(self).eval, jaxpr, args[:n_consts]) - new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:]) + f = partial(copy(self).eval, jaxpr, []) + new_jaxpr = jax.make_jaxpr(f)(*args) - new_args = (*new_jaxpr.consts, *args[n_consts:]) - return pl_jac_prim.bind( - *new_args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **kwargs - ) + new_args = (*new_jaxpr.consts, *args) + j = new_jaxpr.jaxpr + new_j = j.replace(constvars=(), invars=j.constvars + j.invars) + return pl_jac_prim.bind(*new_args, jaxpr=new_j, **kwargs) + + +@WorkflowInterpreter.register_primitive(pl_vjp_prim) +def handle_vjp(self, *args, jaxpr, **kwargs): + """Translate a grad equation.""" + f = partial(copy(self).eval, jaxpr, []) + new_jaxpr = jax.make_jaxpr(f)(*args[: -len(jaxpr.outvars)]) + + new_args = (*new_jaxpr.consts, *args) + j = new_jaxpr.jaxpr + new_j = j.replace(constvars=(), invars=j.constvars + j.invars) + return pl_vjp_prim.bind(*new_args, jaxpr=new_j, **kwargs) # pylint: disable=unused-argument, too-many-arguments diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 0285a7a23a..a7ae95dc11 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -134,6 +134,7 @@ ) from pennylane.capture.primitives import jacobian_prim as pl_jac_prim +from pennylane.capture.primitives import vjp_prim as pl_vjp_prim from catalyst.compiler import get_lib_path from catalyst.jax_extras import ( @@ -727,15 +728,14 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): # pylint: disable=too-many-arguments -def _capture_grad_lowering(ctx, *args, argnums, jaxpr, n_consts, method, h, fn, scalar_out): +def _capture_grad_lowering(ctx, *args, argnums, jaxpr, method, h, fn, scalar_out): mlir_ctx = ctx.module_context.context f64 = ir.F64Type.get(mlir_ctx) finiteDiffParam = ir.FloatAttr.get(f64, h) - new_argnums = [num + n_consts for num in argnums] - argnum_numpy = np.array(new_argnums) + argnum_numpy = np.array(argnums) diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy) - func_op = lower_jaxpr(ctx, jaxpr, (method, h, *new_argnums), fn=fn) + func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn) symbol_ref = get_symbolref(ctx, func_op) output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) @@ -931,6 +931,38 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params): ).results +def _capture_vjp_lowering(ctx, *args, jaxpr, fn, method, argnums, h): + """ + Returns: + MLIR results + """ + args = list(args) + mlir_ctx = ctx.module_context.context + n_params = len(jaxpr.invars) + new_argnums = np.array(argnums) + + output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) + flat_output_types = util.flatten(output_types) + func_args = args[:n_params] + cotang_args = args[n_params:] + func_result_types = flat_output_types[: len(flat_output_types) - len(argnums)] + vjp_result_types = flat_output_types[len(flat_output_types) - len(argnums) :] + + func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn) + + symbol_ref = get_symbolref(ctx, func_op) + return VJPOp( + func_result_types, + vjp_result_types, + ir.StringAttr.get(method), + symbol_ref, + mlir.flatten_ir_values(func_args), + mlir.flatten_ir_values(cotang_args), + diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums), + finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, + ).results + + # # zne # @@ -2843,6 +2875,7 @@ def subroutine_lowering(*args, **kwargs): (for_p, _for_loop_lowering), (grad_p, _grad_lowering), (pl_jac_prim, _capture_grad_lowering), + (pl_vjp_prim, _capture_vjp_lowering), (func_p, _func_lowering), (jvp_p, _jvp_lowering), (vjp_p, _vjp_lowering), diff --git a/frontend/test/pytest/test_autograph.py b/frontend/test/pytest/test_autograph.py index cbf6791edd..b589812b32 100644 --- a/frontend/test/pytest/test_autograph.py +++ b/frontend/test/pytest/test_autograph.py @@ -26,21 +26,13 @@ import pytest from jax.errors import TracerBoolConversionError from numpy.testing import assert_allclose +from pennylane import adjoint, cond, ctrl, for_loop, grad, jacobian, jvp, while_loop, vjp from pennylane.capture.autograph.transformer import TRANSFORMER as capture_TRANSFORMER from catalyst import AutoGraphError, debug, passes, qjit from catalyst.api_extensions import ( - adjoint, - cond, - ctrl, - for_loop, - grad, - jacobian, - jvp, measure, - vjp, vmap, - while_loop, ) from catalyst.autograph import autograph_source, disable_autograph, run_autograph from catalyst.autograph.transformer import TRANSFORMER @@ -396,19 +388,26 @@ def fn(x: float): assert check_cache(inner) assert fn(3) == tuple([jax.numpy.array(2.0), jax.numpy.array(6.0)]) + @pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("vjp_func", [vjp, qml.vjp]) def test_vjp_wrapper(self, vjp_func): """Test conversion is happening succesfully on functions wrapped with 'vjp'.""" + if qml.capture.enabled() and vjp_func == vjp: + pytest.skip("program capture needs qml.vjp") + def inner(x): - return 2 * x, x**2 + if x > 0: + return 2 * x, x**2 + return 4*x, x**8 @qjit(autograph=True) def fn(x: float): return vjp_func(inner, (x,), (1.0, 1.0)) assert hasattr(fn.user_function, "ag_unconverted") - assert check_cache(inner) + if not qml.capture.enabled(): + assert check_cache(inner) assert np.allclose(fn(3)[0], tuple([jnp.array(6.0), jnp.array(9.0)])) assert np.allclose(fn(3)[1], jnp.array(8.0)) diff --git a/frontend/test/pytest/test_jvpvjp.py b/frontend/test/pytest/test_jvpvjp.py index 19686a828d..65e448aa2a 100644 --- a/frontend/test/pytest/test_jvpvjp.py +++ b/frontend/test/pytest/test_jvpvjp.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test JVP/VJP operation lowering""" - +# pylint: disable=too-many-lines from typing import TypeVar import jax @@ -407,6 +407,7 @@ def workflow(): assert_allclose(catalyst_res_flatten, jax_res_flatten, rtol=1e-6) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_S_SS(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -419,7 +420,7 @@ def test_vjp_against_jax_full_argnum_case_S_SS(diff_method): @qjit def C_workflow(): f = qml.QNode(circuit_rx, device=qml.device("lightning.qubit", wires=1)) - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -433,9 +434,10 @@ def J_workflow(): res_jax, tree_jax = jax.tree_util.tree_flatten(r1) res_cat, tree_cat = jax.tree_util.tree_flatten(r2) assert tree_jax == tree_cat - assert_allclose(res_jax, res_cat) + assert_allclose(res_jax, res_cat, atol=5e-7) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_T_T(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -450,7 +452,7 @@ def f(x): @qjit def C_workflow(): - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -467,6 +469,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_TT_T(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -486,7 +489,7 @@ def f(x1, x2): @qjit def C_workflow(): - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -503,6 +506,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_T_TT(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -517,7 +521,7 @@ def f(x): @qjit def C_workflow(): - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -534,6 +538,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_full_argnum_case_TT_TT(diff_method): """Numerically tests Catalyst's jvp against the JAX version.""" @@ -556,7 +561,7 @@ def f(x1, x2): @qjit def C_workflow(): - return C_vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) + return qml.vjp(f, x, ct, method=diff_method, argnums=list(range(len(x)))) @jax.jit def J_workflow(): @@ -617,13 +622,13 @@ def C_workflow_bad1(): @qjit def C_workflow_bad2(): - return C_vjp(f, list(x), 33, argnums=list(range(len(x)))) + return qml.vjp(f, list(x), 33, argnums=list(range(len(x)))) with pytest.raises(ValueError, match="argnums should be integer or a list of integers"): @qjit def C_workflow_bad3(): - return C_vjp(f, x, ct, argnums="invalid") + return qml.vjp(f, x, ct, argnums="invalid") @pytest.mark.parametrize("diff_method", diff_methods) @@ -680,6 +685,7 @@ def _f(a): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_against_jax_argnum0_case_TT_TT(diff_method): """Numerically tests Catalyst's vjp against the JAX version, in case of empty or singular @@ -703,11 +709,11 @@ def f(x1, x2): @qjit def C_workflowA(): - return C_vjp(f, x, ct, method=diff_method) + return qml.vjp(f, x, ct, method=diff_method) @qjit def C_workflowB(): - return C_vjp(f, x, ct, method=diff_method, argnums=[0]) + return qml.vjp(f, x, ct, method=diff_method, argnums=[0]) @jax.jit def J_workflow(): @@ -736,6 +742,7 @@ def _f(a): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_pytrees_return(diff_method): """Test VJP with pytree return.""" @@ -746,7 +753,7 @@ def f(x, y): @qjit def C_workflowA(): ct2 = [1.0, {"res": 1.0}, 1.0] - return C_vjp(f, [0.1, 0.2], ct2, method=diff_method, argnums=[0, 1]) + return qml.vjp(f, [0.1, 0.2], ct2, method=diff_method, argnums=[0, 1]) @jax.jit def J_workflow(): @@ -763,6 +770,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_pytrees_args(diff_method): """Test VJP with pytree args.""" @@ -773,7 +781,9 @@ def f(x, y): @qjit def C_workflowA(): ct2 = [1.0, 1.0] - return C_vjp(f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1]) + return qml.vjp( + f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1] + ) @jax.jit def J_workflow(): @@ -790,6 +800,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_VJP_pytrees_args_and_return(diff_method): """Test that a VJP with pytrees as args.""" @@ -800,7 +811,9 @@ def f(x, y): @qjit def C_workflowA(): ct2 = [1.0, {"res": 1.0}, 1.0] - return C_vjp(f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1]) + return qml.vjp( + f, [{"res1": 0.1, "res2": 0.2}, 0.3], ct2, method=diff_method, argnums=[0, 1] + ) @jax.jit def J_workflow(): @@ -817,6 +830,7 @@ def J_workflow(): assert_allclose(r_j, r_c) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) def test_vjp_multi_return(diff_method): """Test VJP with multiple returns.""" @@ -826,7 +840,7 @@ def f(x): @qjit def C_workflowA(): - return C_vjp(f, [0.1], [1.0, 1.0], method=diff_method, argnums=[0]) + return qml.vjp(f, [0.1], [1.0, 1.0], method=diff_method, argnums=[0]) @jax.jit def J_workflow(): @@ -839,7 +853,7 @@ def J_workflow(): res_cat, tree_cat = jax.tree_util.tree_flatten(r2) assert tree_jax == tree_cat for r_j, r_c in zip(res_jax, res_cat): - assert_allclose(r_j, r_c) + assert_allclose(r_j, r_c, atol=2e-6) @pytest.mark.parametrize("diff_method", diff_methods) @@ -867,10 +881,7 @@ def test_jvp_argument_type_checks_incompatible_n_inputs(diff_method): with pytest.raises( TypeError, - match=( - "number of tangent and number of differentiable parameters in catalyst.jvp " - "do not match" - ), + match=("number of tangent and number of differentiable parameters in"), ): @qjit @@ -918,21 +929,23 @@ def C_workflow(): return C_jvp(g_R3_to_R2, [1, x], [tangents], method=diff_method, argnums=[1]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) -def test_vjp_argument_type_checks_correct_inputs(diff_method): +@pytest.mark.parametrize("vjp_fn", (qml.vjp, C_vjp)) +def test_vjp_argument_type_checks_correct_inputs(diff_method, vjp_fn): """Test that Catalyst's vjp can JIT compile when given the correct types.""" @qjit def C_workflow_f(): x = (1.0,) cotangents = (1.0, 1.0) - return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) + return vjp_fn(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) @qjit def C_workflow_g(): x = jnp.array([2.0, 3.0, 4.0]) cotangents = jnp.ones([2], dtype=float) - return C_vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) + return vjp_fn(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) @pytest.mark.parametrize("diff_method", diff_methods) @@ -943,10 +956,7 @@ def test_vjp_argument_type_checks_incompatible_n_inputs(diff_method): with pytest.raises( TypeError, - match=( - "number of cotangent and number of function output parameters in catalyst.vjp " - "do not match" - ), + match=("number of cotangent and number of function output parameters in"), ): @qjit @@ -957,15 +967,17 @@ def C_workflow(): return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) -def test_vjp_argument_type_checks_incompatible_input_types(diff_method): +@pytest.mark.parametrize("vjp_fn", (qml.vjp, C_vjp)) +def test_vjp_argument_type_checks_incompatible_input_types(diff_method, vjp_fn): """Tests error handling of Catalyst's vjp when the types of the function output params and cotangent arguments are incompatible. """ with pytest.raises( TypeError, - match="function output params and cotangents arguments to catalyst.vjp do not match", + match="function output params and cotangents arguments to ", ): @qjit @@ -973,18 +985,19 @@ def C_workflow(): # If `x` has type float, then `cotangents` should also have type float x = (1.0,) cotangents = (1, 1) - return C_vjp(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) + return vjp_fn(f_R1_to_R2, x, cotangents, method=diff_method, argnums=[0]) +@pytest.mark.usefixtures("use_both_frontend") @pytest.mark.parametrize("diff_method", diff_methods) -def test_vjp_argument_type_checks_incompatible_input_shapes(diff_method): +@pytest.mark.parametrize("vjp_fn", (qml.vjp, C_vjp)) +def test_vjp_argument_type_checks_incompatible_input_shapes(diff_method, vjp_fn): """Tests error handling of Catalyst's vjp when the shapes of the function output params and cotangent arguments are incompatible. """ - with pytest.raises( ValueError, - match="catalyst.vjp called with different function output params and cotangent shapes", + match="vjp called with different function output params and cotangent shapes", ): @qjit @@ -993,7 +1006,7 @@ def C_workflow(): # shape (2,), but it has shape (3,) x = jnp.array([2.0, 3.0, 4.0]) cotangents = jnp.ones([3], dtype=float) - return C_vjp(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) + return vjp_fn(g_R3_to_R2, [1, x], [cotangents], method=diff_method, argnums=[1]) if __name__ == "__main__":