diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index c5916307ca..9289f6c985 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -1,4 +1,4 @@ -# Copyright 2022-2023 Xanadu Quantum Technologies Inc. +# Copyright 2022-2026 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -430,6 +430,7 @@ def get_cli_command(self, tmp_infile_name, output_ir_name, module_name, workspac ) return cmd + # pylint: disable=too-many-branches @debug_logger def run_from_ir(self, ir: str, module_name: str, workspace: Directory): """Compile a shared object from a textual IR (MLIR or LLVM). @@ -484,8 +485,15 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory): else: out_IR = None + # If target is llvm-ir, only return LLVM IR without linking + if self.options.target == "llvmir": + output = output_ir_name if os.path.exists(output_ir_name) else None + if os.path.exists(tmp_infile_name): + os.remove(tmp_infile_name) + return output, out_IR + output = LinkerDriver.run(output_object_name, options=self.options) - output_object_name = str(pathlib.Path(output).absolute()) + output = str(pathlib.Path(output).absolute()) # Clean up temporary files if os.path.exists(tmp_infile_name): @@ -493,7 +501,7 @@ def run_from_ir(self, ir: str, module_name: str, workspace: Directory): if os.path.exists(output_ir_name): os.remove(output_ir_name) - return output_object_name, out_IR + return output, out_IR def has_xdsl_passes_in_transform_modules(self, mlir_module): """Check if the MLIR module contains xDSL passes in transform dialect. diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index f17051e41b..823047d6ad 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -1,4 +1,4 @@ -# Copyright 2022-2024 Xanadu Quantum Technologies Inc. +# Copyright 2022-2026 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -603,6 +603,10 @@ def __call__(self, *args, **kwargs): requires_promotion = self.jit_compile(args, **kwargs) + # For llvm-ir target, compilation is complete, no execution needed + if self.compile_options.target == "llvmir": + return None + # If we receive tracers as input, dispatch to the JAX integration. if any(isinstance(arg, jax.core.Tracer) for arg in tree_flatten(args)[0]): if self.jaxed_function is None: @@ -621,16 +625,18 @@ def aot_compile(self): self.workspace = self._get_workspace() # TODO: awkward, refactor or redesign the target feature - if self.compile_options.target in ("jaxpr", "mlir", "binary"): + if self.compile_options.target in ("jaxpr", "mlir", "llvmir", "binary"): self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture( self.user_sig or () ) - if self.compile_options.target in ("mlir", "binary"): + if self.compile_options.target in ("mlir", "llvmir", "binary"): self.mlir_module = self.generate_ir() - if self.compile_options.target in ("binary",): + if self.compile_options.target in ("llvmir", "binary"): self.compiled_function, _ = self.compile() + + if self.compile_options.target in ("binary",): self.fn_cache.insert( self.compiled_function, self.user_sig, self.out_treedef, self.workspace ) @@ -656,6 +662,15 @@ def jit_compile(self, args, **kwargs): bool: whether the provided arguments will require promotion to be used with the compiled function """ + if self.compile_options.target == "llvmir": + if self.mlir_module is not None: + return False + self.workspace = self._get_workspace() + self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args, **kwargs) + self.mlir_module = self.generate_ir() + self.compiled_function, _ = self.compile() + return False + cached_fn, requires_promotion = self.fn_cache.lookup(args) if cached_fn is None: @@ -795,6 +810,7 @@ def compile(self): Returns: Tuple[CompiledFunction, str]: the compilation result and LLVMIR + For targets that skip execution, returns (None, llvm_ir) instead. """ # WARNING: assumption is that the first function is the entry point to the compiled program. entry_point_func = self.mlir_module.body.operations[0] @@ -820,6 +836,9 @@ def compile(self): else: shared_object, llvm_ir = self.compiler.run(self.mlir_module, self.workspace) + if self.compile_options.target == "llvmir": + return None, llvm_ir + compiled_fn = CompiledFunction( shared_object, func_name, restype, self.out_type, self.compile_options ) diff --git a/frontend/catalyst/third_party/oqd/__init__.py b/frontend/catalyst/third_party/oqd/__init__.py index 7e7d2bb40b..c509cb991a 100644 --- a/frontend/catalyst/third_party/oqd/__init__.py +++ b/frontend/catalyst/third_party/oqd/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. +# Copyright 2024-2026 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,6 +15,7 @@ This submodule contains classes for the OQD device and its properties. """ +from .oqd_compile import compile_to_artiq from .oqd_device import OQDDevice, OQDDevicePipeline -__all__ = ["OQDDevice", "OQDDevicePipeline"] +__all__ = ["OQDDevice", "OQDDevicePipeline", "compile_to_artiq"] diff --git a/frontend/catalyst/third_party/oqd/oqd_compile.py b/frontend/catalyst/third_party/oqd/oqd_compile.py new file mode 100644 index 0000000000..9d1dc3a861 --- /dev/null +++ b/frontend/catalyst/third_party/oqd/oqd_compile.py @@ -0,0 +1,229 @@ +# Copyright 2026 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +OQD Compiler utilities for compiling and linking LLVM IR to ARTIQ's binary. +""" + +import os +import subprocess +from pathlib import Path +from typing import Optional + + +def compile_to_artiq(circuit, artiq_config, output_elf_name=None, verbose=True): + """Compile a qjit-compiled circuit to ARTIQ's binary. + + This function takes a circuit compiled with target="llvmir", writes the LLVM IR + to a file, and links it to an ARTIQ's binary. + + Args: + circuit: A QJIT-compiled function (must be compiled with target="llvmir") + artiq_config: Dictionary containing ARTIQ configuration: + - kernel_ld: Path to ARTIQ kernel linker script + - llc_path: (optional) Path to llc compiler + - lld_path: (optional) Path to ld.lld linker + output_elf_name: Name of the output ELF file (default: None, uses circuit function name) + verbose: Whether to print verbose output (default: True) + + Returns: + str: Path to the generated binary file + """ + # Get LLVM IR text and write to file + llvm_ir_text = circuit.qir + circuit_name = getattr(circuit, "__name__", "circuit") + llvm_ir_path = os.path.join(str(circuit.workspace), f"{circuit_name}.ll") + with open(llvm_ir_path, "w", encoding="utf-8") as f: + f.write(llvm_ir_text) + print(f"LLVM IR file written to: {llvm_ir_path}") + + # Link to ARTIQ's binary + if output_elf_name is None: + output_elf_name = f"{circuit_name}.elf" + + # Output ELF file to current working directory if workspace is in /private (temp dir), + # otherwise use workspace directory + workspace_str = str(circuit.workspace) + if "/private" in workspace_str: + output_elf_path = os.path.join(os.getcwd(), output_elf_name) + else: + output_elf_path = os.path.join(workspace_str, output_elf_name) + + link_to_artiq_elf( + llvm_ir_path=llvm_ir_path, + output_elf_path=output_elf_path, + kernel_ld=artiq_config["kernel_ld"], + llc_path=artiq_config.get("llc_path"), + lld_path=artiq_config.get("lld_path"), + verbose=verbose, + ) + + return output_elf_path + + +def _validate_paths(llvm_ir_path: Path, kernel_ld: Path) -> None: + """Validate that required input files exist.""" + if not llvm_ir_path.exists(): + raise FileNotFoundError(f"LLVM IR file not found: {llvm_ir_path}") + if not kernel_ld.exists(): + raise FileNotFoundError(f"ARTIQ kernel.ld not found: {kernel_ld}") + + +def _get_tool_command(tool_path: Optional[str], default_name: str) -> str: + """Get tool command path, validating if custom path is provided.""" + if tool_path is None: + return default_name + tool_path_obj = Path(tool_path) + if not tool_path_obj.exists(): + raise FileNotFoundError(f"{default_name} not found: {tool_path}") + return tool_path + + +def _compile_llvm_to_object( + llvm_ir_path: Path, object_file: Path, llc_cmd: str, verbose: bool +) -> None: + """Compile LLVM IR to object file with llc. + + Args: + llvm_ir_path: Path to LLVM IR file + object_file: Path to object file + llc_cmd: Command to use for llc compiler + verbose: Whether to print verbose output + + Raises: + RuntimeError: If compilation fails + FileNotFoundError: If llc is not found + """ + llc_args = [ + llc_cmd, + "-mtriple=armv7-unknown-linux-gnueabihf", + "-mcpu=cortex-a9", + "-filetype=obj", + "-relocation-model=pic", + "-o", + str(object_file), + str(llvm_ir_path), + ] + + if verbose: + print(f"[ARTIQ] Compiling with external LLC: {' '.join(llc_args)}") + + try: + result = subprocess.run(llc_args, check=True, capture_output=True, text=True) + if verbose and result.stderr: + print(f"[ARTIQ] LLC stderr: {result.stderr}") + except subprocess.CalledProcessError as e: + error_msg = f"External LLC failed with exit code: {e.returncode}" + if e.stderr: + error_msg += f"\n{e.stderr}" + raise RuntimeError(error_msg) from e + except FileNotFoundError as exc: + raise FileNotFoundError( + "llc not found. Please install LLVM or provide path via llc_path argument." + ) from exc + + if not object_file.exists(): + raise RuntimeError(f"Object file was not created: {object_file}") + + +def _link_object_to_elf( + object_file: Path, output_elf_path: Path, kernel_ld: Path, lld_cmd: str, verbose: bool +) -> None: + """Link object file to ELF format with ld.lld. + + Args: + object_file: Path to object file + output_elf_path: Path to output ELF file + kernel_ld: Path to kernel linker script + lld_cmd: Command to use for ld.lld linker + verbose: Whether to print verbose output + + Raises: + RuntimeError: If linking fails + FileNotFoundError: If ld.lld is not found + """ + lld_args = [ + lld_cmd, + "-shared", + "--eh-frame-hdr", + "-m", + "armelf_linux_eabi", + "--target2=rel", + "-T", + str(kernel_ld), + str(object_file), + "-o", + str(output_elf_path), + ] + + if verbose: + print(f"[ARTIQ] Linking ELF: {' '.join(lld_args)}") + + try: + result = subprocess.run(lld_args, check=True, capture_output=True, text=True) + if verbose and result.stderr: + print(f"[ARTIQ] LLD stderr: {result.stderr}") + except subprocess.CalledProcessError as e: + error_msg = f"LLD linking failed with exit code: {e.returncode}" + if e.stderr: + error_msg += f"\n{e.stderr}" + raise RuntimeError(error_msg) from e + except FileNotFoundError as exc: + raise FileNotFoundError( + "ld.lld not found. Please install LLVM LLD or provide path via lld_path argument." + ) from exc + + if not output_elf_path.exists(): + raise RuntimeError(f"ELF file was not created: {output_elf_path}") + + +# pylint: disable=too-many-arguments,too-many-positional-arguments +def link_to_artiq_elf( + llvm_ir_path: str, + output_elf_path: str, + kernel_ld: str, + llc_path: Optional[str] = None, + lld_path: Optional[str] = None, + verbose: bool = False, +) -> str: + """Link LLVM IR to ARTIQ ELF format. + + Args: + llvm_ir_path: Path to the LLVM IR file (.ll) + output_elf_path: Path to output ELF file + kernel_ld: Path to ARTIQ's kernel.ld linker script + llc_path: Path to llc (LLVM compiler). If None, uses "llc" from PATH + lld_path: Path to ld.lld (LLVM linker). If None, uses "ld.lld" from PATH + verbose: If True, print compilation commands + + Returns: + Path to the generated ELF file + """ + llvm_ir_path_obj = Path(llvm_ir_path) + output_elf_path_obj = Path(output_elf_path) + kernel_ld_obj = Path(kernel_ld) + + _validate_paths(llvm_ir_path_obj, kernel_ld_obj) + + llc_cmd = _get_tool_command(llc_path, "llc") + lld_cmd = _get_tool_command(lld_path, "ld.lld") + + object_file = output_elf_path_obj.with_suffix(".o") + _compile_llvm_to_object(llvm_ir_path_obj, object_file, llc_cmd, verbose) + _link_object_to_elf(object_file, output_elf_path_obj, kernel_ld_obj, lld_cmd, verbose) + + if verbose: + print(f"[ARTIQ] Generated ELF: {output_elf_path_obj}") + + return str(output_elf_path_obj) diff --git a/frontend/catalyst/third_party/oqd/oqd_device.py b/frontend/catalyst/third_party/oqd/oqd_device.py index 5984062730..544bda3734 100644 --- a/frontend/catalyst/third_party/oqd/oqd_device.py +++ b/frontend/catalyst/third_party/oqd/oqd_device.py @@ -1,4 +1,4 @@ -# Copyright 2024 Xanadu Quantum Technologies Inc. +# Copyright 2024-2026 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ trapped-ion quantum computer device. """ from typing import Optional +import os import platform from pennylane import CompilePipeline @@ -29,7 +30,17 @@ BACKENDS = ["default"] -def OQDDevicePipeline(device, qubit, gate): +def get_default_artiq_config(): + """Get default ARTIQ cross-compilation configuration""" + # Check environment variable + kernel_ld = os.environ.get("ARTIQ_KERNEL_LD") + if kernel_ld and os.path.exists(kernel_ld): + return {"kernel_ld": kernel_ld} + + return None + + +def OQDDevicePipeline(device, qubit, gate, device_db=None): """ Generate the compilation pipeline for an OQD device. @@ -37,12 +48,50 @@ def OQDDevicePipeline(device, qubit, gate): device (str): the path to the device toml file specifications. qubit (str): the path to the qubit toml file specifications. gate (str): the path to the gate toml file specifications. + device_db (str, optional): the path to the device_db.json file for ARTIQ. + If provided, generates ARTIQ-compatible output. + If None, uses convert-ion-to-llvm for legacy OQD pipeline. Returns: A list of tuples, with each tuple being a stage in the compilation pipeline. When using ``keep_intermediate=True`` from :func:`~.qjit`, the kept stages correspond to the tuples. """ + # Common gates-to-pulses pass + gates_to_pulses_pass = ( + "func.func(gates-to-pulses{" + + "device-toml-loc=" + + device + + " qubit-toml-loc=" + + qubit + + " gate-to-pulse-toml-loc=" + + gate + + "})" + ) + + # Build OQD pipeline based on whether device_db is provided + if device_db is not None: + oqd_passes = [ + "func.func(ions-decomposition)", + gates_to_pulses_pass, + "convert-ion-to-rtio{" + "device_db=" + device_db + "}", + "convert-rtio-event-to-artiq", + ] + llvm_lowering_passes = [ + "llvm-dialect-lowering-stage", + "emit-artiq-runtime", + ] + else: + # Standard LLVM lowering route (legacy OQD pipeline) + oqd_passes = [ + "func.func(ions-decomposition)", + gates_to_pulses_pass, + "convert-ion-to-llvm", + ] + llvm_lowering_passes = [ + "llvm-dialect-lowering-stage", + ] + return [ ( "device-agnostic-pipeline", @@ -55,30 +104,27 @@ def OQDDevicePipeline(device, qubit, gate): ), ( "oqd_pipeline", - [ - "func.func(ions-decomposition)", - "func.func(gates-to-pulses{" - + "device-toml-loc=" - + device - + " qubit-toml-loc=" - + qubit - + " gate-to-pulse-toml-loc=" - + gate - + "})", - "convert-ion-to-llvm", - ], + oqd_passes, ), ( "llvm-dialect-lowering-stage", - [ - "llvm-dialect-lowering-stage", - ], + llvm_lowering_passes, ), ] class OQDDevice(Device): - """The OQD device allows access to the hardware devices from OQD using Catalyst.""" + """The OQD device allows access to the hardware devices from OQD using Catalyst. + + Args: + wires: The number of wires/qubits. + backend: Backend name (default: "default"). + openapl_file_name: Output file name for OpenAPL. + artiq_config: ARTIQ cross-compilation configuration dict with keys: + - kernel_ld: Path to ARTIQ's kernel.ld linker script + - llc_path: Path to llc + - lld_path: Path to ld.lld + """ config_filepath = get_lib_path("oqd_runtime", "OQD_LIB_DIR") + "/backend" + "/oqd.toml" @@ -96,7 +142,12 @@ def get_c_interface(): return "oqd", lib_path def __init__( - self, wires, backend="default", openapl_file_name="__openapl__output.json", **kwargs + self, + wires, + backend="default", + openapl_file_name="__openapl__output.json", + artiq_config=None, + **kwargs, ): self._backend = backend self._openapl_file_name = openapl_file_name @@ -106,6 +157,16 @@ def __init__( "openapl_file_name": self._openapl_file_name, } + if artiq_config is not None: + self._artiq_config = artiq_config + else: + self._artiq_config = get_default_artiq_config() + + @property + def artiq_config(self): + """ARTIQ cross-compilation configuration.""" + return self._artiq_config + @property def openapl_file_name(self): """The OpenAPL output file name.""" diff --git a/mlir/include/RTIO/Transforms/Passes.td b/mlir/include/RTIO/Transforms/Passes.td index ea8ad1151b..6597caf7a9 100644 --- a/mlir/include/RTIO/Transforms/Passes.td +++ b/mlir/include/RTIO/Transforms/Passes.td @@ -1,4 +1,4 @@ -// Copyright 2025 Xanadu Quantum Technologies Inc. +// Copyright 2025-2026 Xanadu Quantum Technologies Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -39,6 +39,19 @@ def RTIOEventToARTIQPass : Pass<"convert-rtio-event-to-artiq", "mlir::ModuleOp"> ]; } -#endif // RTIO_PASSES +def RTIOEmitARTIQRuntimePass : Pass<"emit-artiq-runtime", "mlir::ModuleOp"> { + let summary = "Emit ARTIQ runtime wrapper (__modinit__) that calls the kernel function"; + let description = [{ + This pass creates the ARTIQ entry point structure: + - __modinit__ (entry point) -> calls kernel function + This pass injects the necessary ARTIQ runtime boilerplate directly into the MLIR module + to allow Catalyst-generated kernels can be loaded and executed by ARTIQ + }]; + let dependentDialects = [ + "mlir::LLVM::LLVMDialect", + "mlir::func::FuncDialect" + ]; +} +#endif // RTIO_PASSES diff --git a/mlir/lib/RTIO/Transforms/CMakeLists.txt b/mlir/lib/RTIO/Transforms/CMakeLists.txt index e2ee789bcf..4c806b5674 100644 --- a/mlir/lib/RTIO/Transforms/CMakeLists.txt +++ b/mlir/lib/RTIO/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ set(LIBRARY_NAME rtio-transforms) file(GLOB SRC RTIOEventToARTIQ.cpp RTIOEventToARTIQPatterns.cpp + RTIOEmitARTIQRuntime.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) diff --git a/mlir/lib/RTIO/Transforms/RTIOEmitARTIQRuntime.cpp b/mlir/lib/RTIO/Transforms/RTIOEmitARTIQRuntime.cpp new file mode 100644 index 0000000000..076b68f006 --- /dev/null +++ b/mlir/lib/RTIO/Transforms/RTIOEmitARTIQRuntime.cpp @@ -0,0 +1,164 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/// This pass creates the ARTIQ entry point structure that allows Catalyst-generated +/// kernels to be loaded and executed by ARTIQ +/// +/// The pass transforms: +/// @__kernel__(ptr, ptr, i64) +/// Into: +/// @__modinit__(ptr) -> calls @__kernel__(ptr, nullptr, 0) + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" + +#include "ARTIQRuntimeBuilder.hpp" +#include "RTIO/Transforms/Passes.h" + +using namespace mlir; + +namespace catalyst { +namespace rtio { + +#define GEN_PASS_DEF_RTIOEMITARTIQRUNTIMEPASS +#include "RTIO/Transforms/Passes.h.inc" + +namespace { + +//===----------------------------------------------------------------------===// +// ARTIQ Runtime Constants +//===----------------------------------------------------------------------===// + +namespace ARTIQRuntime { +constexpr StringLiteral modinit = "__modinit__"; +constexpr StringLiteral artiqPersonality = "__artiq_personality"; +} // namespace ARTIQRuntime + +//===----------------------------------------------------------------------===// +// Pass Implementation +//===----------------------------------------------------------------------===// + +struct RTIOEmitARTIQRuntimePass + : public impl::RTIOEmitARTIQRuntimePassBase { + using RTIOEmitARTIQRuntimePassBase::RTIOEmitARTIQRuntimePassBase; + + void runOnOperation() override + { + ModuleOp module = getOperation(); + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + // Check if __modinit__ already exists (the entry point of ARTIQ device) + if (module.lookupSymbol(ARTIQRuntime::modinit)) { + return; + } + + // Find the kernel function (could be LLVM func or func.func) + LLVM::LLVMFuncOp llvmKernelFunc = + module.lookupSymbol(ARTIQFuncNames::kernel); + + if (!llvmKernelFunc) { + module.emitError("Cannot find kernel function"); + return signalPassFailure(); + } + + // Create ARTIQ runtime wrapper + if (failed(emitARTIQRuntimeForLLVMFunc(module, builder, llvmKernelFunc))) { + return signalPassFailure(); + } + } + + private: + /// Emit ARTIQ runtime wrapper for LLVM dialect kernel function + LogicalResult emitARTIQRuntimeForLLVMFunc(ModuleOp module, OpBuilder &builder, + LLVM::LLVMFuncOp kernelFunc) + { + MLIRContext *ctx = builder.getContext(); + Location loc = module.getLoc(); + + // Types + Type voidTy = LLVM::LLVMVoidType::get(ctx); + Type ptrTy = LLVM::LLVMPointerType::get(ctx); + Type i32Ty = IntegerType::get(ctx, 32); + Type i64Ty = IntegerType::get(ctx, 64); + + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(module.getBody()); + + // Declare __artiq_personality (exception handling) + declareARTIQPersonality(module, builder, loc); + + // Create entry function: void @__modinit__(ptr %self) + auto modinitTy = LLVM::LLVMFunctionType::get(voidTy, {ptrTy}); + auto modinitFunc = builder.create(loc, ARTIQRuntime::modinit, modinitTy); + modinitFunc.setLinkage(LLVM::Linkage::External); + + // Set personality function for exception handling + modinitFunc.setPersonalityAttr(FlatSymbolRefAttr::get(ctx, ARTIQRuntime::artiqPersonality)); + + // Create function body + Block *entry = modinitFunc.addEntryBlock(builder); + builder.setInsertionPointToStart(entry); + + // Get the actual kernel function type and create matching arguments + auto kernelFuncTy = kernelFunc.getFunctionType(); + SmallVector callArgs; + + for (Type argTy : kernelFuncTy.getParams()) { + // Create zero/null values for each argument type + if (isa(argTy)) { + callArgs.push_back(builder.create(loc, ptrTy)); + } + else if (argTy.isInteger(64)) { + callArgs.push_back( + builder.create(loc, i64Ty, builder.getI64IntegerAttr(0))); + } + else if (argTy.isInteger(32)) { + callArgs.push_back( + builder.create(loc, i32Ty, builder.getI32IntegerAttr(0))); + } + else { + // For other types, use null pointer as fallback + callArgs.push_back(builder.create(loc, ptrTy)); + } + } + + auto callOp = builder.create(loc, kernelFunc, callArgs); + callOp.setTailCallKind(LLVM::TailCallKind::Tail); + + builder.create(loc, ValueRange{}); + + return success(); + } + + /// Declare __artiq_personality function + void declareARTIQPersonality(ModuleOp module, OpBuilder &builder, Location loc) + { + if (module.lookupSymbol(ARTIQRuntime::artiqPersonality)) { + return; + } + + Type i32Ty = IntegerType::get(builder.getContext(), 32); + auto personalityTy = LLVM::LLVMFunctionType::get(i32Ty, {}, /*isVarArg=*/true); + builder.create(loc, ARTIQRuntime::artiqPersonality, personalityTy, + LLVM::Linkage::External); + } +}; + +} // namespace +} // namespace rtio +} // namespace catalyst diff --git a/mlir/test/RTIO/RTIOEmitARTIQRuntime.mlir b/mlir/test/RTIO/RTIOEmitARTIQRuntime.mlir new file mode 100644 index 0000000000..7d2e8dcca7 --- /dev/null +++ b/mlir/test/RTIO/RTIOEmitARTIQRuntime.mlir @@ -0,0 +1,49 @@ +// Copyright 2026 Xanadu Quantum Technologies Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --emit-artiq-runtime --split-input-file | FileCheck %s + +// Test basic ARTIQ runtime wrapper generation +// CHECK-LABEL: module @test_basic +module @test_basic { + // CHECK: llvm.func @__artiq_personality(...) -> i32 + // CHECK: llvm.func @__modinit__(%arg0: !llvm.ptr) attributes {personality = @__artiq_personality} + // CHECK-SAME: { + // CHECK: llvm.call tail @__kernel__() : () -> () + // CHECK: llvm.return + // CHECK: } + llvm.func @__kernel__() { + llvm.return + } +} + +// ----- + +// Test with kernel function that has arguments +// CHECK-LABEL: module @test_kernel_with_args +module @test_kernel_with_args { + // CHECK: llvm.func @__artiq_personality(...) -> i32 + // CHECK: llvm.func @__modinit__(%arg0: !llvm.ptr) attributes {personality = @__artiq_personality} + // CHECK-SAME: { + // CHECK: %[[ZERO_PTR0:.*]] = llvm.mlir.zero : !llvm.ptr + // CHECK: %[[ZERO_PTR1:.*]] = llvm.mlir.zero : !llvm.ptr + // CHECK: %[[ZERO_I64:.*]] = llvm.mlir.constant(0 : i64) : i64 + // CHECK: llvm.call tail @__kernel__(%[[ZERO_PTR0]], %[[ZERO_PTR1]], %[[ZERO_I64]]) + // CHECK: llvm.return + // CHECK: } + llvm.func @__kernel__(%arg0: !llvm.ptr, %arg1: !llvm.ptr, %arg2: i64) { + llvm.return + } +} + diff --git a/mlir/tools/catalyst-cli/CMakeLists.txt b/mlir/tools/catalyst-cli/CMakeLists.txt index 39bd44ecf4..b83e628dca 100644 --- a/mlir/tools/catalyst-cli/CMakeLists.txt +++ b/mlir/tools/catalyst-cli/CMakeLists.txt @@ -43,6 +43,7 @@ set(LIBS MLIRIon ion-transforms MLIRRTIO + rtio-transforms MLIRCatalystTest ${ENZYME_LIB} CatalystCompilerDriver