Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ enzyme=v0.0.203

# For a custom PL version, update the package version here and at
# 'doc/requirements.txt'
pennylane=0.44.0-dev50
pennylane=0.44.0.dev54

# For a custom LQ/LK version, update the package version here and at
# 'doc/requirements.txt'
Expand Down
3 changes: 2 additions & 1 deletion frontend/catalyst/device/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
is_controllable,
is_differentiable,
is_invertible,
is_quantum_gate,
is_supported,
)
from catalyst.jax_tracer import HybridOpRegion, has_nested_tapes
Expand Down Expand Up @@ -227,7 +228,7 @@ def catalyst_acceptance(
if match and is_controllable(op.base, capabilities):
return match

elif is_supported(op, capabilities):
elif is_supported(op, capabilities) and is_quantum_gate(op):
return op.name

return None
Expand Down
4 changes: 4 additions & 0 deletions frontend/catalyst/device/op_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def is_supported(op: Operator, capabilities: DeviceCapabilities) -> bool:
"""Check whether an operation is supported by the device."""
return op.name in capabilities.operations

def is_quantum_gate(op: Operator) -> bool:
"""Check whether an operation is a quantum gate."""
return isinstance(op, qml.operation.Gate)
# and isinstance(op, qml.operation.DynamicGate)

def _is_grad_recipe_same_as_catalyst(op):
"""Checks that the grad_recipe for the op matches the hard coded one in Catalyst."""
Expand Down
49 changes: 2 additions & 47 deletions frontend/catalyst/device/qjit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,44 +60,6 @@
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

RUNTIME_OPERATIONS = [
"CNOT",
"ControlledPhaseShift",
"CRot",
"CRX",
"CRY",
"CRZ",
"CSWAP",
"CY",
"CZ",
"Hadamard",
"Identity",
"IsingXX",
"IsingXY",
"IsingYY",
"IsingZZ",
"SingleExcitation",
"DoubleExcitation",
"ISWAP",
"MultiRZ",
"PauliX",
"PauliY",
"PauliZ",
"PCPhase",
"PhaseShift",
"PSWAP",
"QubitUnitary",
"Rot",
"RX",
"RY",
"RZ",
"S",
"SWAP",
"T",
"Toffoli",
"GlobalPhase",
]

RUNTIME_OBSERVABLES = [
"Identity",
"PauliX",
Expand All @@ -113,12 +75,6 @@

RUNTIME_MPS = ["ExpectationMP", "SampleMP", "VarianceMP", "CountsMP", "StateMP", "ProbabilityMP"]

# The runtime interface does not care about specific gate properties, so set them all to True.
RUNTIME_OPERATIONS = {
op: OperatorProperties(invertible=True, controllable=True, differentiable=True)
for op in RUNTIME_OPERATIONS
}

RUNTIME_OBSERVABLES = {
obs: OperatorProperties(invertible=True, controllable=True, differentiable=True)
for obs in RUNTIME_OBSERVABLES
Expand Down Expand Up @@ -227,9 +183,8 @@ def get_qjit_device_capabilities(target_capabilities: DeviceCapabilities) -> Dev
qjit_capabilities = deepcopy(target_capabilities)

# Intersection of gates and observables supported by the device and by Catalyst runtime.
qjit_capabilities.operations = intersect_operations(
target_capabilities.operations, RUNTIME_OPERATIONS
)
qjit_capabilities.operations = target_capabilities.operations

qjit_capabilities.observables = intersect_operations(
target_capabilities.observables, RUNTIME_OBSERVABLES
)
Expand Down
7 changes: 6 additions & 1 deletion frontend/catalyst/from_plxpr/qfunc_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from pennylane.ftqc.primitives import measure_in_basis_prim as plxpr_measure_in_basis_prim
from pennylane.measurements import CountsMP

from catalyst.device.op_support import is_quantum_gate
from catalyst.jax_extras import jaxpr_pad_consts
from catalyst.jax_primitives import (
AbstractQbit,
Expand Down Expand Up @@ -183,9 +184,13 @@
if (fn := _special_op_bind_call.get(type(op))) is not None:
bind_fn = partial(fn, hyperparameters=op.hyperparameters)
else:
bind_fn = qinst_p.bind
if not is_quantum_gate(op):
raise CompileError(
f"Operation {op.name} with hyperparameters {list(op.hyperparameters.keys())} "
"is not compatible with quantum instructions."
)

out_qubits = bind_fn(

Check notice on line 193 in frontend/catalyst/from_plxpr/qfunc_interpreter.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/from_plxpr/qfunc_interpreter.py#L193

Possibly using variable 'bind_fn' before assignment (possibly-used-before-assignment)
*[*in_qubits, *op.data, *in_ctrl_qubits, *control_values],
op=op.name,
qubits_len=len(op.wires),
Expand Down