diff --git a/src/braket/aws/aws_emulation.py b/src/braket/aws/aws_emulation.py index 75ff9ad58..dda9da6b6 100644 --- a/src/braket/aws/aws_emulation.py +++ b/src/braket/aws/aws_emulation.py @@ -8,11 +8,13 @@ from braket.device_schema.ionq import IonqDeviceCapabilities from braket.device_schema.iqm import IqmDeviceCapabilities from braket.device_schema.rigetti import RigettiDeviceCapabilities +from braket.emulation.emulation_passes import ValidationPass from braket.emulation.emulation_passes.gate_device_passes import ( ConnectivityValidator, GateConnectivityValidator, GateValidator, QubitCountValidator, + RigettiRxArgsValidator, ) @@ -33,26 +35,43 @@ def qubit_count_validator(properties: DeviceCapabilities) -> QubitCountValidator return QubitCountValidator(qubit_count) -def gate_validator(properties: DeviceCapabilities) -> GateValidator: +def gate_validator( + properties: DeviceCapabilities, +) -> Union[GateValidator, Iterable[ValidationPass]]: """ - Create a GateValidator pass which defines what supported and native gates are allowed in a - program based on the provided device properties. + Create a pass (or multiple passes) that checks that the gate operations used in a circuit + are supported by the device capabilities and gate operations used in verbatim circuits are + native to the device capabilities. Args: properties (DeviceCapabilities): QPU Device Capabilities object with a QHP-specific schema. Returns: - GateValidator: An emulator pass that checks that a circuit only uses supported gates and - verbatim circuits only use native gates. + Union[GateValidator, Iterable[ValidationPass]]: The resulting gate validation passes based + on the device capabilities """ + return _gate_validator(properties) + + +@singledispatch +def _gate_validator(properties: DeviceCapabilities) -> GateValidator: supported_gates = properties.action[DeviceActionType.OPENQASM].supportedOperations native_gates = properties.paradigm.nativeGateSet return GateValidator(supported_gates=supported_gates, native_gates=native_gates) +@_gate_validator.register(RigettiDeviceCapabilities) +def _(properties: RigettiDeviceCapabilities) -> Iterable[ValidationPass]: + supported_gates = properties.action[DeviceActionType.OPENQASM].supportedOperations + native_gates = properties.paradigm.nativeGateSet + + gate_validator = GateValidator(supported_gates=supported_gates, native_gates=native_gates) + return [gate_validator, RigettiRxArgsValidator()] + + def connectivity_validator( properties: DeviceCapabilities, connectivity_graph: DiGraph ) -> ConnectivityValidator: diff --git a/src/braket/emulation/emulation_passes/gate_device_passes/__init__.py b/src/braket/emulation/emulation_passes/gate_device_passes/__init__.py index 030791be6..65b2fb319 100644 --- a/src/braket/emulation/emulation_passes/gate_device_passes/__init__.py +++ b/src/braket/emulation/emulation_passes/gate_device_passes/__init__.py @@ -10,3 +10,6 @@ from braket.emulation.emulation_passes.gate_device_passes.qubit_count_validator import ( # noqa: F401 E501 QubitCountValidator, ) +from braket.emulation.emulation_passes.gate_device_passes.rigetti_rx_args_validator import ( # noqa: F401 E501 + RigettiRxArgsValidator, +) diff --git a/src/braket/emulation/emulation_passes/gate_device_passes/rigetti_rx_args_validator.py b/src/braket/emulation/emulation_passes/gate_device_passes/rigetti_rx_args_validator.py new file mode 100644 index 000000000..86a49a82d --- /dev/null +++ b/src/braket/emulation/emulation_passes/gate_device_passes/rigetti_rx_args_validator.py @@ -0,0 +1,42 @@ +import math + +from braket.circuits import Circuit, FreeParameter, gates +from braket.circuits.compiler_directives import EndVerbatimBox, StartVerbatimBox +from braket.emulation.emulation_passes import ValidationPass + + +class RigettiRxArgsValidator(ValidationPass[Circuit]): + _ALLOWED_ANGLES = [math.pi, -math.pi, math.pi / 2, -math.pi / 2] + + def validate(self, program: Circuit) -> None: + """ + Validates that the only angles used in a verbatim Rx gate are -pi, pi, -pi/2, or pi/2. + + Args: + program (Circuit): The braket circuit to validate. + + Raises: + ValueError: If an Rx gate used in a verbatim subcircuit uses an unallowed angle. + """ + idx = 0 + while idx < len(program.instructions): + instruction = program.instructions[idx] + if isinstance(instruction.operator, StartVerbatimBox): + idx += 1 + while idx < len(program.instructions) and not isinstance( + program.instructions[idx].operator, EndVerbatimBox + ): + instruction = program.instructions[idx] + if isinstance(instruction.operator, gates.Rx): + angle = instruction.operator.angle + if not isinstance(angle, FreeParameter): + if angle not in self._ALLOWED_ANGLES: + raise ValueError( + f"Invalid RX angle '{angle}' with verbatim usage on Rigetti" + "device. Valid angles are (-π, -π/2, π/2, π).", + ) + idx += 1 + idx += 1 + + def __eq__(self, other): + return isinstance(other, RigettiRxArgsValidator) diff --git a/test/unit_tests/braket/aws/test_aws_emulation.py b/test/unit_tests/braket/aws/test_aws_emulation.py index 53d5ae53a..59db853c4 100644 --- a/test/unit_tests/braket/aws/test_aws_emulation.py +++ b/test/unit_tests/braket/aws/test_aws_emulation.py @@ -36,6 +36,7 @@ GateConnectivityValidator, GateValidator, QubitCountValidator, + RigettiRxArgsValidator, ) REGION = "us-west-1" @@ -603,6 +604,7 @@ def test_rigetti_emulator(rigetti_device, rigetti_target_noise_model): supported_gates=["H", "X", "CNot", "CZ", "Rx", "Ry", "YY"], native_gates=["cz", "prx", "cphaseshift"], ), + RigettiRxArgsValidator(), ConnectivityValidator( nx.from_edgelist([(0, 1), (0, 2), (1, 0), (2, 0)], create_using=nx.DiGraph()) ), @@ -620,6 +622,7 @@ def test_rigetti_emulator(rigetti_device, rigetti_target_noise_model): ) ), ] + print(emulator._emulator_passes) assert emulator._emulator_passes == target_emulator_passes diff --git a/test/unit_tests/braket/emulation/test_rigetti_rx_validator.py b/test/unit_tests/braket/emulation/test_rigetti_rx_validator.py new file mode 100644 index 000000000..5365fac5f --- /dev/null +++ b/test/unit_tests/braket/emulation/test_rigetti_rx_validator.py @@ -0,0 +1,39 @@ +import numpy as np +import pytest + +from braket.circuits import Circuit +from braket.emulation.emulation_passes.gate_device_passes import RigettiRxArgsValidator + + +@pytest.mark.parametrize( + "circuit", + [ + Circuit(), + Circuit().rx(0, np.pi / 2), + Circuit().rx(0, np.pi / 7), + Circuit() + .add_verbatim_box(Circuit().rx(2, np.pi).rx(0, -np.pi).rx(3, np.pi / 2).rx(1, -np.pi / 2)) + .ry(4, np.pi / 9), + ], +) +def test_valid_circuits(circuit): + try: + RigettiRxArgsValidator().validate(circuit) + except ValueError as e: + pytest.fail("Failed Valid Rigetti RX Args Validation: " + repr(e)) + + +@pytest.mark.parametrize( + "circuit", + [ + Circuit().add_verbatim_box(Circuit().rx(0, np.pi / 2).rx(2, np.pi / 4)), + Circuit().ry(0, np.pi / 8).add_verbatim_box(Circuit().rx(4, 0)), + Circuit() + .i(range(5)) + .add_verbatim_box(Circuit().rx(range(5), -np.pi / 2)) + .add_verbatim_box(Circuit().rx(0, np.pi / 2 + 1e-4)), + ], +) +def test_invalid_circuits(circuit): + with pytest.raises(ValueError): + RigettiRxArgsValidator().validate(circuit)