diff --git a/.dep-versions b/.dep-versions index f7e68582d9..88f6c9ebf3 100644 --- a/.dep-versions +++ b/.dep-versions @@ -10,7 +10,7 @@ enzyme=v0.0.203 # For a custom PL version, update the package version here and at # 'doc/requirements.txt' -pennylane=0.44.0-dev50 +pennylane=0.44.0.dev54 # For a custom LQ/LK version, update the package version here and at # 'doc/requirements.txt' diff --git a/frontend/catalyst/device/decomposition.py b/frontend/catalyst/device/decomposition.py index 22c66deac3..bddf79cf07 100644 --- a/frontend/catalyst/device/decomposition.py +++ b/frontend/catalyst/device/decomposition.py @@ -42,6 +42,7 @@ is_controllable, is_differentiable, is_invertible, + is_quantum_gate, is_supported, ) from catalyst.jax_tracer import HybridOpRegion, has_nested_tapes @@ -227,7 +228,7 @@ def catalyst_acceptance( if match and is_controllable(op.base, capabilities): return match - elif is_supported(op, capabilities): + elif is_supported(op, capabilities) and is_quantum_gate(op): return op.name return None diff --git a/frontend/catalyst/device/op_support.py b/frontend/catalyst/device/op_support.py index f86f0cc2c9..099d19d7fa 100644 --- a/frontend/catalyst/device/op_support.py +++ b/frontend/catalyst/device/op_support.py @@ -40,6 +40,10 @@ def is_supported(op: Operator, capabilities: DeviceCapabilities) -> bool: """Check whether an operation is supported by the device.""" return op.name in capabilities.operations +def is_quantum_gate(op: Operator) -> bool: + """Check whether an operation is a quantum gate.""" + return isinstance(op, qml.operation.Gate) + # and isinstance(op, qml.operation.DynamicGate) def _is_grad_recipe_same_as_catalyst(op): """Checks that the grad_recipe for the op matches the hard coded one in Catalyst.""" diff --git a/frontend/catalyst/device/qjit_device.py b/frontend/catalyst/device/qjit_device.py index 77aa57e5aa..c4da627ce4 100644 --- a/frontend/catalyst/device/qjit_device.py +++ b/frontend/catalyst/device/qjit_device.py @@ -60,44 +60,6 @@ logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) -RUNTIME_OPERATIONS = [ - "CNOT", - "ControlledPhaseShift", - "CRot", - "CRX", - "CRY", - "CRZ", - "CSWAP", - "CY", - "CZ", - "Hadamard", - "Identity", - "IsingXX", - "IsingXY", - "IsingYY", - "IsingZZ", - "SingleExcitation", - "DoubleExcitation", - "ISWAP", - "MultiRZ", - "PauliX", - "PauliY", - "PauliZ", - "PCPhase", - "PhaseShift", - "PSWAP", - "QubitUnitary", - "Rot", - "RX", - "RY", - "RZ", - "S", - "SWAP", - "T", - "Toffoli", - "GlobalPhase", -] - RUNTIME_OBSERVABLES = [ "Identity", "PauliX", @@ -113,12 +75,6 @@ RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"] -# The runtime interface does not care about specific gate properties, so set them all to True. -RUNTIME_OPERATIONS = { - op: OperatorProperties(invertible=True, controllable=True, differentiable=True) - for op in RUNTIME_OPERATIONS -} - RUNTIME_OBSERVABLES = { obs: OperatorProperties(invertible=True, controllable=True, differentiable=True) for obs in RUNTIME_OBSERVABLES @@ -227,9 +183,8 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev qjit_capabilities = deepcopy(target_capabilities) # Intersection of gates and observables supported by the device and by Catalyst runtime. - qjit_capabilities.operations = intersect_operations( - target_capabilities.operations, RUNTIME_OPERATIONS - ) + qjit_capabilities.operations = target_capabilities.operations + qjit_capabilities.observables = intersect_operations( target_capabilities.observables, RUNTIME_OBSERVABLES ) diff --git a/frontend/catalyst/from_plxpr/qfunc_interpreter.py b/frontend/catalyst/from_plxpr/qfunc_interpreter.py index 59c6e516c7..9ef6f66224 100644 --- a/frontend/catalyst/from_plxpr/qfunc_interpreter.py +++ b/frontend/catalyst/from_plxpr/qfunc_interpreter.py @@ -36,6 +36,7 @@ from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim from pennylane.measurements import CountsMP +from catalyst.device.op_support import is_quantum_gate from catalyst.jax_extras import jaxpr_pad_consts from catalyst.jax_primitives import ( AbstractQbit, @@ -183,7 +184,11 @@ def interpret_operation(self, op, is_adjoint=False, control_values=(), control_w if (fn := _special_op_bind_call.get(type(op))) is not None: bind_fn = partial(fn, hyperparameters=op.hyperparameters) else: - bind_fn = qinst_p.bind + if not is_quantum_gate(op): + raise CompileError( + f"Operation {op.name} with hyperparameters {list(op.hyperparameters.keys())} " + "is not compatible with quantum instructions." + ) out_qubits = bind_fn( *[*in_qubits, *op.data, *in_ctrl_qubits, *control_values],