Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions frontend/catalyst/from_plxpr/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax
from jax.extend.core import ClosedJaxpr
from jax.interpreters.partial_eval import convert_constvars_jaxpr
from pennylane.capture.primitives import subroutine_prim
from pennylane.capture.primitives import cond_prim as plxpr_cond_prim
from pennylane.capture.primitives import for_loop_prim as plxpr_for_loop_prim
from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim
Expand Down Expand Up @@ -473,3 +474,35 @@

# Return only the output values that match the plxpr output values
return outvals

@PLxPRToQuantumJaxprInterpreter.register_primitive(subroutine_prim)
def handle_subroutine(self, *invals, call_jaxpr, fn):

Check notice on line 479 in frontend/catalyst/from_plxpr/control_flow.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/control_flow.py#L479

Missing function or method docstring (missing-function-docstring)

self.init_qreg.insert_all_dangling_qubits()

dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(

Check notice on line 483 in frontend/catalyst/from_plxpr/control_flow.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/control_flow.py#L483

Unused variable 'dynalloced_wire_global_indices' (unused-variable)
invals, self.qubit_index_recorder, self.init_qreg
)
args_plus_qreg = [
*invals,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
self.init_qreg.get(),
]
closed_jaxpr = ClosedJaxpr(call_jaxpr, ())
f = partial(_calling_convention, self, closed_jaxpr, outer_dynqreg_handlers=dynalloced_qregs)
converted_jaxpr = jax.make_jaxpr(f)(*args_plus_qreg)
consts = converted_jaxpr.consts
j = converted_jaxpr.jaxpr
no_consts_jaxpr = j.replace(constvars=(), invars=j.constvars + j.invars)

outvals = subroutine_prim.bind(*consts, *args_plus_qreg, call_jaxpr=no_consts_jaxpr, fn=fn)

# Output structure:
# First a list of dynamically allocated qregs, then the global qreg
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())
for dyn_qreg in reversed(dynalloced_qregs):
dyn_qreg.set(outvals.pop())

# Return only the output values that match the plxpr output values
return outvals

Check notice on line 508 in frontend/catalyst/from_plxpr/control_flow.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/control_flow.py#L508

Final newline missing (missing-final-newline)
4 changes: 3 additions & 1 deletion frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
)

from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import subroutine_prim as pl_subroutine_prim

from catalyst.compiler import get_lib_path
from catalyst.jax_extras import (
Expand Down Expand Up @@ -584,6 +585,7 @@ def _quantum_kernel_lowering(ctx, *args, call_jaxpr, qnode, pipeline=None):
Returns:
List[mlir.Value] corresponding
"""
print("in quantum kernel lowering")
assert isinstance(qnode, qml.QNode), "This function expects qnodes"
if pipeline is None:
pipeline = tuple()
Expand Down Expand Up @@ -614,7 +616,6 @@ def _func_lowering(ctx, *args, call_jaxpr, fn):
call_op = create_call_op(ctx, func_op, *args)
return call_op.results


#
# Decomp rule
#
Expand Down Expand Up @@ -2844,6 +2845,7 @@ def subroutine_lowering(*args, **kwargs):
(grad_p, _grad_lowering),
(pl_jac_prim, _capture_grad_lowering),
(func_p, _func_lowering),
(pl_subroutine_prim, _func_lowering),
(jvp_p, _jvp_lowering),
(vjp_p, _vjp_lowering),
(adjoint_p, _adjoint_lowering),
Expand Down
4 changes: 4 additions & 0 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,13 @@ def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=None, pu
if metadata is None:
metadata = tuple()
key = (callable_, *metadata, *pipeline)
print(key)
print("cache: ", ctx.module_context.cached_primitive_lowerings)
if callable_ is not None:
if func_op := get_cached(ctx, key):
print("was cached")
return func_op
print("not cached ")
func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public)
cache(ctx, key, func_op)
return func_op
Expand Down
Loading