Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions mlir/include/mlir-c/Dialect/AMDGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,39 @@ extern "C" {

MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu);

//===---------------------------------------------------------------------===//
// TDMBaseType
//===---------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMBaseTypeGetTypeID();

MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx,
MlirType elementType);

//===---------------------------------------------------------------------===//
// TDMDescriptorType
//===---------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMDescriptorTypeGetTypeID();

MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx);

//===---------------------------------------------------------------------===//
// TDMGatherBaseType
//===---------------------------------------------------------------------===//

MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type);

MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMGatherBaseTypeGetTypeID();

MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx,
MlirType elementType,
MlirType indexType);

#ifdef __cplusplus
}
#endif
Expand Down
65 changes: 65 additions & 0 deletions mlir/lib/Bindings/Python/DialectAMDGPU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//===--- DialectAMDGPU.cpp - Pybind module for AMDGPU dialect API support -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Dialect/AMDGPU.h"
#include "mlir-c/IR.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "nanobind/nanobind.h"

namespace nb = nanobind;
using namespace llvm;
using namespace mlir;
using namespace mlir::python;
using namespace mlir::python::nanobind_adaptors;

static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
auto amdgpuTDMBaseType =
mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
mlirAMDGPUTDMBaseTypeGetTypeID);

amdgpuTDMBaseType.def_classmethod(
"get",
[](const nb::object &cls, MlirType elementType, MlirContext ctx) {
return cls(mlirAMDGPUTDMBaseTypeGet(ctx, elementType));
},
"Gets an instance of TDMBaseType in the same context", nb::arg("cls"),
nb::arg("element_type"), nb::arg("ctx") = nb::none());

auto amdgpuTDMDescriptorType = mlir_type_subclass(
m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
mlirAMDGPUTDMDescriptorTypeGetTypeID);

amdgpuTDMDescriptorType.def_classmethod(
"get",
[](const nb::object &cls, MlirContext ctx) {
return cls(mlirAMDGPUTDMDescriptorTypeGet(ctx));
},
"Gets an instance of TDMDescriptorType in the same context",
nb::arg("cls"), nb::arg("ctx") = nb::none());

auto amdgpuTDMGatherBaseType = mlir_type_subclass(
m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
mlirAMDGPUTDMGatherBaseTypeGetTypeID);

amdgpuTDMGatherBaseType.def_classmethod(
"get",
[](const nb::object &cls, MlirType elementType, MlirType indexType,
MlirContext ctx) {
return cls(mlirAMDGPUTDMGatherBaseTypeGet(ctx, elementType, indexType));
},
"Gets an instance of TDMGatherBaseType in the same context",
nb::arg("cls"), nb::arg("element_type"), nb::arg("index_type"),
nb::arg("ctx") = nb::none());
};

NB_MODULE(_mlirDialectsAMDGPU, m) {
m.doc() = "MLIR AMDGPU dialect.";

populateDialectAMDGPUSubmodule(m);
}
53 changes: 53 additions & 0 deletions mlir/lib/CAPI/Dialect/AMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,56 @@

MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu,
mlir::amdgpu::AMDGPUDialect)

using namespace mlir;
using namespace mlir::amdgpu;

//===---------------------------------------------------------------------===//
// TDMBaseType
//===---------------------------------------------------------------------===//

bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type) {
return isa<amdgpu::TDMBaseType>(unwrap(type));
}

MlirTypeID mlirAMDGPUTDMBaseTypeGetTypeID() {
return wrap(amdgpu::TDMBaseType::getTypeID());
}

MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, MlirType elementType) {
return wrap(amdgpu::TDMBaseType::get(unwrap(ctx), unwrap(elementType)));
}

//===---------------------------------------------------------------------===//
// TDMDescriptorType
//===---------------------------------------------------------------------===//

bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type) {
return isa<amdgpu::TDMDescriptorType>(unwrap(type));
}

MlirTypeID mlirAMDGPUTDMDescriptorTypeGetTypeID() {
return wrap(amdgpu::TDMDescriptorType::getTypeID());
}

MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx) {
return wrap(amdgpu::TDMDescriptorType::get(unwrap(ctx)));
}

//===---------------------------------------------------------------------===//
// TDMGatherBaseType
//===---------------------------------------------------------------------===//

bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type) {
return isa<amdgpu::TDMGatherBaseType>(unwrap(type));
}

MlirTypeID mlirAMDGPUTDMGatherBaseTypeGetTypeID() {
return wrap(amdgpu::TDMGatherBaseType::getTypeID());
}

MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, MlirType elementType,
MlirType indexType) {
return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx), unwrap(elementType),
unwrap(indexType)));
}
15 changes: 15 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,21 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter
MLIRCAPITransformDialectTransforms
)

declare_mlir_python_extension(MLIRPythonExtension.Dialects.AMDGPU.Pybind
MODULE_NAME _mlirDialectsAMDGPU
ADD_TO_PARENT MLIRPythonSources.Dialects.amdgpu
ROOT_DIR "${PYTHON_SOURCE_DIR}"
PYTHON_BINDINGS_LIBRARY nanobind
SOURCES
DialectAMDGPU.cpp
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
MLIRCAPIIR
MLIRCAPIAMDGPU
)


# TODO: Figure out how to put this in the test tree.
# This should not be included in the main Python extension. However,
# putting it into MLIRPythonTestSources along with the dialect declaration
Expand Down
1 change: 1 addition & 0 deletions mlir/python/mlir/dialects/amdgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from ._amdgpu_ops_gen import *
from ._amdgpu_enum_gen import *
from .._mlir_libs._mlirDialectsAMDGPU import *
26 changes: 26 additions & 0 deletions mlir/test/python/dialects/amdgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from mlir.dialects import amdgpu, func


def run(f):
print("\nTEST:", f.__name__)
f()
return f


def constructAndPrintInModule(f):
print("\nTEST:", f.__name__)
with Context(), Location.unknown():
Expand Down Expand Up @@ -43,3 +49,23 @@ def testFatRawBufferCastOpParams():
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] resetOffset : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-NEXT: amdgpu.fat_raw_buffer_cast %[[ARG0]] boundsCheck(false) resetOffset : memref<?x?xf32> to memref<?x?xf32, #amdgpu.address_space<fat_raw_buffer>>


# CHECK-LABEL: testTDMTypes
@run
def testTDMTypes():
with Context():
f32 = F32Type.get()
i32 = IntegerType.get_signless(32)

# CHECK: !amdgpu.tdm_base<f32>
tdm_base = amdgpu.TDMBaseType.get(f32)
print(tdm_base)

# CHECK: !amdgpu.tdm_descriptor
tdm_descriptor = amdgpu.TDMDescriptorType.get()
print(tdm_descriptor)

# CHECK: !amdgpu.tdm_gather_base<f32, i32>
tdm_gather_base = amdgpu.TDMGatherBaseType.get(f32, i32)
print(tdm_gather_base)