diff --git a/mlir/include/CMakeLists.txt b/mlir/include/CMakeLists.txt index d211bbd77e..32e440502a 100644 --- a/mlir/include/CMakeLists.txt +++ b/mlir/include/CMakeLists.txt @@ -7,5 +7,6 @@ add_subdirectory(Mitigation) add_subdirectory(PauliFrame) add_subdirectory(QEC) add_subdirectory(Quantum) +add_subdirectory(RefQuantum) add_subdirectory(RTIO) add_subdirectory(Test) diff --git a/mlir/include/Quantum/IR/QuantumAttrDefs.td b/mlir/include/Quantum/IR/QuantumAttrDefs.td index 133462bae1..716c60b2aa 100644 --- a/mlir/include/Quantum/IR/QuantumAttrDefs.td +++ b/mlir/include/Quantum/IR/QuantumAttrDefs.td @@ -38,5 +38,6 @@ def NamedObservable : I32EnumAttr<"NamedObservable", def NamedObservableAttr : EnumAttr; +def PauliWord : TypedArrayAttrBase; #endif // QUANTUM_ATTR_DEFS diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 165aa880d1..8b6029207a 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -534,8 +534,6 @@ def CustomOp : UnitaryGate_Op<"custom", [DifferentiableGate, NoMemoryEffect, let hasVerifier = 1; } -def PauliWord : TypedArrayAttrBase; - def PauliRotOp : UnitaryGate_Op<"paulirot", [DifferentiableGate, NoMemoryEffect, AttrSizedOperandSegments, AttrSizedResultSegments]> { let summary = "Apply a Pauli Product Rotation"; diff --git a/mlir/include/RefQuantum/CMakeLists.txt b/mlir/include/RefQuantum/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/mlir/include/RefQuantum/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/RefQuantum/IR/CMakeLists.txt b/mlir/include/RefQuantum/IR/CMakeLists.txt new file mode 100644 index 0000000000..0e0146a8bf --- /dev/null +++ b/mlir/include/RefQuantum/IR/CMakeLists.txt @@ -0,0 +1,9 @@ +add_mlir_dialect(RefQuantumOps ref_quantum) +add_mlir_interface(RefQuantumInterfaces) +add_mlir_doc(RefQuantumDialect RefQuantumDialect RefQuantum/ -gen-dialect-doc -gen-op-doc) +add_mlir_doc(RefQuantumOps RefQuantumOps RefQuantum/ -gen-op-doc) +add_mlir_doc(RefQuantumInterfaces RefQuantumInterfaces RefQuantum/ -gen-op-interface-docs) + +set(LLVM_TARGET_DEFINITIONS RefQuantumOps.td) +mlir_tablegen(RefQuantumAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=ref_quantum) +mlir_tablegen(RefQuantumAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=ref_quantum) diff --git a/mlir/include/RefQuantum/IR/RefQuantumDialect.h b/mlir/include/RefQuantum/IR/RefQuantumDialect.h new file mode 100644 index 0000000000..bb0787e090 --- /dev/null +++ b/mlir/include/RefQuantum/IR/RefQuantumDialect.h @@ -0,0 +1,23 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/IR/Dialect.h" + +//===----------------------------------------------------------------------===// +// RefQuantum dialect declarations. +//===----------------------------------------------------------------------===// + +#include "RefQuantum/IR/RefQuantumOpsDialect.h.inc" diff --git a/mlir/include/RefQuantum/IR/RefQuantumDialect.td b/mlir/include/RefQuantum/IR/RefQuantumDialect.td new file mode 100644 index 0000000000..205e6f7fed --- /dev/null +++ b/mlir/include/RefQuantum/IR/RefQuantumDialect.td @@ -0,0 +1,57 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REF_QUANTUM_DIALECT +#define REF_QUANTUM_DIALECT + +include "mlir/IR/DialectBase.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/AttrTypeBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// RefQuantum dialect definition. +//===----------------------------------------------------------------------===// + +def RefQuantumDialect : Dialect { + let summary = "Reference semantics quantum dialect."; + let description = [{ + A supplemental dialect to the core quantum dialect. + + Quantum operations in this dialect follow reference semantics (as opposed to qubit value + semantics in the core quantum dialect): the targets of quantum operations in this dialect + are all integer wire indices. + }]; + + /// This is the namespace of the dialect in MLIR, which is used as a prefix for types and ops. + let name = "ref_quantum"; + + /// This is the C++ namespace in which the dialect and all of its sub-components are placed. + let cppNamespace = "::catalyst::ref_quantum"; + + let dependentDialects = [ + "quantum::QuantumDialect" + ]; +} + + +//===----------------------------------------------------------------------===// +// RefQuantum dialect base operation. +//===----------------------------------------------------------------------===// + +class RefQuantum_Op traits = []> : + Op; + + +#endif // REF_QUANTUM_DIALECT diff --git a/mlir/include/RefQuantum/IR/RefQuantumInterfaces.h b/mlir/include/RefQuantum/IR/RefQuantumInterfaces.h new file mode 100644 index 0000000000..1b80f0edd9 --- /dev/null +++ b/mlir/include/RefQuantum/IR/RefQuantumInterfaces.h @@ -0,0 +1,25 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "mlir/IR/OpDefinition.h" + +//===----------------------------------------------------------------------===// +// RefQuantum interface declarations. +//===----------------------------------------------------------------------===// + +#include "RefQuantum/IR/RefQuantumInterfaces.h.inc" diff --git a/mlir/include/RefQuantum/IR/RefQuantumInterfaces.td b/mlir/include/RefQuantum/IR/RefQuantumInterfaces.td new file mode 100644 index 0000000000..84604186f8 --- /dev/null +++ b/mlir/include/RefQuantum/IR/RefQuantumInterfaces.td @@ -0,0 +1,136 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REFQUANTUM_INTERFACES +#define REFQUANTUM_INTERFACES + +include "mlir/IR/OpBase.td" + +def QuantumOperation : OpInterface<"QuantumOperation"> { + let description = [{ + A base class for all quantum operations that can be considered actions on wires. + The actions do not have to be unitary. For example, the SetState operations also falls + under this class. + }]; + + let cppNamespace = "::catalyst::ref_quantum"; + + let methods = [ + InterfaceMethod< + "Return all operands which are considered input wires (including controls).", + "std::vector", "getWireOperands" + >, + InterfaceMethod< + "Set all operands which are considered input wires (including controls).", + "void", "setWireOperands", (ins "mlir::ValueRange":$replacements) + > + ]; +} + +def QuantumGate : OpInterface<"QuantumGate", [QuantumOperation]> { + let description = [{ + A base class for all unitary quantum operations. + These operations can be inverted and controlled. + }]; + + let cppNamespace = "::catalyst::ref_quantum"; + + let methods = [ + InterfaceMethod< + "Return operands which are considered non-controlled input wire values.", + "mlir::ValueRange", "getNonCtrlWireOperands" + >, + InterfaceMethod< + "Set all operands which are considered non-controlled input wire values.", + "void", "setNonCtrlWireOperands", (ins "mlir::ValueRange":$replacements) + >, + InterfaceMethod< + "Return all operands which are considered controlling input wire values.", + "mlir::ValueRange", "getCtrlWireOperands" + >, + InterfaceMethod< + "Set all operands which are considered controlling input wire values.", + "void", "setCtrlWireOperands", (ins "mlir::ValueRange":$replacements) + >, + InterfaceMethod< + "Return all operands which are considered controlling input boolean values.", + "mlir::ValueRange", "getCtrlValueOperands" + >, + InterfaceMethod< + "Set all operands which are considered controlling input boolean values.", + "void", "setCtrlValueOperands", (ins "mlir::ValueRange":$replacements) + >, + InterfaceMethod< + "Return adjoint flag.", + "bool", "getAdjointFlag" + >, + InterfaceMethod< + "Set adjoint flag.", + "void", "setAdjointFlag", (ins "bool":$adjoint) + > + ]; + + let verify = [{ + auto gate = mlir::cast($_op); + + if (gate.getCtrlValueOperands().size() != gate.getCtrlWireOperands().size()) { + return $_op->emitError() << + "number of controlling wires in input (" << + gate.getCtrlWireOperands().size() << ") " << + "and controlling values (" << + gate.getCtrlValueOperands().size() << + ") must be the same"; + } + + // STL methods to check duplicates will all complain about `mlir::Value` not having a + // comparison method defined, since they all use map/set, which is hash-based + // So we just do it manually + std::vector wireOperands = gate.getWireOperands(); + for (size_t i=0; i < wireOperands.size(); i++) { + for (size_t j=i+1; j < wireOperands.size(); j++) { + if (wireOperands[i] == wireOperands[j]) { + return $_op->emitError() << "all wires on a quantum gate must be " << + "distinct (including controls)"; + } + } + } + + return mlir::success(); + }]; +} + +def ParametrizedGate : OpInterface<"ParametrizedGate", [QuantumGate]> { + let description = [{ + This interface provides a generic way to interact with parametrized + quantum instructions. These are quantum operations with arbitrary + classical gate parameters. + }]; + + let cppNamespace = "::catalyst::ref_quantum"; + + let methods = [ + InterfaceMethod< + "Return all operands which are considered gate parameters.", + "mlir::ValueRange", "getAllParams" + >, + InterfaceMethod< + "Return the param operand at the requested index.", + "mlir::Value", "getParam", (ins "size_t":$idx), /*methodBody=*/[{}], + /*defaultImplementation=*/[{ return mlir::cast($_op).getAllParams()[idx]; }] + >, + ]; +} + + +#endif // REFQUANTUM_INTERFACES diff --git a/mlir/include/RefQuantum/IR/RefQuantumOps.h b/mlir/include/RefQuantum/IR/RefQuantumOps.h new file mode 100644 index 0000000000..389f6288b1 --- /dev/null +++ b/mlir/include/RefQuantum/IR/RefQuantumOps.h @@ -0,0 +1,36 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "llvm/ADT/StringRef.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/LogicalResult.h" + +#include "Quantum/IR/QuantumDialect.h" +#include "RefQuantum/IR/RefQuantumDialect.h" +#include "RefQuantum/IR/RefQuantumInterfaces.h" + +//===----------------------------------------------------------------------===// +// RefQuantum ops declarations. +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "RefQuantum/IR/RefQuantumOps.h.inc" diff --git a/mlir/include/RefQuantum/IR/RefQuantumOps.td b/mlir/include/RefQuantum/IR/RefQuantumOps.td new file mode 100644 index 0000000000..408bd516df --- /dev/null +++ b/mlir/include/RefQuantum/IR/RefQuantumOps.td @@ -0,0 +1,453 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REF_QUANTUM_OPS +#define REF_QUANTUM_OPS + +include "mlir/IR/EnumAttr.td" +include "mlir/IR/OpBase.td" +// TODO: I probably need bufferization. Remove if ended up not using it. +// Or do I? If this dialect is only supposed to be at a high level, +// i.e. connection to lower parts of the pipeline are done via +// ref dialect ---> value dialect ---> bufferization ---> .... +include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" + +include "Quantum/IR/QuantumAttrDefs.td" +include "Quantum/IR/QuantumDialect.td" +include "Quantum/IR/QuantumTypes.td" +include "RefQuantum/IR/RefQuantumDialect.td" +include "RefQuantum/IR/RefQuantumInterfaces.td" + +//===----------------------------------------------------------------------===// +// RefQuantum dialect operations. +//===----------------------------------------------------------------------===// + +// ----- + +class Memory_Op traits = []> : RefQuantum_Op; + +// TODO: (dynamic) alloc and deallocs +// Note that extracts and inserts are not a thing in reference semantics + +// ----- + +class Gate_Op traits = []> : + RefQuantum_Op { + + code extraBaseClassDeclaration = [{ + std::vector getWireOperands() { + std::vector values; + values.insert(values.end(), getWires().begin(), getWires().end()); + return values; + } + + void setWireOperands(mlir::ValueRange replacements) { + mlir::MutableOperandRange wires = getWiresMutable(); + assert(wires.size() == replacements.size() && "must provide values for all wires"); + wires.assign(replacements); + } + }]; + + let extraClassDeclaration = extraBaseClassDeclaration; +} + +class UnitaryGate_Op traits = []> : + Gate_Op { + + code extraBaseClassDeclaration = [{ + std::vector getWireOperands() { + std::vector values; + values.insert(values.end(), getWires().begin(), getWires().end()); + values.insert(values.end(), getCtrlWires().begin(), getCtrlWires().end()); + return values; + } + + void setWireOperands(mlir::ValueRange replacements) { + mlir::MutableOperandRange wires = getWiresMutable(); + mlir::MutableOperandRange ctrls = getCtrlWiresMutable(); + assert(wires.size() + ctrls.size() == replacements.size() && + "must provide values for all wires (including controls)"); + + wires.assign(replacements.take_front(wires.size())); + ctrls.assign(replacements.take_back(ctrls.size())); + } + + mlir::ValueRange getNonCtrlWireOperands() { + return getWires(); + } + + void setNonCtrlWireOperands(mlir::ValueRange replacements) { + mlir::MutableOperandRange wires = getWiresMutable(); + assert(wires.size() == replacements.size() && + "must provide values for all non-ctrl wire values"); + wires.assign(replacements); + } + + mlir::ValueRange getCtrlWireOperands() { + return getCtrlWires(); + } + + void setCtrlWireOperands(mlir::ValueRange replacements) { + mlir::MutableOperandRange ctrlWires = getCtrlWiresMutable(); + assert(ctrlWires.size() == replacements.size() && + "must provide values for all ctrl wire values"); + ctrlWires.assign(replacements); + } + + mlir::ValueRange getCtrlValueOperands() { + return getCtrlValues(); + } + + void setCtrlValueOperands(mlir::ValueRange replacements) { + mlir::MutableOperandRange ctrlValues = getCtrlValuesMutable(); + assert(ctrlValues.size() == replacements.size() && + "must provide values for all control values"); + ctrlValues.assign(replacements); + } + + bool getAdjointFlag() { + return getAdjoint(); + } + + void setAdjointFlag(bool adjoint) { + if (adjoint) { + (*this)->setAttr("adjoint", mlir::UnitAttr::get(this->getContext())); + } else { + (*this)->removeAttr("adjoint"); + } + }; + }]; + + let extraClassDeclaration = extraBaseClassDeclaration; +} + +def SetStateOp : Gate_Op<"set_state"> { + let summary = "Set state to a complex vector."; + let description = [{ + This operation is useful for simulators implementing state preparation. + Instead of decomposing state preparation into multiple operations, this + operation shortcuts all of that into a single operation. + + .. note:: + This op is not bufferizable at the moment, and must take in the state as a tensor + instead of a memref. To execute reference semantics quantum dialect, please convert + to the value semantics quantum dialect, where ops are bufferizable and lowerable to + LLVM IR. + }]; + + let arguments = (ins + 1DTensorOf<[Complex]>:$in_state, + Variadic:$wires + ); + + let assemblyFormat = [{ + `(` $in_state `)` $wires attr-dict `:` type(operands) + }]; + +} + + +def SetBasisStateOp : Gate_Op<"set_basis_state"> { + let summary = "Set basis state."; + let description = [{ + This operation is useful for simulators implementing set basis state. + Instead of decomposing basis state into multiple operations, this + operation shortcuts all of that into a single operation. + This signature matches the one in pennylane-lightning which expects + only a single integer as opposed to a binary digit. + + .. note:: + This op is not bufferizable at the moment, and must take in the basis state as a tensor + instead of a memref. To execute reference semantics quantum dialect, please convert + to the value semantics quantum dialect, where ops are bufferizable and lowerable to + LLVM IR. + }]; + + let arguments = (ins + 1DTensorOf<[I1]>:$basis_state, + Variadic:$wires + ); + + let assemblyFormat = [{ + `(` $basis_state`)` $wires attr-dict `:` type(operands) + }]; +} + + +def CustomOp : UnitaryGate_Op<"custom", [ParametrizedGate, AttrSizedOperandSegments]> { + let summary = "A generic quantum gate on n qubits with m floating point parameters."; + let description = [{ + }]; + + let arguments = (ins + Variadic:$params, + Variadic:$wires, // Perhaps add a static version? But not super high priority on my list. We have arith.constant. + StrAttr:$gate_name, + UnitAttr:$adjoint, + Variadic:$ctrl_wires, // Ditto regarding static. + Variadic:$ctrl_values + ); + + // TODO: add convenience builders + + let assemblyFormat = [{ + $gate_name `(` $params `)` $wires (`adj` $adjoint^)? attr-dict + ( `ctrls` `(` $ctrl_wires^ `)` )? + ( `ctrlvals` `(` $ctrl_values^ `)` )? + `:` type($wires) (`ctrls` type($ctrl_wires)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getParams(); + } + }]; + + // TODO: anything needed here? + // let hasCanonicalizeMethod = 1; + // let hasVerifier = 1; +} + +def PauliRotOp : UnitaryGate_Op<"paulirot", [ParametrizedGate, AttrSizedOperandSegments]> { + let summary = "Apply a Pauli Product Rotation"; + let description = [{ + The `ref_quantum.paulirot` operation applies a rotation around a Pauli product + operator to the state-vector. + The arguments are the rotation angle `angle`, a string representing the + Pauli product operator, and a set of qubits the operation acts on. + Note that this operation is currently not excutable. There isn't a valid + lowering path to the LLVM IR. + }]; + + let arguments = (ins + F64:$angle, + PauliWord:$pauli_product, + Variadic:$wires, + UnitAttr:$adjoint, + Variadic:$ctrl_wires, + Variadic:$ctrl_values + ); + + let assemblyFormat = [{ + $pauli_product `(` $angle `)` $wires (`adj` $adjoint^)? attr-dict + ( `ctrls` `(` $ctrl_wires^ `)` )? + ( `ctrlvals` `(` $ctrl_values^ `)` )? + `:` type($wires) (`ctrls` type($ctrl_wires)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getODSOperands(0); // `angle` is the 0th operand + } + }]; + + let hasVerifier = 1; +} + +def GlobalPhaseOp : UnitaryGate_Op<"gphase", [ParametrizedGate, AttrSizedOperandSegments]> { + let summary = "Global Phase."; + + let description = [{ + Applies global phase to the current system. + }]; + + let arguments = (ins + F64:$params, + UnitAttr:$adjoint, + Variadic:$ctrl_wires, + Variadic:$ctrl_values + ); + + let assemblyFormat = [{ + `(` $params `)` (`adj` $adjoint^)? attr-dict + ( `ctrls` `(` $ctrl_wires^ `)` )? + ( `ctrlvals` `(` $ctrl_values^ `)` )? + `:` type($params) (`ctrls` type($ctrl_wires)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getODSOperands(0); // `params` is the 0-th operand + } + + // Simulate missing operands and results for the default impl of the quantum gate interface. + mlir::OperandRange getWires() { + return {getOperands().begin(), getOperands().begin()}; + } + mlir::MutableOperandRange getWiresMutable() { + return mlir::MutableOperandRange(getOperation(), 0, 0); + } + }]; +} + +def MultiRZOp : UnitaryGate_Op<"multirz", [ParametrizedGate, AttrSizedOperandSegments]> { + let summary = "Apply an arbitrary multi Z rotation"; + let description = [{ + The `ref_quantum.multirz` operation applies an arbitrary multi Z rotation to the state-vector. + The arguments are the rotation angle `theta` and a set of wires the operation acts on. + + .. note:: + This operation is one of the few quantum operations that is not applied via + ``ref_quantum.custom``. The reason for this is that its quantum dialect counterpart + needs to be handled in a special way during the lowering due to its C function being + variadic on the number of qubits. + }]; + + let arguments = (ins + F64:$theta, + Variadic:$wires, + UnitAttr:$adjoint, + Variadic:$ctrl_wires, + Variadic:$ctrl_values + ); + + let assemblyFormat = [{ + `(` $theta `)` $wires (`adj` $adjoint^)? attr-dict + ( `ctrls` `(` $ctrl_wires^ `)` )? + ( `ctrlvals` `(` $ctrl_values^ `)` )? + `:` type($wires) (`ctrls` type($ctrl_wires)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getODSOperands(0); // `theta` is the 0-th operand + } + }]; + + // TODO + // let hasCanonicalizeMethod = 1; +} + +def PCPhaseOp : UnitaryGate_Op<"pcphase", [ParametrizedGate, AttrSizedOperandSegments]> { + let summary = "Apply a projector-controlled phase gate"; + let description = [{ + This gate is built from simpler gates like `PhaseShift` and `PauliX` and acts on a group + of wires and takes a rotation angle. + It also takes another number, an integer called `dim`, which defines a specific part + of the quantum state. The gate then applies a positive phase shift to a portion of the + state defined by `dim`. At the same time, it applies a negative phase shift to the rest + of the state. + + .. note:: + This operation is one of the few quantum operations that is not applied via + ``quantum.custom``. The reason for this is that its quantum dialect counterpart needs + to be handled in a special way during the lowering due to its C function being variadic + on the number of qubits. + + .. note:: + `dim` is currently captured as a float number for compatibility with + runtime and device integration. + + }]; + + let arguments = (ins + F64:$theta, + F64:$dim, + Variadic:$wires, + UnitAttr:$adjoint, + Variadic:$ctrl_wires, + Variadic:$ctrl_values + ); + + let assemblyFormat = [{ + `(` $theta `,` $dim `)` $wires (`adj` $adjoint^)? attr-dict + ( `ctrls` `(` $ctrl_wires^ `)` )? + ( `ctrlvals` `(` $ctrl_values^ `)` )? + `:` type($wires) (`ctrls` type($ctrl_wires)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getODSOperands(0); // `theta` is the 0-th operand + } + }]; + + // TODO + // let hasCanonicalizeMethod = 1; +} + +def QubitUnitaryOp : UnitaryGate_Op<"unitary", [ParametrizedGate, AttrSizedOperandSegments]> { + let summary = "Apply an arbitrary fixed unitary matrix"; + let description = [{ + The `ref_quantum.unitary` operation applies an arbitrary fixed unitary matrix to the + state-vector. The arguments are a set of qubits and a 2-dim matrix of complex numbers + that represents a Unitary matrix of size 2^(number of qubits) * 2^(number of qubits). + + .. note:: + This op is not bufferizable at the moment, and must take in the matrix as a tensor + instead of a memref. To execute reference semantics quantum dialect, please convert + to the value semantics quantum dialect, where ops are bufferizable and lowerable to + LLVM IR. + }]; + + let arguments = (ins + 2DTensorOf<[Complex]>:$matrix, + Variadic:$wires, + UnitAttr:$adjoint, + Variadic:$ctrl_wires, + Variadic:$ctrl_values + ); + + let assemblyFormat = [{ + `(` $matrix `:` type($matrix) `)` $wires (`adj` $adjoint^)? attr-dict + ( `ctrls` `(` $ctrl_wires^ `)` )? + ( `ctrlvals` `(` $ctrl_values^ `)` )? + `:` type($wires) (`ctrls` type($ctrl_wires)^ )? + }]; + + let extraClassDeclaration = extraBaseClassDeclaration # [{ + mlir::ValueRange getAllParams() { + return getODSOperands(0); // `matrix` is the first operand + } + }]; + + let hasVerifier = 1; +} +// ----- + +// Observable ops are not meaningful on their own: their purpose is to be sent into a measurement +// Hence they are Pure, i.e. removable if no users + +class Observable_Op traits = []> : + RefQuantum_Op; + +def NamedObsOp : Observable_Op<"namedobs"> { + let summary = "Define a Named observable for use in measurements"; + let description = [{ + The `ref_quantum.namedobs` operation defines a quantum observable to be used by measurement + processes. The specific observable defined here represents one of 5 named observables + {Identity, PauliX, PauliY, PauliZ, Hadamard} on a qubit. The arguments are a wire to + measure as well as an encoding operator for the qubit as an integer between 0-4. + }]; + + let arguments = (ins + I64:$wire, + NamedObservableAttr:$type + ); + + let results = (outs + ObservableType:$obs + ); + + let assemblyFormat = [{ + $wire `[` $type `]` attr-dict `:` type(results) + }]; +} + +// ----- + +// class Measurement_Op traits = []> : +// Quantum_Op; + +#endif // REF_QUANTUM_OPS diff --git a/mlir/include/RefQuantum/Transforms/CMakeLists.txt b/mlir/include/RefQuantum/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..d597f46814 --- /dev/null +++ b/mlir/include/RefQuantum/Transforms/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name RefQuantum) +add_public_tablegen_target(MLIRRefQuantumPassIncGen) +add_mlir_doc(Passes RefQuantumPasses ./ -gen-pass-doc) diff --git a/mlir/include/RefQuantum/Transforms/Passes.h b/mlir/include/RefQuantum/Transforms/Passes.h new file mode 100644 index 0000000000..74031d0161 --- /dev/null +++ b/mlir/include/RefQuantum/Transforms/Passes.h @@ -0,0 +1,27 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Pass/Pass.h" + +namespace catalyst { +namespace ref_quantum { + +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "RefQuantum/Transforms/Passes.h.inc" + +} // namespace ref_quantum +} // namespace catalyst diff --git a/mlir/include/RefQuantum/Transforms/Passes.td b/mlir/include/RefQuantum/Transforms/Passes.td new file mode 100644 index 0000000000..3da6ec7b89 --- /dev/null +++ b/mlir/include/RefQuantum/Transforms/Passes.td @@ -0,0 +1,29 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef REF_QUANTUM_PASSES +#define REF_QUANTUM_PASSES + +include "mlir/Pass/PassBase.td" + +def RQHelloWorldPass : Pass<"rq-hw"> { + let summary = "RefQuantum Hello world!"; + + // let dependentDialects = [ + // "scf::SCFDialect", + // "catalyst::quantum::QuantumDialect" + // ]; +} + +#endif // REF_QUANTUM_PASSES diff --git a/mlir/include/RefQuantum/Transforms/Patterns.h b/mlir/include/RefQuantum/Transforms/Patterns.h new file mode 100644 index 0000000000..92a5ceaf26 --- /dev/null +++ b/mlir/include/RefQuantum/Transforms/Patterns.h @@ -0,0 +1,25 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/IR/PatternMatch.h" + +namespace catalyst { +namespace ref_quantum { + +void populateRQHelloWorldPatterns(mlir::RewritePatternSet &patterns); + +} // namespace ref_quantum +} // namespace catalyst diff --git a/mlir/include/RegisterAllPasses.h b/mlir/include/RegisterAllPasses.h index 1ff2397b0b..966ed90b6f 100644 --- a/mlir/include/RegisterAllPasses.h +++ b/mlir/include/RegisterAllPasses.h @@ -23,6 +23,7 @@ #include "QEC/Transforms/Passes.h" #include "Quantum/Transforms/Passes.h" #include "RTIO/Transforms/Passes.h" +#include "RefQuantum/Transforms/Passes.h" #include "Test/Transforms/Passes.h" #include "hlo-extensions/Transforms/Passes.h" @@ -39,6 +40,7 @@ inline void registerAllPasses() pauli_frame::registerPauliFramePasses(); qec::registerQECPasses(); quantum::registerQuantumPasses(); + ref_quantum::registerRefQuantumPasses(); rtio::registerRTIOPasses(); test::registerTestPasses(); } diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index 6f874d8502..96d352f957 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -9,5 +9,6 @@ add_subdirectory(Mitigation) add_subdirectory(PauliFrame) add_subdirectory(QEC) add_subdirectory(Quantum) +add_subdirectory(RefQuantum) add_subdirectory(RTIO) add_subdirectory(Test) diff --git a/mlir/lib/Driver/CMakeLists.txt b/mlir/lib/Driver/CMakeLists.txt index 9147c10f78..6a211f625d 100644 --- a/mlir/lib/Driver/CMakeLists.txt +++ b/mlir/lib/Driver/CMakeLists.txt @@ -34,6 +34,8 @@ set(LIBS catalyst-transforms MLIRQuantum quantum-transforms + MLIRRefQuantum + ref-quantum-transforms MLIRQEC qec-transforms MLIRGradient diff --git a/mlir/lib/RefQuantum/CMakeLists.txt b/mlir/lib/RefQuantum/CMakeLists.txt new file mode 100644 index 0000000000..9f57627c32 --- /dev/null +++ b/mlir/lib/RefQuantum/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/RefQuantum/IR/CMakeLists.txt b/mlir/lib/RefQuantum/IR/CMakeLists.txt new file mode 100644 index 0000000000..f4600de447 --- /dev/null +++ b/mlir/lib/RefQuantum/IR/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_library(MLIRRefQuantum + RefQuantumDialect.cpp + RefQuantumInterfaces.cpp + RefQuantumOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/RefQuantum + + DEPENDS + MLIRQuantumInterfacesIncGen + MLIRRefQuantumOpsIncGen +) diff --git a/mlir/lib/RefQuantum/IR/RefQuantumDialect.cpp b/mlir/lib/RefQuantum/IR/RefQuantumDialect.cpp new file mode 100644 index 0000000000..0378038357 --- /dev/null +++ b/mlir/lib/RefQuantum/IR/RefQuantumDialect.cpp @@ -0,0 +1,35 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/IR/Builders.h" + +#include "RefQuantum/IR/RefQuantumDialect.h" +#include "RefQuantum/IR/RefQuantumOps.h" + +using namespace mlir; +using namespace catalyst::ref_quantum; + +//===----------------------------------------------------------------------===// +// RefQuantum dialect definitions. +//===----------------------------------------------------------------------===// + +#include "RefQuantum/IR/RefQuantumOpsDialect.cpp.inc" + +void RefQuantumDialect::initialize() +{ + addOperations< +#define GET_OP_LIST +#include "RefQuantum/IR/RefQuantumOps.cpp.inc" + >(); +} diff --git a/mlir/lib/RefQuantum/IR/RefQuantumInterfaces.cpp b/mlir/lib/RefQuantum/IR/RefQuantumInterfaces.cpp new file mode 100644 index 0000000000..9a7a858218 --- /dev/null +++ b/mlir/lib/RefQuantum/IR/RefQuantumInterfaces.cpp @@ -0,0 +1,24 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "RefQuantum/IR/RefQuantumInterfaces.h" + +using namespace mlir; +using namespace catalyst::ref_quantum; + +//===----------------------------------------------------------------------===// +// RefQuantum interface definitions. +//===----------------------------------------------------------------------===// + +#include "RefQuantum/IR/RefQuantumInterfaces.cpp.inc" diff --git a/mlir/lib/RefQuantum/IR/RefQuantumOps.cpp b/mlir/lib/RefQuantum/IR/RefQuantumOps.cpp new file mode 100644 index 0000000000..6019cad6b4 --- /dev/null +++ b/mlir/lib/RefQuantum/IR/RefQuantumOps.cpp @@ -0,0 +1,80 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/IR/OpImplementation.h" +#include "llvm/ADT/StringSet.h" + +#include "RefQuantum/IR/RefQuantumDialect.h" +#include "RefQuantum/IR/RefQuantumOps.h" + +using namespace mlir; +using namespace catalyst::ref_quantum; + +//===----------------------------------------------------------------------===// +// RefQuantum op definitions. +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "RefQuantum/IR/RefQuantumOps.cpp.inc" + +namespace catalyst::ref_quantum { + +// Utils +static LogicalResult verifyTensorResult(Type ty, int64_t length0, int64_t length1) +{ + ShapedType tensor = cast(ty); + if (!tensor.hasStaticShape() || tensor.getShape().size() != 2 || + tensor.getShape()[0] != length0 || tensor.getShape()[1] != length1) { + return failure(); + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// RefQuantum op verifiers. +//===----------------------------------------------------------------------===// + +static const mlir::StringSet<> validPauliWords = {"X", "Y", "Z", "I"}; + +LogicalResult PauliRotOp::verify() +{ + size_t pauliWordLength = getPauliProduct().size(); + size_t numWires = getWires().size(); + if (pauliWordLength != numWires) { + return emitOpError() << "length of Pauli word (" << pauliWordLength + << ") and number of wires (" << numWires << ") must be the same"; + } + + if (!llvm::all_of(getPauliProduct(), [](mlir::Attribute attr) { + auto pauliStr = llvm::cast(attr); + return validPauliWords.contains(pauliStr.getValue()); + })) { + return emitOpError() << "Only \"X\", \"Y\", \"Z\", and \"I\" are valid Pauli words."; + } + + return success(); +} + +LogicalResult QubitUnitaryOp::verify() +{ + size_t dim = 1 << getWires().size(); + if (failed(verifyTensorResult(cast(getMatrix().getType()), dim, dim))) { + return emitOpError("The Unitary matrix must be of size 2^(num_wires) * 2^(num_wires)"); + } + + return success(); +} + +} // namespace catalyst::ref_quantum diff --git a/mlir/lib/RefQuantum/Transforms/CMakeLists.txt b/mlir/lib/RefQuantum/Transforms/CMakeLists.txt new file mode 100644 index 0000000000..2b3c3ffcb6 --- /dev/null +++ b/mlir/lib/RefQuantum/Transforms/CMakeLists.txt @@ -0,0 +1,27 @@ +set(LIBRARY_NAME ref-quantum-transforms) + + +file(GLOB SRC + HelloWorldPatterns.cpp + hello_world.cpp +) + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +set(LIBS + ${dialect_libs} + ${conversion_libs} + MLIRRefQuantum +) + +set(DEPENDS + MLIRRefQuantumPassIncGen +) + +add_mlir_library(${LIBRARY_NAME} STATIC ${SRC} LINK_LIBS PRIVATE ${LIBS} DEPENDS ${DEPENDS}) +target_compile_features(${LIBRARY_NAME} PUBLIC cxx_std_20) + +target_include_directories(${LIBRARY_NAME} PUBLIC + . + ${PROJECT_SOURCE_DIR}/include + ${CMAKE_BINARY_DIR}/include) diff --git a/mlir/lib/RefQuantum/Transforms/HelloWorldPatterns.cpp b/mlir/lib/RefQuantum/Transforms/HelloWorldPatterns.cpp new file mode 100644 index 0000000000..382f9a7aad --- /dev/null +++ b/mlir/lib/RefQuantum/Transforms/HelloWorldPatterns.cpp @@ -0,0 +1,51 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "rq-hello-world" + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" + +// #include "RefQuantum/IR/RefQuantumOps.h" +#include "Quantum/IR/QuantumDialect.h" +#include "Quantum/IR/QuantumOps.h" +#include "RefQuantum/Transforms/Patterns.h" + +using namespace mlir; + +namespace { + +struct RQHelloWorldPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(catalyst::quantum::CustomOp op, + PatternRewriter &rewriter) const override + { + llvm::errs() << "hello world! Visiting " << op << "\n"; + return success(); + } +}; + +} // namespace + +namespace catalyst { +namespace ref_quantum { + +void populateRQHelloWorldPatterns(RewritePatternSet &patterns) +{ + patterns.add(patterns.getContext()); +} + +} // namespace ref_quantum +} // namespace catalyst diff --git a/mlir/lib/RefQuantum/Transforms/hello_world.cpp b/mlir/lib/RefQuantum/Transforms/hello_world.cpp new file mode 100644 index 0000000000..60d9c8a45d --- /dev/null +++ b/mlir/lib/RefQuantum/Transforms/hello_world.cpp @@ -0,0 +1,51 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DEBUG_TYPE "rq-hello-world" + +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "RefQuantum/IR/RefQuantumOps.h" +#include "RefQuantum/Transforms/Patterns.h" + +using namespace mlir; + +namespace catalyst { +namespace ref_quantum { + +#define GEN_PASS_DECL_RQHELLOWORLDPASS +#define GEN_PASS_DEF_RQHELLOWORLDPASS +#include "RefQuantum/Transforms/Passes.h.inc" + +struct RQHelloWorldPass : impl::RQHelloWorldPassBase { + using RQHelloWorldPassBase::RQHelloWorldPassBase; + + void runOnOperation() final + { + Operation *mod = getOperation(); + + RewritePatternSet patterns(&getContext()); + populateRQHelloWorldPatterns(patterns); + + if (failed(applyPatternsGreedily(mod, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +} // namespace ref_quantum +} // namespace catalyst diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt index b4519ab273..4f52d7acdb 100644 --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -20,6 +20,7 @@ set(DIALECT_TESTS_DEPEND set(TEST_SUITES Quantum + RefQuantum Gradient Mitigation Catalyst diff --git a/mlir/test/RefQuantum/DialectTest/UnitTests.mlir b/mlir/test/RefQuantum/DialectTest/UnitTests.mlir new file mode 100644 index 0000000000..ad34a034a0 --- /dev/null +++ b/mlir/test/RefQuantum/DialectTest/UnitTests.mlir @@ -0,0 +1,196 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test basic parsing. +// +// RUN: quantum-opt --split-input-file --verify-diagnostics %s + + +func.func @test_set_state(%arg0 : tensor<2xcomplex>, %w0: i64) { + ref_quantum.set_state(%arg0) %w0 : tensor<2xcomplex>, i64 + return +} + +// ----- + +func.func @test_basis_state(%arg0 : tensor<1xi1>, %w0: i64) { + ref_quantum.set_basis_state(%arg0) %w0 : tensor<1xi1>, i64 + return +} + +// ----- + +func.func @test_custom_op(%w0: i64, %w1: i64, %w2: i64, %w3: i64, %param0: f64, %param1: f64) { + + // Basic + ref_quantum.custom "Hadamard"() %w0 : i64 + ref_quantum.custom "CNOT"() %w0, %w1 : i64, i64 + + // With params + ref_quantum.custom "RX"(%param0) %w0 : i64 + ref_quantum.custom "Rot"(%param0, %param1, %param1) %w0 : i64 + + // With adjoint + ref_quantum.custom "PauliX"() %w0 adj : i64 + ref_quantum.custom "CNOT"() %w0, %w1 adj : i64, i64 + + // With control + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + ref_quantum.custom "PauliZ"() %w0 ctrls (%w1) ctrlvals (%true) : i64 ctrls i64 + ref_quantum.custom "RY"(%param0) %w0 ctrls (%w1) ctrlvals (%true) : i64 ctrls i64 + ref_quantum.custom "SWAP"() %w0, %w1 ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + // With params, control and adjoint altogether + ref_quantum.custom "Rot"(%param0, %param1, %param1) %w0, %w1 adj ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + return +} + + +// ----- + +func.func @test_paulirot_op(%w0: i64, %w1: i64, %w2: i64, %w3: i64, %angle: f64) { + + // Basic + ref_quantum.paulirot ["Z"](%angle) %w0 : i64 + ref_quantum.paulirot ["Z", "X"](%angle) %w0, %w1 : i64, i64 + + // With adjoint + ref_quantum.paulirot ["Z", "X", "I"](%angle) %w0, %w1, %w2 adj : i64, i64, i64 + + // With control + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + ref_quantum.paulirot ["Y", "I"](%angle) %w0, %w1 ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + // With params, control and adjoint altogether + ref_quantum.paulirot ["I", "X"](%angle) %w0, %w1 adj ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + return +} + +// ----- + +func.func @test_global_phase(%w0: i64, %cv: i1, %param: f64) { + + // Basic + ref_quantum.gphase(%param) : f64 + + // With adjoint + ref_quantum.gphase(%param) adj : f64 + + // With control + ref_quantum.gphase(%param) ctrls (%w0) ctrlvals (%cv) : f64 ctrls i64 + + // With control and adjoint + ref_quantum.gphase(%param) adj ctrls (%w0) ctrlvals (%cv) : f64 ctrls i64 + + return +} + +// ----- + +func.func @test_multirz(%w0: i64, %w1: i64, %w2: i64, %w3: i64, %theta: f64) { + + // Basic + ref_quantum.multirz (%theta) %w0 : i64 + ref_quantum.multirz (%theta) %w0, %w1 : i64, i64 + + // With adjoint + ref_quantum.multirz (%theta) %w0, %w1, %w2 adj : i64, i64, i64 + + // With control + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + ref_quantum.multirz (%theta) %w0 ctrls (%w1) ctrlvals (%true) : i64 ctrls i64 + ref_quantum.multirz (%theta) %w0, %w1 ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + // With control and adjoint + ref_quantum.multirz (%theta) %w0, %w1 adj ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + return +} + +// ----- + +func.func @test_pcphase(%w0: i64, %w1: i64, %w2: i64, %w3: i64, %theta: f64, %dim: f64) { + + // Basic + ref_quantum.pcphase (%theta, %dim) %w0 : i64 + ref_quantum.pcphase (%theta, %dim) %w0, %w1, %w2 : i64, i64, i64 + + // With adjoint + ref_quantum.pcphase (%theta, %dim) %w0, %w1 adj : i64, i64 + + // With control + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + ref_quantum.pcphase (%theta, %dim) %w0 ctrls (%w1) ctrlvals (%true) : i64 ctrls i64 + ref_quantum.pcphase (%theta, %dim) %w0, %w1 ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + // With control and adjoint + ref_quantum.pcphase (%theta, %dim) %w0, %w1 adj ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + return +} + +// ----- + +func.func @test_qubit_unitary(%w0: i64, %w1: i64, %w2: i64, %w3: i64) { + + // Basic + %matrix22 = tensor.empty() : tensor<2x2xcomplex> + %matrix44 = tensor.empty() : tensor<4x4xcomplex> + + ref_quantum.unitary (%matrix22 : tensor<2x2xcomplex>) %w0 : i64 + ref_quantum.unitary (%matrix44 : tensor<4x4xcomplex>) %w0, %w1 : i64, i64 + + // With adjoint + ref_quantum.unitary (%matrix22 : tensor<2x2xcomplex>) %w0 adj : i64 + + // With control + %true = llvm.mlir.constant (1 : i1) :i1 + %false = llvm.mlir.constant (0 : i1) :i1 + ref_quantum.unitary (%matrix22 : tensor<2x2xcomplex>) %w0 ctrls (%w1) ctrlvals (%true) : i64 ctrls i64 + ref_quantum.unitary (%matrix44 : tensor<4x4xcomplex>) %w0, %w1 ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + // With control and adjoint + ref_quantum.unitary (%matrix44 : tensor<4x4xcomplex>) %w0, %w1 adj ctrls (%w2, %w3) ctrlvals (%true, %false) : i64, i64 ctrls i64, i64 + + return +} + +// ----- + +func.func @test_namedobs_op(%w0: i64) { + + %ox = ref_quantum.namedobs %w0 [ PauliX] : !quantum.obs + %oy = ref_quantum.namedobs %w0 [ PauliY] : !quantum.obs + %oz = ref_quantum.namedobs %w0 [ PauliZ] : !quantum.obs + %oi = ref_quantum.namedobs %w0 [ Identity] : !quantum.obs + %oh = ref_quantum.namedobs %w0 [ Hadamard] : !quantum.obs + + return +} + +// ----- + +func.func @test_expval_circuit() -> f64 { + %0 = arith.constant 0 : i64 + ref_quantum.custom "Hadamard"() %0 : i64 + %obs = ref_quantum.namedobs %0 [ PauliX] : !quantum.obs + %expval = quantum.expval %obs : f64 + return %expval : f64 +} diff --git a/mlir/test/RefQuantum/DialectTest/VerifierTests.mlir b/mlir/test/RefQuantum/DialectTest/VerifierTests.mlir new file mode 100644 index 0000000000..953c25e346 --- /dev/null +++ b/mlir/test/RefQuantum/DialectTest/VerifierTests.mlir @@ -0,0 +1,175 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Test verifiers. +// +// RUN: quantum-opt --split-input-file --verify-diagnostics %s + +// ----- + +func.func @test_controlled1(%w0: i64, %w1: i64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{number of controlling wires in input (1) and controlling values (2) must be the same}} + ref_quantum.custom "PauliZ"() %w0 ctrls (%w1) ctrlvals (%true, %true) : i64 ctrls i64 + return +} + +// ----- + +func.func @test_controlled2(%w0: i64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{number of controlling wires in input (0) and controlling values (1) must be the same}} + ref_quantum.custom "PauliZ"() %w0 ctrls () ctrlvals (%true) : i64 + return +} + +// ----- + +func.func @test_controlled3(%w0: i64, %w1: i64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{number of controlling wires in input (1) and controlling values (0) must be the same}} + ref_quantum.custom "PauliZ"() %w0 ctrls (%w1) ctrlvals () : i64 ctrls i64 + return +} + +// ----- + +func.func @test_duplicate_wires1(%w0: i64) { + // expected-error@+1 {{all wires on a quantum gate must be distinct (including controls)}} + ref_quantum.custom "CNOT"() %w0, %w0 : i64, i64 + return +} + +// ----- + +func.func @test_duplicate_wires2(%w0: i64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{all wires on a quantum gate must be distinct (including controls)}} + ref_quantum.custom "PauliX"() %w0 ctrls (%w0) ctrlvals (%true) : i64 ctrls i64 + return +} + +// ----- + +func.func @test_paulirot_length_mismatch(%w0: i64, %angle: f64) { + // expected-error@+1 {{length of Pauli word (2) and number of wires (1) must be the same}} + ref_quantum.paulirot ["Z", "X"](%angle) %w0 : i64 + return +} + +// ----- + +func.func @test_paulirot_bad_pauli_word(%w0: i64, %angle: f64) { + // expected-error@+1 {{Only "X", "Y", "Z", and "I" are valid Pauli words.}} + ref_quantum.paulirot ["bad"](%angle) %w0 : i64 + return +} + +// ----- + +func.func @test_paulirot_control(%w0: i64, %w1: i64, %angle: f64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{number of controlling wires in input (1) and controlling values (2) must be the same}} + ref_quantum.paulirot ["Z"](%angle) %w0 ctrls (%w1) ctrlvals (%true, %true) : i64 ctrls i64 + return +} + +// ----- + +func.func @test_paulirot_duplicate_wires(%w0: i64, %angle: f64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{all wires on a quantum gate must be distinct (including controls)}} + ref_quantum.paulirot ["Z", "I"](%angle) %w0, %w0 : i64, i64 + return +} + +// ----- + +func.func @test_gphase_control(%w0: i64, %param: f64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{number of controlling wires in input (1) and controlling values (2) must be the same}} + ref_quantum.gphase(%param) ctrls (%w0) ctrlvals (%true, %true) : f64 ctrls i64 + return +} + +// ----- + +func.func @test_multirz_control(%w0: i64, %w1: i64, %theta: f64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{number of controlling wires in input (1) and controlling values (2) must be the same}} + ref_quantum.multirz(%theta) %w0 ctrls (%w1) ctrlvals (%true, %true) : i64 ctrls i64 + return +} + +// ----- + +func.func @test_multirz_duplicate_wires(%w0: i64, %theta: f64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{all wires on a quantum gate must be distinct (including controls)}} + ref_quantum.multirz(%theta) %w0, %w0 : i64, i64 + return +} + +// ----- + +func.func @test_pcphase_control(%w0: i64, %w1: i64, %theta: f64, %dim: f64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{number of controlling wires in input (1) and controlling values (2) must be the same}} + ref_quantum.pcphase(%theta, %dim) %w0 ctrls (%w1) ctrlvals (%true, %true) : i64 ctrls i64 + return +} + +// ----- + +func.func @test_pcphase_duplicate_wires(%w0: i64, %theta: f64, %dim: f64) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{all wires on a quantum gate must be distinct (including controls)}} + ref_quantum.pcphase(%theta, %dim) %w0, %w0 : i64, i64 + return +} + +// ----- + +func.func @test_unitary_bad_matrix_shape(%w0: i64, %matrix: tensor<37x42xcomplex>) { + // expected-error@+1 {{The Unitary matrix must be of size 2^(num_wires) * 2^(num_wires)}} + ref_quantum.unitary (%matrix : tensor<37x42xcomplex>) %w0 : i64 + return +} + +// ----- + +func.func @test_unitary_control(%w0: i64, %w1: i64, %matrix: tensor<2x2xcomplex>) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{number of controlling wires in input (1) and controlling values (2) must be the same}} + ref_quantum.unitary(%matrix: tensor<2x2xcomplex>) %w0 ctrls (%w1) ctrlvals (%true, %true) : i64 ctrls i64 + return +} + +// ----- + +func.func @test_unitary_duplicate_wires(%w0: i64, %matrix: tensor<4x4xcomplex>) { + %true = llvm.mlir.constant (1 : i1) :i1 + // expected-error@+1 {{all wires on a quantum gate must be distinct (including controls)}} + ref_quantum.unitary(%matrix: tensor<4x4xcomplex>) %w0, %w0 : i64, i64 + return +} + +// ----- + +func.func @test_namedobs_op_bad_attribute(%w0: i64) { + // expected-error@+2 {{expected catalyst::quantum::NamedObservable to be one of: Identity, PauliX, PauliY, PauliZ, Hadamard}} + // expected-error@+1 {{failed to parse NamedObservableAttr parameter 'value' which is to be a `catalyst::quantum::NamedObservable`}} + %0 = ref_quantum.namedobs %w0 [ bad] : !quantum.obs + return +} diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index 39bd44ecf4..c51a720e17 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -30,6 +30,8 @@ set(LIBS catalyst-stablehlo-transforms MLIRQuantum quantum-transforms + MLIRRefQuantum + ref-quantum-transforms MLIRQEC qec-transforms MLIRGradient diff --git a/mlir/tools/quantum-lsp-server/CMakeLists.txt b/mlir/tools/quantum-lsp-server/CMakeLists.txt index 517fc8a1d9..26e693f3b6 100644 --- a/mlir/tools/quantum-lsp-server/CMakeLists.txt +++ b/mlir/tools/quantum-lsp-server/CMakeLists.txt @@ -8,6 +8,7 @@ set(LIBS MLIRRegisterAllDialects MLIRCatalyst MLIRQuantum + MLIRRefQuantum MLIRQEC MLIRGradient MLIRMBQC diff --git a/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp b/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp index d29335b392..871e354196 100644 --- a/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp +++ b/mlir/tools/quantum-lsp-server/quantum-lsp-server.cpp @@ -25,6 +25,7 @@ #include "QEC/IR/QECDialect.h" #include "Quantum/IR/QuantumDialect.h" #include "RTIO/IR/RTIODialect.h" +#include "RefQuantum/IR/RefQuantumDialect.h" #include "stablehlo/dialect/Register.h" @@ -34,6 +35,7 @@ int main(int argc, char **argv) mlir::registerAllDialects(registry); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/mlir/tools/quantum-opt/CMakeLists.txt b/mlir/tools/quantum-opt/CMakeLists.txt index c3057fe4c5..5376c961a8 100644 --- a/mlir/tools/quantum-opt/CMakeLists.txt +++ b/mlir/tools/quantum-opt/CMakeLists.txt @@ -14,6 +14,8 @@ set(LIBS catalyst-stablehlo-transforms MLIRQuantum quantum-transforms + MLIRRefQuantum + ref-quantum-transforms MLIRQEC qec-transforms MLIRGradient diff --git a/mlir/tools/quantum-opt/quantum-opt.cpp b/mlir/tools/quantum-opt/quantum-opt.cpp index bacdd4ccab..8d370cec4c 100644 --- a/mlir/tools/quantum-opt/quantum-opt.cpp +++ b/mlir/tools/quantum-opt/quantum-opt.cpp @@ -45,6 +45,7 @@ #include "Quantum/IR/QuantumDialect.h" #include "Quantum/Transforms/BufferizableOpInterfaceImpl.h" #include "RTIO/IR/RTIODialect.h" +#include "RefQuantum/IR/RefQuantumDialect.h" #include "RegisterAllPasses.h" namespace test { @@ -66,6 +67,7 @@ int main(int argc, char **argv) mlir::func::registerAllExtensions(registry); registry.insert(); registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index b0ad6d5de1..ff88fe21c2 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -10,4 +10,5 @@ function(add_catalyst_unittest test_dirname) endfunction() add_subdirectory(Example) +add_subdirectory(RefQuantum) add_subdirectory(Utils) diff --git a/mlir/unittests/RefQuantum/CMakeLists.txt b/mlir/unittests/RefQuantum/CMakeLists.txt new file mode 100644 index 0000000000..bb95d17fac --- /dev/null +++ b/mlir/unittests/RefQuantum/CMakeLists.txt @@ -0,0 +1,12 @@ +add_catalyst_unittest(CatalystRefQuantumUnitTests + InterfaceTest.cpp +) + +target_link_libraries(CatalystRefQuantumUnitTests PRIVATE + MLIRArithDialect + MLIRFuncDialect + MLIRRefQuantum + MLIRQuantum + MLIRIR + MLIRParser +) diff --git a/mlir/unittests/RefQuantum/InterfaceTest.cpp b/mlir/unittests/RefQuantum/InterfaceTest.cpp new file mode 100644 index 0000000000..8fb263a6e6 --- /dev/null +++ b/mlir/unittests/RefQuantum/InterfaceTest.cpp @@ -0,0 +1,262 @@ +// Copyright 2025 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest.h" + +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/AsmParser/AsmParserState.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser/Parser.h" +#include "llvm/Support/SourceMgr.h" + +#include "RefQuantum/IR/RefQuantumInterfaces.h" +#include "RefQuantum/IR/RefQuantumOps.h" + +using namespace mlir; + +namespace { + +TEST(InterfaceTests, Getters) +{ + std::string moduleStr = R"mlir( +func.func @f(%w0: i64, %w1: i64, %param: f64, %bool: i1) { + ref_quantum.custom "Rot"(%param, %param) %w0 adj ctrls (%w1) ctrlvals (%bool) : i64 ctrls i64 + return +} + )mlir"; + + // Parsing boilerplate + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + ParserConfig config(&context, /*verifyAfterParse=*/false); + OwningOpRef mod = parseSourceString(moduleStr, config); + + // Parse ops + func::FuncOp f = *(*mod).getOps().begin(); + catalyst::ref_quantum::CustomOp customOp = *f.getOps().begin(); + + Block &bb = f.getCallableRegion()->front(); + auto args = bb.getArguments(); + + // Run checks + std::vector wireOperands = customOp.getWireOperands(); + ASSERT_TRUE(wireOperands.size() == 2 && wireOperands[0] == args[0] && + wireOperands[1] == args[1]); + + ValueRange nonCtrlWireOperands = customOp.getNonCtrlWireOperands(); + ASSERT_TRUE(nonCtrlWireOperands.size() == 1 && nonCtrlWireOperands[0] == args[0]); + + ValueRange ctrlWireOperands = customOp.getCtrlWireOperands(); + ASSERT_TRUE(ctrlWireOperands.size() == 1 && ctrlWireOperands[0] == args[1]); + + ValueRange ctrlValueOperands = customOp.getCtrlValueOperands(); + ASSERT_TRUE(ctrlValueOperands.size() == 1 && ctrlValueOperands[0] == args[3]); + + ASSERT_TRUE(customOp.getAdjointFlag()); + + ValueRange allParams = customOp.getAllParams(); + ASSERT_TRUE(allParams.size() == 2 && allParams[0] == args[2] && allParams[1] == args[2]); + + ASSERT_TRUE(customOp.getParam(0) == args[2]); + ASSERT_TRUE(customOp.getParam(1) == args[2]); +} + +TEST(InterfaceTests, setWireOperands) +{ + std::string moduleStr = R"mlir( +func.func @f(%w0: i64, %w1: i64, %bool: i1) { + ref_quantum.custom "Rot"() %w0 ctrls (%w1) ctrlvals (%bool) : i64 ctrls i64 + return +} + )mlir"; + + // Parsing boilerplate + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + ParserConfig config(&context, /*verifyAfterParse=*/false); + OwningOpRef mod = parseSourceString(moduleStr, config); + + // Parse ops + func::FuncOp f = *(*mod).getOps().begin(); + catalyst::ref_quantum::CustomOp customOp = *f.getOps().begin(); + + Block &bb = f.getCallableRegion()->front(); + auto args = bb.getArguments(); + + // Run checks + customOp.setWireOperands({args[1], args[0]}); + std::vector wireOperands = customOp.getWireOperands(); + ASSERT_TRUE(wireOperands.size() == 2 && wireOperands[0] == args[1] && + wireOperands[1] == args[0]); +} + +TEST(InterfaceTests, setNonCtrlWireOperands) +{ + std::string moduleStr = R"mlir( +func.func @f(%w0: i64, %w1: i64, %w2: i64, %bool: i1) { + ref_quantum.custom "Rot"() %w0 ctrls (%w1) ctrlvals (%bool) : i64 ctrls i64 + return +} + )mlir"; + + // Parsing boilerplate + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + ParserConfig config(&context, /*verifyAfterParse=*/false); + OwningOpRef mod = parseSourceString(moduleStr, config); + + // Parse ops + func::FuncOp f = *(*mod).getOps().begin(); + catalyst::ref_quantum::CustomOp customOp = *f.getOps().begin(); + + Block &bb = f.getCallableRegion()->front(); + auto args = bb.getArguments(); + + // Run checks + customOp.setNonCtrlWireOperands({args[2]}); + ValueRange nonCtrlWireOperands = customOp.getNonCtrlWireOperands(); + ASSERT_TRUE(nonCtrlWireOperands.size() == 1 && nonCtrlWireOperands[0] == args[2]); +} + +TEST(InterfaceTests, setCtrlWireOperands) +{ + std::string moduleStr = R"mlir( +func.func @f(%w0: i64, %w1: i64, %w2: i64, %bool: i1) { + ref_quantum.custom "Rot"() %w0 ctrls (%w1) ctrlvals (%bool) : i64 ctrls i64 + return +} + )mlir"; + + // Parsing boilerplate + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + ParserConfig config(&context, /*verifyAfterParse=*/false); + OwningOpRef mod = parseSourceString(moduleStr, config); + + // Parse ops + func::FuncOp f = *(*mod).getOps().begin(); + catalyst::ref_quantum::CustomOp customOp = *f.getOps().begin(); + + Block &bb = f.getCallableRegion()->front(); + auto args = bb.getArguments(); + + // Run checks + customOp.setCtrlWireOperands({args[2]}); + ValueRange ctrlWireOperands = customOp.getCtrlWireOperands(); + ASSERT_TRUE(ctrlWireOperands.size() == 1 && ctrlWireOperands[0] == args[2]); +} + +TEST(InterfaceTests, setCtrlValueOperands) +{ + std::string moduleStr = R"mlir( +func.func @f(%w0: i64, %w1: i64, %bool: i1, %other_bool: i1) { + ref_quantum.custom "Rot"() %w0 ctrls (%w1) ctrlvals (%bool) : i64 ctrls i64 + return +} + )mlir"; + + // Parsing boilerplate + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + ParserConfig config(&context, /*verifyAfterParse=*/false); + OwningOpRef mod = parseSourceString(moduleStr, config); + + // Parse ops + func::FuncOp f = *(*mod).getOps().begin(); + catalyst::ref_quantum::CustomOp customOp = *f.getOps().begin(); + + Block &bb = f.getCallableRegion()->front(); + auto args = bb.getArguments(); + + // Run checks + customOp.setCtrlValueOperands({args[3]}); + ValueRange ctrlValueOperands = customOp.getCtrlValueOperands(); + ASSERT_TRUE(ctrlValueOperands.size() == 1 && ctrlValueOperands[0] == args[3]); +} + +TEST(InterfaceTests, setAdjointFlag) +{ + std::string moduleStr = R"mlir( +func.func @f(%w0: i64) { + ref_quantum.custom "PauliX"() %w0 : i64 + return +} + )mlir"; + + // Parsing boilerplate + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + ParserConfig config(&context, /*verifyAfterParse=*/false); + OwningOpRef mod = parseSourceString(moduleStr, config); + + // Parse ops + func::FuncOp f = *(*mod).getOps().begin(); + catalyst::ref_quantum::CustomOp customOp = *f.getOps().begin(); + + // Run checks + customOp.setAdjointFlag(true); + ASSERT_TRUE(customOp.getAdjointFlag()); + + customOp.setAdjointFlag(false); + ASSERT_TRUE(!customOp.getAdjointFlag()); +} + +TEST(InterfaceTests, globalPhase) +{ + std::string moduleStr = R"mlir( +func.func @f(%w0: i64, %cv: i1, %param: f64) { + ref_quantum.gphase(%param) adj ctrls (%w0) ctrlvals (%cv) : f64 ctrls i64 + return +} + )mlir"; + + // Parsing boilerplate + DialectRegistry registry; + registry.insert(); + MLIRContext context(registry); + ParserConfig config(&context, /*verifyAfterParse=*/false); + OwningOpRef mod = parseSourceString(moduleStr, config); + + // Parse ops + func::FuncOp f = *(*mod).getOps().begin(); + catalyst::ref_quantum::GlobalPhaseOp gphaseOp = + *f.getOps().begin(); + + Block &bb = f.getCallableRegion()->front(); + auto args = bb.getArguments(); + + // Run checks + ValueRange allParams = gphaseOp.getAllParams(); + ASSERT_TRUE(allParams.size() == 1 && allParams[0] == args[2]); + + ASSERT_TRUE(gphaseOp.getAdjointFlag()); + + ValueRange ctrlWireOperands = gphaseOp.getCtrlWireOperands(); + ASSERT_TRUE(ctrlWireOperands.size() == 1 && ctrlWireOperands[0] == args[0]); + + ValueRange ctrlValueOperands = gphaseOp.getCtrlValueOperands(); + ASSERT_TRUE(ctrlValueOperands.size() == 1 && ctrlValueOperands[0] == args[1]); +} + +} // namespace