From 9ec3bce563e8956849663137beb8d8ce8655be24 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Mon, 15 Dec 2025 14:38:59 +0000 Subject: [PATCH] [amdgpu] Add Python bindings for TDM types --- mlir/include/mlir-c/Dialect/AMDGPU.h | 33 +++++++++++ mlir/lib/Bindings/Python/DialectAMDGPU.cpp | 65 ++++++++++++++++++++++ mlir/lib/CAPI/Dialect/AMDGPU.cpp | 53 ++++++++++++++++++ mlir/python/CMakeLists.txt | 15 +++++ mlir/python/mlir/dialects/amdgpu.py | 1 + mlir/test/python/dialects/amdgpu.py | 26 +++++++++ 6 files changed, 193 insertions(+) create mode 100644 mlir/lib/Bindings/Python/DialectAMDGPU.cpp diff --git a/mlir/include/mlir-c/Dialect/AMDGPU.h b/mlir/include/mlir-c/Dialect/AMDGPU.h index 142044f7f3afe..83cfe8f5dd65e 100644 --- a/mlir/include/mlir-c/Dialect/AMDGPU.h +++ b/mlir/include/mlir-c/Dialect/AMDGPU.h @@ -18,6 +18,39 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu); +//===---------------------------------------------------------------------===// +// TDMBaseType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMBaseTypeGetTypeID(); + +MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, + MlirType elementType); + +//===---------------------------------------------------------------------===// +// TDMDescriptorType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMDescriptorTypeGetTypeID(); + +MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// TDMGatherBaseType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMGatherBaseTypeGetTypeID(); + +MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, + MlirType elementType, + MlirType indexType); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/DialectAMDGPU.cpp b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp new file mode 100644 index 0000000000000..26ffc0e427e41 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectAMDGPU.cpp @@ -0,0 +1,65 @@ +//===--- DialectAMDGPU.cpp - Pybind module for AMDGPU dialect API support -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/AMDGPU.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" + +namespace nb = nanobind; +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +static void populateDialectAMDGPUSubmodule(const nb::module_ &m) { + auto amdgpuTDMBaseType = + mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType, + mlirAMDGPUTDMBaseTypeGetTypeID); + + amdgpuTDMBaseType.def_classmethod( + "get", + [](const nb::object &cls, MlirType elementType, MlirContext ctx) { + return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType)); + }, + "Gets an instance of TDMBaseType in the same context", nb::arg("cls"), + nb::arg("element_type"), nb::arg("ctx") = nb::none()); + + auto amdgpuTDMDescriptorType = mlir_type_subclass( + m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType, + mlirAMDGPUTDMDescriptorTypeGetTypeID); + + amdgpuTDMDescriptorType.def_classmethod( + "get", + [](const nb::object &cls, MlirContext ctx) { + return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx)); + }, + "Gets an instance of TDMDescriptorType in the same context", + nb::arg("cls"), nb::arg("ctx") = nb::none()); + + auto amdgpuTDMGatherBaseType = mlir_type_subclass( + m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType, + mlirAMDGPUTDMGatherBaseTypeGetTypeID); + + amdgpuTDMGatherBaseType.def_classmethod( + "get", + [](const nb::object &cls, MlirType elementType, MlirType indexType, + MlirContext ctx) { + return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType)); + }, + "Gets an instance of TDMGatherBaseType in the same context", + nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"), + nb::arg("ctx") = nb::none()); +}; + +NB_MODULE(_mlirDialectsAMDGPU, m) { + m.doc() = "MLIR AMDGPU dialect."; + + populateDialectAMDGPUSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/AMDGPU.cpp b/mlir/lib/CAPI/Dialect/AMDGPU.cpp index d877ca2dff375..ddca1cb55edab 100644 --- a/mlir/lib/CAPI/Dialect/AMDGPU.cpp +++ b/mlir/lib/CAPI/Dialect/AMDGPU.cpp @@ -12,3 +12,56 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu, mlir::amdgpu::AMDGPUDialect) + +using namespace mlir; +using namespace mlir::amdgpu; + +//===---------------------------------------------------------------------===// +// TDMBaseType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirAMDGPUTDMBaseTypeGetTypeID() { + return wrap(amdgpu::TDMBaseType::getTypeID()); +} + +MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, MlirType elementType) { + return wrap(amdgpu::TDMBaseType::get(unwrap(ctx), unwrap(elementType))); +} + +//===---------------------------------------------------------------------===// +// TDMDescriptorType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirAMDGPUTDMDescriptorTypeGetTypeID() { + return wrap(amdgpu::TDMDescriptorType::getTypeID()); +} + +MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx) { + return wrap(amdgpu::TDMDescriptorType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// TDMGatherBaseType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirAMDGPUTDMGatherBaseTypeGetTypeID() { + return wrap(amdgpu::TDMGatherBaseType::getTypeID()); +} + +MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, MlirType elementType, + MlirType indexType) { + return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx), unwrap(elementType), + unwrap(indexType))); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 2acb6ee6cfda5..6e449e275f782 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -804,6 +804,21 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter MLIRCAPITransformDialectTransforms ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Pybind + MODULE_NAME _mlirDialectsAMDGPU + ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectAMDGPU.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIAMDGPU +) + + # TODO: Figure out how to put this in the test tree. # This should not be included in the main Python extension. However, # putting it into MLIRPythonTestSources along with the dialect declaration diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py index 43d905d0c481c..1c4d274bc31af 100644 --- a/mlir/python/mlir/dialects/amdgpu.py +++ b/mlir/python/mlir/dialects/amdgpu.py @@ -4,3 +4,4 @@ from ._amdgpu_ops_gen import * from ._amdgpu_enum_gen import * +from .._mlir_libs._mlirDialectsAMDGPU import * diff --git a/mlir/test/python/dialects/amdgpu.py b/mlir/test/python/dialects/amdgpu.py index b479576dac093..10415ec1e842a 100644 --- a/mlir/test/python/dialects/amdgpu.py +++ b/mlir/test/python/dialects/amdgpu.py @@ -5,6 +5,12 @@ from mlir.dialects import amdgpu, func +def run(f): + print("\nTEST:", f.__name__) + f() + return f + + def constructAndPrintInModule(f): print("\nTEST:", f.__name__) with Context(), Location.unknown(): @@ -43,3 +49,23 @@ def testFatRawBufferCastOpParams(): # CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] resetOffset : memref to memref> # CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) : memref to memref> # CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) resetOffset : memref to memref> + + +# CHECK-LABEL: testTDMTypes +@run +def testTDMTypes(): + with Context(): + f32 = F32Type.get() + i32 = IntegerType.get_signless(32) + + # CHECK: !amdgpu.tdm_base + tdm_base = amdgpu.TDMBaseType.get(f32) + print(tdm_base) + + # CHECK: !amdgpu.tdm_descriptor + tdm_descriptor = amdgpu.TDMDescriptorType.get() + print(tdm_descriptor) + + # CHECK: !amdgpu.tdm_gather_base + tdm_gather_base = amdgpu.TDMGatherBaseType.get(f32, i32) + print(tdm_gather_base)