diff --git a/frontend/catalyst/from_plxpr/control_flow.py b/frontend/catalyst/from_plxpr/control_flow.py index 1a42c3bdc9..5b657a0c1d 100644 --- a/frontend/catalyst/from_plxpr/control_flow.py +++ b/frontend/catalyst/from_plxpr/control_flow.py @@ -21,6 +21,7 @@ import jax from jax.extend.core import ClosedJaxpr from jax.interpreters.partial_eval import convert_constvars_jaxpr +from pennylane.capture.primitives import subroutine_prim from pennylane.capture.primitives import cond_prim as plxpr_cond_prim from pennylane.capture.primitives import for_loop_prim as plxpr_for_loop_prim from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim @@ -473,3 +474,35 @@ def handle_while_loop( # Return only the output values that match the plxpr output values return outvals + +@PLxPRToQuantumJaxprInterpreter.register_primitive(subroutine_prim) +def handle_subroutine(self, *invals, call_jaxpr, fn): + + self.init_qreg.insert_all_dangling_qubits() + + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + invals, self.qubit_index_recorder, self.init_qreg + ) + args_plus_qreg = [ + *invals, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], + self.init_qreg.get(), + ] + closed_jaxpr = ClosedJaxpr(call_jaxpr, ()) + f = partial(_calling_convention, self, closed_jaxpr, outer_dynqreg_handlers=dynalloced_qregs) + converted_jaxpr = jax.make_jaxpr(f)(*args_plus_qreg) + consts = converted_jaxpr.consts + j = converted_jaxpr.jaxpr + no_consts_jaxpr = j.replace(constvars=(), invars=j.constvars + j.invars) + + outvals = subroutine_prim.bind(*consts, *args_plus_qreg, call_jaxpr=no_consts_jaxpr, fn=fn) + + # Output structure: + # First a list of dynamically allocated qregs, then the global qreg + # Update the current qreg and remove it from the output values. + self.init_qreg.set(outvals.pop()) + for dyn_qreg in reversed(dynalloced_qregs): + dyn_qreg.set(outvals.pop()) + + # Return only the output values that match the plxpr output values + return outvals \ No newline at end of file diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 0285a7a23a..f49ea1ed44 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -134,6 +134,7 @@ ) from pennylane.capture.primitives import jacobian_prim as pl_jac_prim +from pennylane.capture.primitives import subroutine_prim as pl_subroutine_prim from catalyst.compiler import get_lib_path from catalyst.jax_extras import ( @@ -584,6 +585,7 @@ def _quantum_kernel_lowering(ctx, *args, call_jaxpr, qnode, pipeline=None): Returns: List[mlir.Value] corresponding """ + print("in quantum kernel lowering") assert isinstance(qnode, qml.QNode), "This function expects qnodes" if pipeline is None: pipeline = tuple() @@ -614,7 +616,6 @@ def _func_lowering(ctx, *args, call_jaxpr, fn): call_op = create_call_op(ctx, func_op, *args) return call_op.results - # # Decomp rule # @@ -2844,6 +2845,7 @@ def subroutine_lowering(*args, **kwargs): (grad_p, _grad_lowering), (pl_jac_prim, _capture_grad_lowering), (func_p, _func_lowering), + (pl_subroutine_prim, _func_lowering), (jvp_p, _jvp_lowering), (vjp_p, _vjp_lowering), (adjoint_p, _adjoint_lowering), diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 078cafc422..0433cbb7d3 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -143,9 +143,13 @@ def get_or_create_funcop(ctx, callable_, call_jaxpr, pipeline, metadata=None, pu if metadata is None: metadata = tuple() key = (callable_, *metadata, *pipeline) + print(key) + print("cache: ", ctx.module_context.cached_primitive_lowerings) if callable_ is not None: if func_op := get_cached(ctx, key): + print("was cached") return func_op + print("not cached ") func_op = lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=public) cache(ctx, key, func_op) return func_op