Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/braket/aws/aws_emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from braket.device_schema.ionq import IonqDeviceCapabilities
from braket.device_schema.iqm import IqmDeviceCapabilities
from braket.device_schema.rigetti import RigettiDeviceCapabilities
from braket.emulation.emulation_passes import ValidationPass
from braket.emulation.emulation_passes.gate_device_passes import (
ConnectivityValidator,
GateConnectivityValidator,
GateValidator,
QubitCountValidator,
RigettiRxArgsValidator,
)


Expand All @@ -33,26 +35,43 @@ def qubit_count_validator(properties: DeviceCapabilities) -> QubitCountValidator
return QubitCountValidator(qubit_count)


def gate_validator(properties: DeviceCapabilities) -> GateValidator:
def gate_validator(
properties: DeviceCapabilities,
) -> Union[GateValidator, Iterable[ValidationPass]]:
"""
Create a GateValidator pass which defines what supported and native gates are allowed in a
program based on the provided device properties.
Create a pass (or multiple passes) that checks that the gate operations used in a circuit
are supported by the device capabilities and gate operations used in verbatim circuits are
native to the device capabilities.

Args:
properties (DeviceCapabilities): QPU Device Capabilities object with a
QHP-specific schema.

Returns:
GateValidator: An emulator pass that checks that a circuit only uses supported gates and
verbatim circuits only use native gates.
Union[GateValidator, Iterable[ValidationPass]]: The resulting gate validation passes based
on the device capabilities
"""

return _gate_validator(properties)


@singledispatch
def _gate_validator(properties: DeviceCapabilities) -> GateValidator:
supported_gates = properties.action[DeviceActionType.OPENQASM].supportedOperations
native_gates = properties.paradigm.nativeGateSet

return GateValidator(supported_gates=supported_gates, native_gates=native_gates)


@_gate_validator.register(RigettiDeviceCapabilities)
def _(properties: RigettiDeviceCapabilities) -> Iterable[ValidationPass]:
supported_gates = properties.action[DeviceActionType.OPENQASM].supportedOperations
native_gates = properties.paradigm.nativeGateSet

gate_validator = GateValidator(supported_gates=supported_gates, native_gates=native_gates)
return [gate_validator, RigettiRxArgsValidator()]


def connectivity_validator(
properties: DeviceCapabilities, connectivity_graph: DiGraph
) -> ConnectivityValidator:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
from braket.emulation.emulation_passes.gate_device_passes.qubit_count_validator import ( # noqa: F401 E501
QubitCountValidator,
)
from braket.emulation.emulation_passes.gate_device_passes.rigetti_rx_args_validator import ( # noqa: F401 E501
RigettiRxArgsValidator,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import math

from braket.circuits import Circuit, FreeParameter, gates
from braket.circuits.compiler_directives import EndVerbatimBox, StartVerbatimBox
from braket.emulation.emulation_passes import ValidationPass


class RigettiRxArgsValidator(ValidationPass[Circuit]):
_ALLOWED_ANGLES = [math.pi, -math.pi, math.pi / 2, -math.pi / 2]

def validate(self, program: Circuit) -> None:
"""
Validates that the only angles used in a verbatim Rx gate are -pi, pi, -pi/2, or pi/2.

Args:
program (Circuit): The braket circuit to validate.

Raises:
ValueError: If an Rx gate used in a verbatim subcircuit uses an unallowed angle.
"""
idx = 0
while idx < len(program.instructions):
instruction = program.instructions[idx]
if isinstance(instruction.operator, StartVerbatimBox):
idx += 1
while idx < len(program.instructions) and not isinstance(
program.instructions[idx].operator, EndVerbatimBox
):
instruction = program.instructions[idx]
if isinstance(instruction.operator, gates.Rx):
angle = instruction.operator.angle
if not isinstance(angle, FreeParameter):
if angle not in self._ALLOWED_ANGLES:
raise ValueError(
f"Invalid RX angle '{angle}' with verbatim usage on Rigetti"
"device. Valid angles are (-π, -π/2, π/2, π).",
)
idx += 1
idx += 1

def __eq__(self, other):
return isinstance(other, RigettiRxArgsValidator)
3 changes: 3 additions & 0 deletions test/unit_tests/braket/aws/test_aws_emulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
GateConnectivityValidator,
GateValidator,
QubitCountValidator,
RigettiRxArgsValidator,
)

REGION = "us-west-1"
Expand Down Expand Up @@ -603,6 +604,7 @@ def test_rigetti_emulator(rigetti_device, rigetti_target_noise_model):
supported_gates=["H", "X", "CNot", "CZ", "Rx", "Ry", "YY"],
native_gates=["cz", "prx", "cphaseshift"],
),
RigettiRxArgsValidator(),
ConnectivityValidator(
nx.from_edgelist([(0, 1), (0, 2), (1, 0), (2, 0)], create_using=nx.DiGraph())
),
Expand All @@ -620,6 +622,7 @@ def test_rigetti_emulator(rigetti_device, rigetti_target_noise_model):
)
),
]
print(emulator._emulator_passes)
assert emulator._emulator_passes == target_emulator_passes


Expand Down
39 changes: 39 additions & 0 deletions test/unit_tests/braket/emulation/test_rigetti_rx_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import pytest

from braket.circuits import Circuit
from braket.emulation.emulation_passes.gate_device_passes import RigettiRxArgsValidator


@pytest.mark.parametrize(
"circuit",
[
Circuit(),
Circuit().rx(0, np.pi / 2),
Circuit().rx(0, np.pi / 7),
Circuit()
.add_verbatim_box(Circuit().rx(2, np.pi).rx(0, -np.pi).rx(3, np.pi / 2).rx(1, -np.pi / 2))
.ry(4, np.pi / 9),
],
)
def test_valid_circuits(circuit):
try:
RigettiRxArgsValidator().validate(circuit)
except ValueError as e:
pytest.fail("Failed Valid Rigetti RX Args Validation: " + repr(e))


@pytest.mark.parametrize(
"circuit",
[
Circuit().add_verbatim_box(Circuit().rx(0, np.pi / 2).rx(2, np.pi / 4)),
Circuit().ry(0, np.pi / 8).add_verbatim_box(Circuit().rx(4, 0)),
Circuit()
.i(range(5))
.add_verbatim_box(Circuit().rx(range(5), -np.pi / 2))
.add_verbatim_box(Circuit().rx(0, np.pi / 2 + 1e-4)),
],
)
def test_invalid_circuits(circuit):
with pytest.raises(ValueError):
RigettiRxArgsValidator().validate(circuit)