Skip to content

Commit 9e638c9

Browse files
committed
hook getTypeIDFunction
1 parent c660714 commit 9e638c9

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

mlir/include/mlir-c/Dialect/AMDGPU.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu);
2424

2525
MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type);
2626

27+
MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMBaseTypeGetTypeID();
28+
2729
MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx,
2830
MlirType elementType);
2931

@@ -33,6 +35,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx,
3335

3436
MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type);
3537

38+
MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMDescriptorTypeGetTypeID();
39+
3640
MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx);
3741

3842
//===---------------------------------------------------------------------===//
@@ -41,6 +45,8 @@ MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx);
4145

4246
MLIR_CAPI_EXPORTED bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type);
4347

48+
MLIR_CAPI_EXPORTED MlirTypeID mlirAMDGPUTDMGatherBaseTypeGetTypeID();
49+
4450
MLIR_CAPI_EXPORTED MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx,
4551
MlirType elementType,
4652
MlirType indexType);

mlir/lib/Bindings/Python/DialectAMDGPU.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ using namespace mlir::python::nanobind_adaptors;
2020

2121
static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
2222
auto amdgpuTDMBaseType =
23-
mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType);
23+
mlir_type_subclass(m, "TDMBaseType", mlirTypeIsAAMDGPUTDMBaseType,
24+
mlirAMDGPUTDMBaseTypeGetTypeID);
2425

2526
amdgpuTDMBaseType.def_classmethod(
2627
"get",
@@ -31,7 +32,8 @@ static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
3132
nb::arg("element_type"), nb::arg("ctx") = nb::none());
3233

3334
auto amdgpuTDMDescriptorType = mlir_type_subclass(
34-
m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType);
35+
m, "TDMDescriptorType", mlirTypeIsAAMDGPUTDMDescriptorType,
36+
mlirAMDGPUTDMDescriptorTypeGetTypeID);
3537

3638
amdgpuTDMDescriptorType.def_classmethod(
3739
"get",
@@ -42,7 +44,8 @@ static void populateDialectAMDGPUSubmodule(const nb::module_ &m) {
4244
nb::arg("cls"), nb::arg("ctx") = nb::none());
4345

4446
auto amdgpuTDMGatherBaseType = mlir_type_subclass(
45-
m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType);
47+
m, "TDMGatherBaseType", mlirTypeIsAAMDGPUTDMGatherBaseType,
48+
mlirAMDGPUTDMGatherBaseTypeGetTypeID);
4649

4750
amdgpuTDMGatherBaseType.def_classmethod(
4851
"get",

mlir/lib/CAPI/Dialect/AMDGPU.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ bool mlirTypeIsAAMDGPUTDMBaseType(MlirType type) {
2424
return isa<amdgpu::TDMBaseType>(unwrap(type));
2525
}
2626

27+
MlirTypeID mlirAMDGPUTDMBaseTypeGetTypeID() {
28+
return wrap(amdgpu::TDMBaseType::getTypeID());
29+
}
30+
2731
MlirType mlirAMDGPUTDMBaseTypeGet(MlirContext ctx, MlirType elementType) {
2832
return wrap(amdgpu::TDMBaseType::get(unwrap(ctx), unwrap(elementType)));
2933
}
@@ -36,6 +40,10 @@ bool mlirTypeIsAAMDGPUTDMDescriptorType(MlirType type) {
3640
return isa<amdgpu::TDMDescriptorType>(unwrap(type));
3741
}
3842

43+
MlirTypeID mlirAMDGPUTDMDescriptorTypeGetTypeID() {
44+
return wrap(amdgpu::TDMDescriptorType::getTypeID());
45+
}
46+
3947
MlirType mlirAMDGPUTDMDescriptorTypeGet(MlirContext ctx) {
4048
return wrap(amdgpu::TDMDescriptorType::get(unwrap(ctx)));
4149
}
@@ -48,6 +56,10 @@ bool mlirTypeIsAAMDGPUTDMGatherBaseType(MlirType type) {
4856
return isa<amdgpu::TDMGatherBaseType>(unwrap(type));
4957
}
5058

59+
MlirTypeID mlirAMDGPUTDMGatherBaseTypeGetTypeID() {
60+
return wrap(amdgpu::TDMGatherBaseType::getTypeID());
61+
}
62+
5163
MlirType mlirAMDGPUTDMGatherBaseTypeGet(MlirContext ctx, MlirType elementType,
5264
MlirType indexType) {
5365
return wrap(amdgpu::TDMGatherBaseType::get(unwrap(ctx), unwrap(elementType),

0 commit comments

Comments
 (0)