Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/check-catalyst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,11 @@ jobs:
python3 -m pip install oqc-qcaas-client
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp


- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down Expand Up @@ -569,6 +574,10 @@ jobs:
python3 -m pip install -r requirements.txt
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp

- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down Expand Up @@ -631,6 +640,10 @@ jobs:
python3 -m pip install -r requirements.txt
make frontend

- name: Install PennyLane branch
run: |
pip install --no-deps --force git+https://github.com/PennyLaneAI/pennylane@capture-vjp

- name: Get Cached LLVM Build
id: cache-llvm-build
uses: actions/cache@v4
Expand Down
3 changes: 3 additions & 0 deletions frontend/catalyst/api_extensions/differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Callable, Iterable, List, Optional, Union

import jax
import pennylane as qml
from jax._src.api import _dtype
from jax._src.tree_util import PyTreeDef, tree_flatten, tree_unflatten
from jax.api_util import debug_info
Expand Down Expand Up @@ -546,6 +547,8 @@ def f(x):
(Array([0.09983342, 0.04 , 0.02 ], dtype=float64),
(Array([-0.43750208, 0.07 ], dtype=float64),))
"""
if qml.capture.enabled():
return qml.vjp(f, params, cotangents, method=method, h=h, argnums=argnums)

def check_is_iterable(x, hint):
if not isinstance(x, Iterable):
Expand Down
4 changes: 4 additions & 0 deletions frontend/catalyst/autograph/ag_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,10 @@ def converted_call(fn, args, kwargs, caller_fn_scope=None, options=None):
qml.prod,
catalyst.ctrl,
qml.ctrl,
qml.grad,
qml.jacobian,
qml.vjp,
qml.jvp,
catalyst.grad,
catalyst.value_and_grad,
catalyst.jacobian,
Expand Down
27 changes: 20 additions & 7 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import transform_prim
from pennylane.capture.primitives import vjp_prim as pl_vjp_prim
from pennylane.transforms import commute_controlled as pl_commute_controlled
from pennylane.transforms import decompose as pl_decompose
from pennylane.transforms import gridsynth as pl_gridsynth
Expand Down Expand Up @@ -228,15 +229,27 @@ def __init__(self):


@WorkflowInterpreter.register_primitive(pl_jac_prim)
def handle_grad(self, *args, jaxpr, n_consts, **kwargs):
def handle_grad(self, *args, jaxpr, **kwargs):
"""Translate a grad equation."""
f = partial(copy(self).eval, jaxpr, args[:n_consts])
new_jaxpr = jax.make_jaxpr(f)(*args[n_consts:])
f = partial(copy(self).eval, jaxpr, [])
new_jaxpr = jax.make_jaxpr(f)(*args)

new_args = (*new_jaxpr.consts, *args[n_consts:])
return pl_jac_prim.bind(
*new_args, jaxpr=new_jaxpr.jaxpr, n_consts=len(new_jaxpr.consts), **kwargs
)
new_args = (*new_jaxpr.consts, *args)
j = new_jaxpr.jaxpr
new_j = j.replace(constvars=(), invars=j.constvars + j.invars)
return pl_jac_prim.bind(*new_args, jaxpr=new_j, **kwargs)


@WorkflowInterpreter.register_primitive(pl_vjp_prim)
def handle_vjp(self, *args, jaxpr, **kwargs):
"""Translate a grad equation."""
f = partial(copy(self).eval, jaxpr, [])
new_jaxpr = jax.make_jaxpr(f)(*args[: -len(jaxpr.outvars)])

new_args = (*new_jaxpr.consts, *args)
j = new_jaxpr.jaxpr
new_j = j.replace(constvars=(), invars=j.constvars + j.invars)
return pl_vjp_prim.bind(*new_args, jaxpr=new_j, **kwargs)


# pylint: disable=unused-argument, too-many-arguments
Expand Down
41 changes: 37 additions & 4 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@
)

from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import vjp_prim as pl_vjp_prim

from catalyst.compiler import get_lib_path
from catalyst.jax_extras import (
Expand Down Expand Up @@ -722,15 +723,14 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):


# pylint: disable=too-many-arguments
def _capture_grad_lowering(ctx, *args, argnums, jaxpr, n_consts, method, h, fn, scalar_out):
def _capture_grad_lowering(ctx, *args, argnums, jaxpr, method, h, fn, scalar_out):
mlir_ctx = ctx.module_context.context
f64 = ir.F64Type.get(mlir_ctx)
finiteDiffParam = ir.FloatAttr.get(f64, h)

new_argnums = [num + n_consts for num in argnums]
argnum_numpy = np.array(new_argnums)
argnum_numpy = np.array(argnums)
diffArgIndices = ir.DenseIntElementsAttr.get(argnum_numpy)
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *new_argnums), fn=fn)
func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn)
symbol_ref = get_symbolref(ctx, func_op)
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
Expand Down Expand Up @@ -926,6 +926,38 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params):
).results


def _capture_vjp_lowering(ctx, *args, jaxpr, fn, method, argnums, h):
"""
Returns:
MLIR results
"""
args = list(args)
mlir_ctx = ctx.module_context.context
n_params = len(jaxpr.invars)
new_argnums = np.array(argnums)

output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
func_args = args[:n_params]
cotang_args = args[n_params:]
func_result_types = flat_output_types[: len(flat_output_types) - len(argnums)]
vjp_result_types = flat_output_types[len(flat_output_types) - len(argnums) :]

func_op = lower_jaxpr(ctx, jaxpr, (method, h, *argnums), fn=fn)

symbol_ref = get_symbolref(ctx, func_op)
return VJPOp(
func_result_types,
vjp_result_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_ir_values(func_args),
mlir.flatten_ir_values(cotang_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results


#
# zne
#
Expand Down Expand Up @@ -2874,6 +2906,7 @@ def subroutine_lowering(*args, **kwargs):
(for_p, _for_loop_lowering),
(grad_p, _grad_lowering),
(pl_jac_prim, _capture_grad_lowering),
(pl_vjp_prim, _capture_vjp_lowering),
(func_p, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
Expand Down
18 changes: 7 additions & 11 deletions frontend/test/pytest/test_autograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,13 @@
import pytest
from jax.errors import TracerBoolConversionError
from numpy.testing import assert_allclose
from pennylane import adjoint, cond, ctrl, for_loop, grad, jacobian, jvp, while_loop, vjp
from pennylane.capture.autograph.transformer import TRANSFORMER as capture_TRANSFORMER

from catalyst import AutoGraphError, debug, passes, qjit
from catalyst.api_extensions import (
adjoint,
cond,
ctrl,
for_loop,
grad,
jacobian,
jvp,
measure,
vjp,
vmap,
while_loop,
)
from catalyst.autograph import autograph_source, disable_autograph, run_autograph
from catalyst.autograph.transformer import TRANSFORMER
Expand Down Expand Up @@ -396,18 +388,22 @@ def fn(x: float):
assert check_cache(inner)
assert fn(3) == tuple([jax.numpy.array(2.0), jax.numpy.array(6.0)])

@pytest.mark.usefixtures("use_both_frontend")
def test_vjp_wrapper(self):
"""Test conversion is happening succesfully on functions wrapped with 'vjp'."""

def inner(x):
return 2 * x, x**2
if x > 0:
return 2 * x, x**2
return 4*x, x**8

@qjit(autograph=True)
def fn(x: float):
return vjp(inner, (x,), (1.0, 1.0))

assert hasattr(fn.user_function, "ag_unconverted")
assert check_cache(inner)
if not qml.capture.enabled():
assert check_cache(inner)
assert np.allclose(fn(3)[0], tuple([jnp.array(6.0), jnp.array(9.0)]))
assert np.allclose(fn(3)[1], jnp.array(8.0))

Expand Down
Loading
Loading