From e77750c96969120de29c66698a1b1f4b79ca00eb Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Fri, 17 Mar 2023 17:54:22 +0000 Subject: [PATCH 1/6] [SYCL-MLIR] Opaque pointer Polygeist to LLVM Signed-off-by: Lukas Sommer --- .../PolygeistToLLVM/PolygeistToLLVM.cpp | 80 +++--- .../Polygeist/Transforms/BareMemRefToLLVM.cpp | 78 +++--- .../test/polygeist-opt/bareptrlowering.mlir | 265 +++++++++--------- 3 files changed, 199 insertions(+), 224 deletions(-) diff --git a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp index 5d51bc6de6599..8e876721f6aa4 100644 --- a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp +++ b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp @@ -174,11 +174,12 @@ struct SubIndexOpLowering : public BaseSubIndexOpLowering { // Handle the general (non-SYCL) case first. if (convViewElemType == - prev.getType().cast().getElementType()) { + transformed.getSource().getType().cast().getElementType()) { auto memRefDesc = createMemRefDescriptor( loc, viewMemRefType, targetMemRef.allocatedPtr(rewriter, loc), - rewriter.create(loc, prev.getType(), prev, idxs), sizes, - strides, rewriter); + rewriter.create(loc, prev.getType(), convViewElemType, + prev, idxs), + sizes, strides, rewriter); rewriter.replaceOp(subViewOp, {memRefDesc}); return success(); @@ -187,6 +188,7 @@ struct SubIndexOpLowering : public BaseSubIndexOpLowering { "Expecting struct type"); // SYCL case + // TODO(Lukas): Opaque pointer handling for SYCL case assert(sourceMemRefType.getRank() == viewMemRefType.getRank() && "Expecting the input and output MemRef ranks to be the same"); @@ -200,8 +202,9 @@ struct SubIndexOpLowering : public BaseSubIndexOpLowering { // polygeist.subindex operation should be a memref of the element type of // the struct. auto elemPtrTy = LLVM::LLVMPointerType::get( - convViewElemType, viewMemRefType.getMemorySpaceAsInt()); - auto gep = rewriter.create(loc, elemPtrTy, prev, indices); + convViewElemType.getContext(), viewMemRefType.getMemorySpaceAsInt()); + auto gep = rewriter.create(loc, elemPtrTy, convViewElemType, + prev, indices); auto memRefDesc = createMemRefDescriptor(loc, viewMemRefType, gep, gep, sizes, strides, rewriter); LLVM_DEBUG(llvm::dbgs() << "SubIndexOpLowering: gep: " << *gep << "\n"); @@ -256,15 +259,16 @@ struct SubIndexBarePtrOpLowering : public BaseSubIndexOpLowering { Type resType = getTypeConverter()->convertType(subViewOp.getType()); // Handle the general (non-SYCL) case first. - if (convViewElemType == - target.getType().cast().getElementType()) { - rewriter.replaceOpWithNewOp(subViewOp, resType, target, idx); + if (convViewElemType == convSourceElemType) { + rewriter.replaceOpWithNewOp(subViewOp, resType, + convViewElemType, target, idx); return success(); } assert(convSourceElemType.isa() && "Expecting struct type"); // SYCL case + // TODO(Lukas): Opaque pointer handling for SYCL case assert(sourceMemRefType.getRank() == viewMemRefType.getRank() && "Expecting the input and output MemRef ranks to be the same"); @@ -278,8 +282,8 @@ struct SubIndexBarePtrOpLowering : public BaseSubIndexOpLowering { // polygeist.subindex operation should be a memref of the element type of // the struct. - rewriter.replaceOpWithNewOp(subViewOp, resType, target, - indices); + rewriter.replaceOpWithNewOp(subViewOp, resType, + convViewElemType, target, indices); return success(); } @@ -303,12 +307,14 @@ struct Memref2PointerOpLowering Value baseOffset = targetMemRef.offset(rewriter, loc); Value ptr = targetMemRef.alignedPtr(rewriter, loc); Value idxs[] = {baseOffset}; - ptr = rewriter.create(loc, ptr.getType(), ptr, idxs); + ptr = rewriter.create( + loc, ptr.getType(), + transformed.getSource().getType().cast().getElementType(), + ptr, idxs); assert(ptr.getType().cast().getAddressSpace() == op.getType().getAddressSpace() && "Expecting Memref2PointerOp source and result types to have the " "same address space"); - ptr = rewriter.create(loc, op.getType(), ptr); rewriter.replaceOp(op, {ptr}); return success(); @@ -335,8 +341,7 @@ struct Pointer2MemrefOpLowering op.getType().cast().getMemorySpaceAsInt() && "Expecting Pointer2MemrefOp source and result types to have the " "same address space"); - auto ptr = rewriter.create( - op.getLoc(), descr.getElementPtrType(), adaptor.getSource()); + auto ptr = adaptor.getSource(); // Extract all strides and offsets and verify they are static. int64_t offset; @@ -396,6 +401,7 @@ struct BareMemref2PointerOpLowering return failure(); const auto target = transformed.getSource(); + // TODO(Lukas): Can we eliminate this bitcast? rewriter.replaceOpWithNewOp(op, op.getType(), target); return success(); @@ -416,6 +422,7 @@ struct BarePointer2MemrefOpLowering const auto convertedType = getTypeConverter()->convertType(op.getType()); if (!convertedType) return failure(); + // TODO(Lukas): CAn we eliminate this bitcast? rewriter.replaceOpWithNewOp(op, convertedType, adaptor.getSource()); return success(); @@ -579,14 +586,11 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { } auto voidTy = LLVM::LLVMVoidType::get(ctx); - auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + auto i8Ptr = LLVM::LLVMPointerType::get(ctx); auto resumeOp = moduleBuilder.create( fname, LLVM::LLVMFunctionType::get( - voidTy, {i8Ptr, - LLVM::LLVMPointerType::get( - LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})), - streamTy})); + voidTy, {i8Ptr, LLVM::LLVMPointerType::get(ctx), streamTy})); resumeOp.setPrivate(); return resumeOp; @@ -604,7 +608,7 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { Location loc = execute.getLoc(); auto voidTy = LLVM::LLVMVoidType::get(ctx); - Type voidPtr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + Type voidPtr = LLVM::LLVMPointerType::get(ctx); // Make sure that all constants will be inside the outlined async function // to reduce the number of function arguments. @@ -668,11 +672,7 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { } else if (functionInputs.size() == 1 && converter->convertType(functionInputs[0].getType()) .isa()) { - valueMapping.map( - functionInputs[0], - rewriter.create( - execute.getLoc(), - converter->convertType(functionInputs[0].getType()), arg)); + valueMapping.map(functionInputs[0], arg); } else if (functionInputs.size() == 1 && converter->convertType(functionInputs[0].getType()) .isa()) { @@ -685,20 +685,18 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { SmallVector types; for (auto v : functionInputs) types.push_back(converter->convertType(v.getType())); - auto ST = LLVM::LLVMStructType::getLiteral(ctx, types); - auto alloc = rewriter.create( - execute.getLoc(), LLVM::LLVMPointerType::get(ST), arg); + for (auto idx : llvm::enumerate(functionInputs)) { mlir::Value idxs[] = { rewriter.create(loc, 0, 32), rewriter.create(loc, idx.index(), 32), }; + auto nextTy = types[idx.index()]; Value next = rewriter.create( - loc, LLVM::LLVMPointerType::get(idx.value().getType()), alloc, - idxs); + loc, LLVM::LLVMPointerType::get(ctx), nextTy, arg, idxs); valueMapping.map(idx.value(), - rewriter.create(loc, next)); + rewriter.create(loc, nextTy, next)); } auto freef = getFreeFn(*getTypeConverter(), module); Value args[] = {arg}; @@ -729,8 +727,7 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { } else if (crossing.size() == 1 && converter->convertType(crossing[0].getType()) .isa()) { - vals.push_back(rewriter.create(execute.getLoc(), - voidPtr, crossing[0])); + vals.push_back(crossing[0]); } else if (crossing.size() == 1 && converter->convertType(crossing[0].getType()) .isa()) { @@ -748,10 +745,8 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { loc, rewriter.getI64Type(), rewriter.create(loc, rewriter.getIndexType(), ST))}; - mlir::Value alloc = rewriter.create( - loc, LLVM::LLVMPointerType::get(ST), - rewriter.create(loc, mallocf, args) - .getResult()); + mlir::Value alloc = + rewriter.create(loc, mallocf, args).getResult(); rewriter.setInsertionPoint(execute); for (auto idx : llvm::enumerate(crossing)) { @@ -759,13 +754,12 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { rewriter.create(loc, 0, 32), rewriter.create(loc, idx.index(), 32), }; - Value next = rewriter.create( - loc, LLVM::LLVMPointerType::get(idx.value().getType()), alloc, - idxs); + Value next = + rewriter.create(loc, LLVM::LLVMPointerType::get(ctx), + idx.value().getType(), alloc, idxs); rewriter.create(loc, idx.value(), next); } - vals.push_back( - rewriter.create(execute.getLoc(), voidPtr, alloc)); + vals.push_back(alloc); } vals.push_back( rewriter.create(execute.getLoc(), func)); @@ -922,6 +916,7 @@ struct ConvertPolygeistToLLVMPass LowerToLLVMOptions options(&getContext(), dataLayoutAnalysis.getAtOrAbove(m)); options.useBarePtrCallConv = true; + options.useOpaquePointers = true; if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); @@ -931,7 +926,6 @@ struct ConvertPolygeistToLLVMPass LLVMTypeConverter converter(&getContext(), options, &dataLayoutAnalysis); RewritePatternSet patterns(&getContext()); - // Keep these at the top; these should be run before the rest of // function conversion patterns. populateReturnOpTypeConversionPattern(patterns, converter); diff --git a/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp b/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp index fc78e7c35b32b..274a5f4c0b512 100644 --- a/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp +++ b/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp @@ -34,44 +34,31 @@ struct GetGlobalMemrefOpLowering if (!canBeLoweredToBarePtr(memrefTy)) return failure(); - const auto arrayTy = - convertGlobalMemrefTypeToLLVM(memrefTy, *typeConverter); - if (!arrayTy) + // LLVM type for a global memref will be a multi-dimension array. For + // declarations or uninitialized global memrefs, we can potentially flatten + // this to a 1D array. However, for memref.global's with an initial value, + // we do not intend to flatten the ElementsAttribute when going from std -> + // LLVM dialect, so the LLVM type needs to me a multi-dimension array. + const auto convElemType = + typeConverter->convertType(memrefTy.getElementType()); + if (!convElemType) return failure(); const auto addressOf = static_cast(rewriter.create( getGlobalOp.getLoc(), - LLVM::LLVMPointerType::get(arrayTy, memrefTy.getMemorySpaceAsInt()), + LLVM::LLVMPointerType::get(memrefTy.getContext(), + memrefTy.getMemorySpaceAsInt()), adaptor.getName())); // Get the address of the first element in the array by creating a GEP with // the address of the GV as the base, and (rank + 1) number of 0 indices. rewriter.replaceOpWithNewOp( - getGlobalOp, typeConverter->convertType(memrefTy), addressOf, - SmallVector(memrefTy.getRank() + 1, 0), + getGlobalOp, typeConverter->convertType(memrefTy), convElemType, + addressOf, SmallVector(memrefTy.getRank() + 1, 0), /* inbounds */ true); return success(); } - -private: - /// Returns the LLVM type of the global variable given the memref type `type`. - static Type convertGlobalMemrefTypeToLLVM(MemRefType type, - TypeConverter &typeConverter) { - // LLVM type for a global memref will be a multi-dimension array. For - // declarations or uninitialized global memrefs, we can potentially flatten - // this to a 1D array. However, for memref.global's with an initial value, - // we do not intend to flatten the ElementsAttribute when going from std -> - // LLVM dialect, so the LLVM type needs to me a multi-dimension array. - const auto convElemTy = typeConverter.convertType(type.getElementType()); - if (!convElemTy) - return {}; - // Shape has the outermost dim at index 0, so need to walk it backwards - const auto shape = type.getShape(); - return std::accumulate( - shape.rbegin(), shape.rend(), convElemTy, - [](auto ty, auto dim) { return LLVM::LLVMArrayType::get(ty, dim); }); - } }; /// Simply replace by the source, as we don't care about the shape. @@ -108,17 +95,19 @@ struct AllocaMemrefOpLowering const auto ptrType = typeConverter->convertType(allocaOp.getType()); if (!ptrType) return failure(); + const auto convElemType = + typeConverter->convertType(memrefType.getElementType()); const auto loc = allocaOp.getLoc(); auto nullPtr = rewriter.create(loc, ptrType); auto gepPtr = rewriter.create( - loc, ptrType, nullPtr, - createIndexConstant(rewriter, loc, - allocaOp.getType().getNumElements())); + loc, ptrType, convElemType, nullPtr, + createIndexConstant(rewriter, loc, memrefType.getNumElements())); auto sizeBytes = rewriter.create(loc, getIndexType(), gepPtr); rewriter.replaceOpWithNewOp( - allocaOp, ptrType, sizeBytes, allocaOp.getAlignment().value_or(0)); + allocaOp, ptrType, convElemType, sizeBytes, + allocaOp.getAlignment().value_or(0)); return success(); } }; @@ -166,10 +155,10 @@ struct AllocMemrefOpLowering : public ConvertOpToLLVMPattern { const auto allocFuncOp = getAllocFn(*getTypeConverter(), module, getIndexType()); - const auto results = - rewriter.create(loc, allocFuncOp, sizeBytes).getResults(); auto alignedPtr = static_cast( - rewriter.create(loc, elementPtrType, results)); + rewriter.create(loc, allocFuncOp, sizeBytes) + .getResults() + .front()); if (alignment) { // Compute the aligned pointer. const auto allocatedInt = static_cast( @@ -198,12 +187,8 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { // Insert the `free` declaration if it is not already present. const auto freeFunc = getFreeFn(*getTypeConverter(), deallocOp->getParentOfType()); - const auto casted = - rewriter - .create(deallocOp.getLoc(), getVoidPtrType(), - adaptor.getMemref()) - .getRes(); - rewriter.replaceOpWithNewOp(deallocOp, freeFunc, casted); + rewriter.replaceOpWithNewOp(deallocOp, freeFunc, + adaptor.getMemref()); return success(); } }; @@ -275,9 +260,11 @@ struct MemAccessLowering : public ConvertToLLVMPattern { const auto elementPtrType = getTypeConverter()->convertType(type); if (!elementPtrType) return {}; - return index - ? rewriter.create(loc, elementPtrType, base, index) - : base; + const auto convElemType = + getTypeConverter()->convertType(type.getElementType()); + return index ? rewriter.create(loc, elementPtrType, + convElemType, base, index) + : base; } }; @@ -300,7 +287,8 @@ struct LoadMemRefOpLowering : public MemAccessLowering { adaptor.getIndices(), rewriter); if (!DataPtr) return failure(); - rewriter.replaceOpWithNewOp(op, DataPtr); + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(loadOp.getType()), DataPtr); return success(); } }; @@ -345,9 +333,7 @@ void mlir::polygeist::populateBareMemRefToLLVMConversionPatterns( converter.addConversion([&](MemRefType type) -> Optional { if (!canBeLoweredToBarePtr(type)) return std::nullopt; - const auto elemType = converter.convertType(type.getElementType()); - if (!elemType) - return Type{}; - return LLVM::LLVMPointerType::get(elemType, type.getMemorySpaceAsInt()); + return LLVM::LLVMPointerType::get(type.getContext(), + type.getMemorySpaceAsInt()); }); } diff --git a/polygeist/test/polygeist-opt/bareptrlowering.mlir b/polygeist/test/polygeist-opt/bareptrlowering.mlir index a09d8f7660a03..19385545f4d96 100644 --- a/polygeist/test/polygeist-opt/bareptrlowering.mlir +++ b/polygeist/test/polygeist-opt/bareptrlowering.mlir @@ -1,36 +1,36 @@ // RUN: polygeist-opt %s --convert-polygeist-to-llvm --split-input-file | FileCheck %s -// CHECK-LABEL: llvm.func @ptr_ret_static(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_ret_static(i64) -> !llvm.ptr func.func private @ptr_ret_static(%arg0: i64) -> memref<4xi64> // ----- -// CHECK-LABEL: llvm.func @ptr_ret_dynamic(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_ret_dynamic(i64) -> !llvm.ptr func.func private @ptr_ret_dynamic(%arg0: i64) -> memref // ----- -// CHECK-LABEL: llvm.func @ptr_ret_nd_static(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_ret_nd_static(i64) -> !llvm.ptr func.func private @ptr_ret_nd_static(%arg0: i64) -> memref<4x4xi64> // ----- -// CHECK-LABEL: llvm.func @ptr_ret_nd_dynamic(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_ret_nd_dynamic(i64) -> !llvm.ptr func.func private @ptr_ret_nd_dynamic(%arg0: i64) -> memref // ----- -// CHECK-LABEL: llvm.func @ptr_args_and_ret(!llvm.ptr, !llvm.ptr) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_args_and_ret(!llvm.ptr, !llvm.ptr) -> !llvm.ptr func.func private @ptr_args_and_ret(%arg0: memref<1xi64>, %arg1: memref) -> memref // ----- -// CHECK-LABEL: llvm.func @ptr_args_and_ret_with_attrs(!llvm.ptr {llvm.byval = i64}, !llvm.ptr {llvm.byval = i64}) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_args_and_ret_with_attrs(!llvm.ptr {llvm.byval = i64}, !llvm.ptr {llvm.byval = i64}) -> !llvm.ptr func.func private @ptr_args_and_ret_with_attrs(%arg0: memref<1xi64> {llvm.byval = i64}, %arg1: memref {llvm.byval = i64}) -> memref @@ -40,8 +40,8 @@ func.func private @ptr_args_and_ret_with_attrs(%arg0: memref<1xi64> {llvm.byval gpu.module @kernels { // CHECK-LABEL: llvm.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr {llvm.byval = i64}, -// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr {llvm.byval = i64}) attributes {gpu.kernel, workgroup_attributions = 0 : i64} { +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr {llvm.byval = i64}, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr {llvm.byval = i64}) attributes {gpu.kernel, workgroup_attributions = 0 : i64} { // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -57,10 +57,10 @@ gpu.module @kernels { memref.global @global : memref<3xi64> -// CHECK-LABEL: llvm.func @get_global() -> !llvm.ptr -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global : !llvm.ptr> -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr>) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-LABEL: llvm.func @get_global() -> !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK-NEXT: } func.func private @get_global() -> memref<3xi64> { @@ -74,10 +74,10 @@ func.func private @get_global() -> memref<3xi64> { memref.global @global_addrspace : memref<3xi64, 4> -// CHECK-LABEL: llvm.func @get_global_addrspace() -> !llvm.ptr -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global_addrspace : !llvm.ptr, 4> -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr, 4>) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-LABEL: llvm.func @get_global_addrspace() -> !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global_addrspace : !llvm.ptr<4> +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr<4>) -> !llvm.ptr<4>, i64 +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr<4> // CHECK-NEXT: } func.func private @get_global_addrspace() -> memref<3xi64, 4> { @@ -90,9 +90,9 @@ func.func private @get_global_addrspace() -> memref<3xi64, 4> { memref.global "private" constant @shape : memref<2xi64> = dense<[2, 2]> // CHECK-LABEL: llvm.func @reshape( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr>) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @reshape(%arg0: memref<4xi32>) -> memref<2x2xi32> { @@ -106,9 +106,9 @@ func.func private @reshape(%arg0: memref<4xi32>) -> memref<2x2xi32> { memref.global "private" constant @shape : memref<1xindex> // CHECK-LABEL: llvm.func @reshape_dyn( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr>) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @reshape_dyn(%arg0: memref<4xi32>) -> memref { @@ -120,11 +120,11 @@ func.func private @reshape_dyn(%arg0: memref<4xi32>) -> memref { // ----- // CHECK-LABEL: llvm.func @alloca() -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -136,11 +136,11 @@ func.func private @alloca() { // ----- // CHECK-LABEL: llvm.func @alloca_nd() -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -152,11 +152,11 @@ func.func private @alloca_nd() { // ----- // CHECK-LABEL: llvm.func @alloca_aligned() -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -168,11 +168,11 @@ func.func private @alloca_aligned() { // ----- // CHECK-LABEL: llvm.func @alloca_nd_aligned() -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -183,16 +183,15 @@ func.func private @alloca_nd_aligned() { // ----- -// CHECK-LABEL: llvm.func @malloc(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @malloc(i64) -> !llvm.ptr // CHECK-LABEL: llvm.func @alloc() // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_5:.*]] = llvm.call @malloc(%[[VAL_4]]) : (i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.bitcast %[[VAL_5]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.call @malloc(%[[VAL_4]]) : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -210,11 +209,10 @@ func.func private @alloc() { // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_9:.*]] = llvm.call @malloc(%[[VAL_8]]) : (i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.call @malloc(%[[VAL_8]]) : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -228,20 +226,19 @@ func.func private @alloc_nd() { // CHECK-LABEL: llvm.func @alloc_aligned() // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_4]], %[[VAL_5]] : i64 -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.call @malloc(%[[VAL_6]]) : (i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: %[[VAL_9:.*]] = llvm.ptrtoint %[[VAL_8]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK-NEXT: %[[VAL_11:.*]] = llvm.sub %[[VAL_5]], %[[VAL_10]] : i64 -// CHECK-NEXT: %[[VAL_12:.*]] = llvm.add %[[VAL_9]], %[[VAL_11]] : i64 -// CHECK-NEXT: %[[VAL_13:.*]] = llvm.urem %[[VAL_12]], %[[VAL_5]] : i64 -// CHECK-NEXT: %[[VAL_14:.*]] = llvm.sub %[[VAL_12]], %[[VAL_13]] : i64 -// CHECK-NEXT: %[[VAL_15:.*]] = llvm.inttoptr %[[VAL_14]] : i64 to !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.call @malloc(%[[VAL_6]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-NEXT: %[[VAL_10:.*]] = llvm.sub %[[VAL_5]], %[[VAL_9]] : i64 +// CHECK-NEXT: %[[VAL_11:.*]] = llvm.add %[[VAL_8]], %[[VAL_10]] : i64 +// CHECK-NEXT: %[[VAL_12:.*]] = llvm.urem %[[VAL_11]], %[[VAL_5]] : i64 +// CHECK-NEXT: %[[VAL_13:.*]] = llvm.sub %[[VAL_11]], %[[VAL_12]] : i64 +// CHECK-NEXT: %[[VAL_14:.*]] = llvm.inttoptr %[[VAL_13]] : i64 to !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -259,20 +256,19 @@ func.func private @alloc_aligned() { // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_9:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK-NEXT: %[[VAL_10:.*]] = llvm.add %[[VAL_8]], %[[VAL_9]] : i64 -// CHECK-NEXT: %[[VAL_11:.*]] = llvm.call @malloc(%[[VAL_10]]) : (i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: %[[VAL_13:.*]] = llvm.ptrtoint %[[VAL_12]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_14:.*]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK-NEXT: %[[VAL_15:.*]] = llvm.sub %[[VAL_9]], %[[VAL_14]] : i64 -// CHECK-NEXT: %[[VAL_16:.*]] = llvm.add %[[VAL_13]], %[[VAL_15]] : i64 -// CHECK-NEXT: %[[VAL_17:.*]] = llvm.urem %[[VAL_16]], %[[VAL_9]] : i64 -// CHECK-NEXT: %[[VAL_18:.*]] = llvm.sub %[[VAL_16]], %[[VAL_17]] : i64 -// CHECK-NEXT: %[[VAL_19:.*]] = llvm.inttoptr %[[VAL_18]] : i64 to !llvm.ptr +// CHECK-NEXT: %[[VAL_11:.*]] = llvm.call @malloc(%[[VAL_10]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_12:.*]] = llvm.ptrtoint %[[VAL_11]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_13:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-NEXT: %[[VAL_14:.*]] = llvm.sub %[[VAL_9]], %[[VAL_13]] : i64 +// CHECK-NEXT: %[[VAL_15:.*]] = llvm.add %[[VAL_12]], %[[VAL_14]] : i64 +// CHECK-NEXT: %[[VAL_16:.*]] = llvm.urem %[[VAL_15]], %[[VAL_9]] : i64 +// CHECK-NEXT: %[[VAL_17:.*]] = llvm.sub %[[VAL_15]], %[[VAL_16]] : i64 +// CHECK-NEXT: %[[VAL_18:.*]] = llvm.inttoptr %[[VAL_17]] : i64 to !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -284,9 +280,8 @@ func.func private @alloc_nd_aligned() { // ----- // CHECK-LABEL: llvm.func @dealloc( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: llvm.call @free(%[[VAL_1]]) : (!llvm.ptr) -> () +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) +// CHECK-NEXT: llvm.call @free(%[[VAL_0]]) : (!llvm.ptr) -> () // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -298,8 +293,8 @@ func.func private @dealloc(%arg0: memref) { // ----- // CHECK-LABEL: llvm.func @cast( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @cast(%arg0: memref<2xi32>) -> memref { @@ -310,10 +305,10 @@ func.func private @cast(%arg0: memref<2xi32>) -> memref { // ----- // CHECK-LABEL: llvm.func @load( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64) -> f32 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.load %[[VAL_2]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.load %[[VAL_2]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_3]] : f32 // CHECK-NEXT: } @@ -325,14 +320,14 @@ func.func private @load(%arg0: memref<100xf32>, %index: index) -> f32 { // ----- // CHECK-LABEL: llvm.func @load_nd( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, // CHECK-SAME: %[[VAL_2:.*]]: i64) -> f32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_7]] : f32 // CHECK-NEXT: } @@ -344,14 +339,14 @@ func.func private @load_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: // ----- // CHECK-LABEL: llvm.func @load_nd_dyn( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, // CHECK-SAME: %[[VAL_2:.*]]: i64) -> f32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_7]] : f32 // CHECK-NEXT: } @@ -363,11 +358,11 @@ func.func private @load_nd_dyn(%arg0: memref, %index0: index, %index1 // ----- // CHECK-LABEL: llvm.func @store( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, // CHECK-SAME: %[[VAL_2:.*]]: f32) -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.store %[[VAL_2]], %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.store %[[VAL_2]], %[[VAL_3]] : f32, !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -379,14 +374,14 @@ func.func private @store(%arg0: memref<100xf32>, %index: index, %val: f32) { // ----- // CHECK-LABEL: llvm.func @store_nd( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, // CHECK-SAME: %[[VAL_3:.*]]: f32) // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mul %[[VAL_1]], %[[VAL_4]] : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_5]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : f32, !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -398,14 +393,14 @@ func.func private @store_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: // ----- // CHECK-LABEL: llvm.func @store_nd_dyn( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, // CHECK-SAME: %[[VAL_3:.*]]: f32) // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mul %[[VAL_1]], %[[VAL_4]] : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_5]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : f32, !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -416,15 +411,15 @@ func.func private @store_nd_dyn(%arg0: memref, %index0: index, %index // ----- -// CHECK-LABEL: llvm.func @impl(!llvm.ptr, i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @impl(!llvm.ptr, i64) -> !llvm.ptr func.func private @impl(%arg0: memref, %arg1: index) -> memref // CHECK-LABEL: llvm.func @call( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.call @impl(%[[VAL_0]], %[[VAL_1]]) : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.call @impl(%[[VAL_0]], %[[VAL_1]]) : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr // CHECK-NEXT: } func.func private @call(%arg0: memref, %arg1: index) -> memref { @@ -435,12 +430,12 @@ func.func private @call(%arg0: memref, %arg1: index) -> memref { // ----- // CHECK-LABEL: llvm.func @subindexop_memref( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(4 : i64) : i64 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mul %[[VAL_1]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_4]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_4]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref(%arg0: memref<4x4xf32>, %arg1: index) -> memref<4xf32> { @@ -451,10 +446,10 @@ func.func private @subindexop_memref(%arg0: memref<4x4xf32>, %arg1: index) -> me // ----- // CHECK-LABEL: llvm.func @subindexop_memref_same_dim( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_same_dim(%arg0: memref<4x4xf32>, %arg1: index) -> memref<4x4xf32> { @@ -465,11 +460,11 @@ func.func private @subindexop_memref_same_dim(%arg0: memref<4x4xf32>, %arg1: ind // ----- // CHECK-LABEL: llvm.func @subindexop_memref_struct( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0] : (!llvm.ptr>, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0] : (!llvm.ptr, i64, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_struct(%arg0: memref<4x!llvm.struct<(f32)>>) -> memref { @@ -481,11 +476,11 @@ func.func private @subindexop_memref_struct(%arg0: memref<4x!llvm.struct<(f32)>> // ----- // CHECK-LABEL: llvm.func @subindexop_memref_nested_struct( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>>) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, 0] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_nested_struct(%arg0: memref<4x!llvm.struct<(struct<(f32)>)>>) -> memref { @@ -497,11 +492,11 @@ func.func private @subindexop_memref_nested_struct(%arg0: memref<4x!llvm.struct< // ----- // CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_ptr( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>)>>) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr>)>>, i64, i64, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_nested_struct_ptr(%arg0: memref<4x!llvm.struct<(ptr>)>>) -> memref { @@ -513,11 +508,11 @@ func.func private @subindexop_memref_nested_struct_ptr(%arg0: memref<4x!llvm.str // ----- // CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_array( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>)>>) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], 0] : (!llvm.ptr>)>>, i64, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], 0] : (!llvm.ptr, i64, i64, i64, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_nested_struct_array(%arg0: memref<4x!llvm.struct<(array<4x!llvm.struct<(f32)>>)>>) -> memref { @@ -529,26 +524,26 @@ func.func private @subindexop_memref_nested_struct_array(%arg0: memref<4x!llvm.s // ----- // CHECK-LABEL: llvm.func @memref2ptr( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK-NEXT: } -func.func private @memref2ptr(%arg0: memref<4xf32>) -> !llvm.ptr { - %res = "polygeist.memref2pointer"(%arg0) : (memref<4xf32>) -> !llvm.ptr - return %res : !llvm.ptr +func.func private @memref2ptr(%arg0: memref<4xf32>) -> !llvm.ptr { + %res = "polygeist.memref2pointer"(%arg0) : (memref<4xf32>) -> !llvm.ptr + return %res : !llvm.ptr } // ----- // CHECK-LABEL: llvm.func @ptr2memref( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK-NEXT: } -func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { - %res = "polygeist.pointer2memref"(%arg0) : (!llvm.ptr) -> memref +func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { + %res = "polygeist.pointer2memref"(%arg0) : (!llvm.ptr) -> memref return %res : memref } @@ -557,13 +552,13 @@ func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { #layout = affine_map<(s0) -> (s0 - 1)> // CHECK-LABEL: llvm.func @non_bare_due_to_layout( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, // CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(-1 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.add %[[VAL_3]], %[[VAL_1]] : i64 -// CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_6]] : i64 // CHECK-NEXT: } From 9bbb47ca45649df245bc317f4668c03b443575fc Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Fri, 24 Mar 2023 12:51:31 +0000 Subject: [PATCH 2/6] Remove unncessary bitcasts and fix tests Signed-off-by: Lukas Sommer --- .../PolygeistToLLVM/PolygeistToLLVM.cpp | 17 +++++----- .../test/polygeist-opt/bareptrlowering.mlir | 34 +++++-------------- polygeist/test/polygeist-opt/sycl/cast.mlir | 18 +++++----- .../test/polygeist-opt/sycl/subindex.mlir | 22 ++++++------ 4 files changed, 38 insertions(+), 53 deletions(-) diff --git a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp index 8e876721f6aa4..8b8efd4d307b5 100644 --- a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp +++ b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp @@ -103,8 +103,10 @@ struct BaseSubIndexOpLowering : public ConvertOpToLLVMPattern { assert(t.getBody().size() == 1 && "Expecting single member type"); currType = t.getBody()[0]; }) - .Case( + .Case( [&](auto t) { currType = t.getElementType(); }) + .Case( + [&](auto t) { assert(false && "Pointer type not allowed"); }) .Default([&](Type t) { currType = t; assert(currType == resElemType && @@ -188,7 +190,6 @@ struct SubIndexOpLowering : public BaseSubIndexOpLowering { "Expecting struct type"); // SYCL case - // TODO(Lukas): Opaque pointer handling for SYCL case assert(sourceMemRefType.getRank() == viewMemRefType.getRank() && "Expecting the input and output MemRef ranks to be the same"); @@ -268,7 +269,6 @@ struct SubIndexBarePtrOpLowering : public BaseSubIndexOpLowering { "Expecting struct type"); // SYCL case - // TODO(Lukas): Opaque pointer handling for SYCL case assert(sourceMemRefType.getRank() == viewMemRefType.getRank() && "Expecting the input and output MemRef ranks to be the same"); @@ -401,8 +401,9 @@ struct BareMemref2PointerOpLowering return failure(); const auto target = transformed.getSource(); - // TODO(Lukas): Can we eliminate this bitcast? - rewriter.replaceOpWithNewOp(op, op.getType(), target); + // In an opaque pointer world, a bitcast is a no-op, so no need to insert + // one here. + rewriter.replaceOp(op, target); return success(); } @@ -422,9 +423,9 @@ struct BarePointer2MemrefOpLowering const auto convertedType = getTypeConverter()->convertType(op.getType()); if (!convertedType) return failure(); - // TODO(Lukas): CAn we eliminate this bitcast? - rewriter.replaceOpWithNewOp(op, convertedType, - adaptor.getSource()); + // In an opaque pointer world, a bitcast is a no-op, so no need to insert + // one here. + rewriter.replaceOp(op, adaptor.getSource()); return success(); } }; diff --git a/polygeist/test/polygeist-opt/bareptrlowering.mlir b/polygeist/test/polygeist-opt/bareptrlowering.mlir index 19385545f4d96..4609998ad6d83 100644 --- a/polygeist/test/polygeist-opt/bareptrlowering.mlir +++ b/polygeist/test/polygeist-opt/bareptrlowering.mlir @@ -463,7 +463,7 @@ func.func private @subindexop_memref_same_dim(%arg0: memref<4x4xf32>, %arg1: ind // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0] : (!llvm.ptr, i64, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } @@ -476,10 +476,10 @@ func.func private @subindexop_memref_struct(%arg0: memref<4x!llvm.struct<(f32)>> // ----- // CHECK-LABEL: llvm.func @subindexop_memref_nested_struct( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, 0] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } @@ -491,27 +491,11 @@ func.func private @subindexop_memref_nested_struct(%arg0: memref<4x!llvm.struct< // ----- -// CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_ptr( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr -// CHECK-NEXT: } - -func.func private @subindexop_memref_nested_struct_ptr(%arg0: memref<4x!llvm.struct<(ptr>)>>) -> memref { - %c_0 = arith.constant 0 : index - %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(ptr>)>>, index) -> memref - return %res : memref -} - -// ----- - // CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_array( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], 0] : (!llvm.ptr, i64, i64, i64, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64, i64, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } @@ -525,8 +509,7 @@ func.func private @subindexop_memref_nested_struct_array(%arg0: memref<4x!llvm.s // CHECK-LABEL: llvm.func @memref2ptr( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @memref2ptr(%arg0: memref<4xf32>) -> !llvm.ptr { @@ -538,8 +521,7 @@ func.func private @memref2ptr(%arg0: memref<4xf32>) -> !llvm.ptr { // CHECK-LABEL: llvm.func @ptr2memref( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { @@ -552,9 +534,9 @@ func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { #layout = affine_map<(s0) -> (s0 - 1)> // CHECK-LABEL: llvm.func @non_bare_due_to_layout( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, // CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(-1 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.add %[[VAL_3]], %[[VAL_1]] : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr diff --git a/polygeist/test/polygeist-opt/sycl/cast.mlir b/polygeist/test/polygeist-opt/sycl/cast.mlir index 4d52f3c1344e5..89917ff781b26 100644 --- a/polygeist/test/polygeist-opt/sycl/cast.mlir +++ b/polygeist/test/polygeist-opt/sycl/cast.mlir @@ -4,9 +4,9 @@ !sycl_range_1_ = !sycl.range<[1], (!sycl_array_1_)> // CHECK-LABEL: llvm.func @test1( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>>) -> !llvm.ptr)>> { -// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>> to !llvm.ptr)>> -// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>> +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK: } func.func @test1(%arg0: memref) -> memref { @@ -17,9 +17,9 @@ func.func @test1(%arg0: memref) -> memref { // ----- // CHECK-LABEL: llvm.func @test2( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>>) -> !llvm.ptr)>> { -// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>> to !llvm.ptr)>> -// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>> +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK: } !sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> @@ -32,9 +32,9 @@ func.func @test2(%arg0: memref) -> memref { // ----- // CHECK-LABEL: llvm.func @test_addrspaces( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>, 4>) -> !llvm.ptr)>, 4> { -// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>, 4> to !llvm.ptr)>, 4> -// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>, 4> +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<4>) -> !llvm.ptr<4> { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr<4> to !llvm.ptr<4> +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr<4> // CHECK: } !sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> diff --git a/polygeist/test/polygeist-opt/sycl/subindex.mlir b/polygeist/test/polygeist-opt/sycl/subindex.mlir index 4fb19075221e0..50eab37fdeaff 100644 --- a/polygeist/test/polygeist-opt/sycl/subindex.mlir +++ b/polygeist/test/polygeist-opt/sycl/subindex.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: @test_1 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr %{{.*}}[[[ZERO]], 0] : (!llvm.ptr !llvm.ptr<[[SYCLIDSTRUCT]], {{.*}} +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr %{{.*}}[[[ZERO]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::id.1", {{.*}} // CHECK-NEXT: llvm.return [[GEP]] !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> @@ -15,7 +15,8 @@ func.func @test_1(%arg0: memref>) -> memref !llvm.ptr, !llvm.struct<"class.sycl::_V1::detail::AccessorImplDevice.1", {{.*}} +// CHECK-NEXT: llvm.return [[GEP]] !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> !sycl_range_1_ = !sycl.range<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> @@ -30,10 +31,11 @@ func.func @test_2(%arg0: memref) -> memref>) -> !llvm.ptr { +// CHECK: llvm.func @test_3([[A0:.*]]: !llvm.ptr) -> !llvm.ptr { +// CHECK: [[IDX_ZERO:%.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO]], 0] : (!llvm.ptr>, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO]], [[IDX_ZERO]]] : (!llvm.ptr, i64, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr func.func @test_3(%arg0: memref>) -> memref { %c0 = arith.constant 0 : index @@ -43,9 +45,9 @@ func.func @test_3(%arg0: memref>) -> memref { // ----- -// CHECK: llvm.func @test_4([[A0:%.*]]: !llvm.ptr\)>\)>]])>>, [[A5:%.*]]: i64) -> !llvm.ptr)>)>)>> { -// CHECK: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[A5]]] : (!llvm.ptr>, i64) -> !llvm.ptr> -// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr> +// CHECK: llvm.func @test_4([[A0:%.*]]: !llvm.ptr, [[A5:%.*]]: i64) -> !llvm.ptr { +// CHECK: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[A5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> func.func @test_4(%arg0: memref<1x!llvm.struct<(!sycl_id_1_)>>, %arg1: index) -> memref> { @@ -55,10 +57,10 @@ func.func @test_4(%arg0: memref<1x!llvm.struct<(!sycl_id_1_)>>, %arg1: index) -> // ----- -// CHECK: llvm.func @test_5([[A0:%.*]]: !llvm.ptr<[[ARRTYPE:struct<"class.sycl::_V1::detail::array.1", \(array<1 x i64>\)>]], 4>) -> !llvm.ptr { +// CHECK: llvm.func @test_5([[A0:%.*]]: !llvm.ptr<4>) -> !llvm.ptr<4> { // CHECK-DAG: [[ZERO1:%.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-DAG: [[ZERO2:%.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO2]], 0, [[ZERO1]]] : (!llvm.ptr<[[ARRTYPE]], 4>, i64, i64) -> !llvm.ptr +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO2]], [[ZERO2]], [[ZERO1]]] : (!llvm.ptr<4>, i64, i64, i64) -> !llvm.ptr<4>, i64 !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> func.func @test_5(%arg0: memref)>, 4>) -> memref<1xi64, 4> { From 49f93e1e78443c78676c19e5d4335d07949b1fd9 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Fri, 24 Mar 2023 16:04:23 +0000 Subject: [PATCH 3/6] Version BareMemRefToLLVM patterns Signed-off-by: Lukas Sommer --- .../Dialect/Polygeist/Transforms/Passes.h | 3 +- .../Polygeist/Transforms/BareMemRefToLLVM.cpp | 355 +++++++++++++++++- 2 files changed, 345 insertions(+), 13 deletions(-) diff --git a/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h b/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h index bc4db810755dd..bdec6b31a485a 100644 --- a/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h +++ b/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h @@ -26,7 +26,8 @@ namespace polygeist { /// MemRef dialect to the LLVM dialect forcing a "bare pointer" calling /// convention. void populateBareMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + bool useOpaquePointer = false); #define GEN_PASS_DECL #include "mlir/Dialect/Polygeist/Transforms/Passes.h.inc" diff --git a/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp b/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp index 274a5f4c0b512..79ffc85a3518b 100644 --- a/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp +++ b/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp @@ -318,22 +318,353 @@ struct StoreMemRefOpLowering : public MemAccessLowering { }; } // namespace +// The following patterns are outdated and only used in case typed pointers +// should be used for the lowering. They will be removed soon. +namespace { +/// Conversion similar to the canonical one, but not inserting the obtained +/// pointer in a struct. +struct GetGlobalMemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::GetGlobalOp getGlobalOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto memrefTy = getGlobalOp.getType(); + if (!canBeLoweredToBarePtr(memrefTy)) + return failure(); + + const auto arrayTy = + convertGlobalMemrefTypeToLLVM(memrefTy, *typeConverter); + if (!arrayTy) + return failure(); + const auto addressOf = + static_cast(rewriter.create( + getGlobalOp.getLoc(), + LLVM::LLVMPointerType::get(arrayTy, memrefTy.getMemorySpaceAsInt()), + adaptor.getName())); + + // Get the address of the first element in the array by creating a GEP with + // the address of the GV as the base, and (rank + 1) number of 0 indices. + rewriter.replaceOpWithNewOp( + getGlobalOp, typeConverter->convertType(memrefTy), addressOf, + SmallVector(memrefTy.getRank() + 1, 0), + /* inbounds */ true); + + return success(); + } + +private: + /// Returns the LLVM type of the global variable given the memref type `type`. + static Type convertGlobalMemrefTypeToLLVM(MemRefType type, + TypeConverter &typeConverter) { + // LLVM type for a global memref will be a multi-dimension array. For + // declarations or uninitialized global memrefs, we can potentially flatten + // this to a 1D array. However, for memref.global's with an initial value, + // we do not intend to flatten the ElementsAttribute when going from std -> + // LLVM dialect, so the LLVM type needs to me a multi-dimension array. + const auto convElemTy = typeConverter.convertType(type.getElementType()); + if (!convElemTy) + return {}; + // Shape has the outermost dim at index 0, so need to walk it backwards + const auto shape = type.getShape(); + return std::accumulate( + shape.rbegin(), shape.rend(), convElemTy, + [](auto ty, auto dim) { return LLVM::LLVMArrayType::get(ty, dim); }); + } +}; + +/// Simply replace by the source, as we don't care about the shape. +struct ReshapeMemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ReshapeOp reshape, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!canBeLoweredToBarePtr(reshape.getType()) || + !canBeLoweredToBarePtr( + reshape.getSource().getType().cast())) + return failure(); + + rewriter.replaceOp(reshape, adaptor.getSource()); + return success(); + } +}; + +/// Conversion similar to the canonical one, but not inserting the obtained +/// pointer in a struct. +struct AllocaMemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto memrefType = allocaOp.getType(); + if (!memrefType.hasStaticShape() || !memrefType.getLayout().isIdentity()) + return failure(); + + const auto ptrType = typeConverter->convertType(allocaOp.getType()); + if (!ptrType) + return failure(); + const auto loc = allocaOp.getLoc(); + auto nullPtr = rewriter.create(loc, ptrType); + auto gepPtr = rewriter.create( + loc, ptrType, nullPtr, + createIndexConstant(rewriter, loc, + allocaOp.getType().getNumElements())); + auto sizeBytes = + rewriter.create(loc, getIndexType(), gepPtr); + + rewriter.replaceOpWithNewOp( + allocaOp, ptrType, sizeBytes, allocaOp.getAlignment().value_or(0)); + return success(); + } +}; + +static Value createAlignedOld(ConversionPatternRewriter &rewriter, Location loc, + Value input, Value alignment) { + auto one = rewriter.create(loc, alignment.getType(), 1); + auto bump = rewriter.create(loc, alignment, one); + auto bumped = rewriter.create(loc, input, bump); + auto mod = rewriter.create(loc, bumped, alignment); + return rewriter.create(loc, bumped, mod); +} + +/// Conversion similar to the canonical one, but not inserting the obtained +/// pointer in a struct. +struct AllocMemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto memrefType = allocOp.getType(); + const auto elementPtrType = typeConverter->convertType(memrefType); + if (!elementPtrType || !memrefType.hasStaticShape() || + !memrefType.getLayout().isIdentity()) + return failure(); + + const auto loc = allocOp.getLoc(); + SmallVector sizes; + SmallVector strides; + Value sizeBytes; + getMemRefDescriptorSizes(loc, memrefType, adaptor.getOperands(), rewriter, + sizes, strides, sizeBytes); + + const auto alignment = + llvm::transformOptional(allocOp.getAlignment(), [&](auto val) { + return createIndexConstant(rewriter, loc, val); + }); + if (alignment) { + // Adjust the allocation size to consider alignment. + sizeBytes = rewriter.create(loc, sizeBytes, *alignment); + } + + auto module = allocOp->getParentOfType(); + const auto allocFuncOp = + getAllocFn(*getTypeConverter(), module, getIndexType()); + + const auto results = + rewriter.create(loc, allocFuncOp, sizeBytes).getResults(); + auto alignedPtr = static_cast( + rewriter.create(loc, elementPtrType, results)); + if (alignment) { + // Compute the aligned pointer. + const auto allocatedInt = static_cast( + rewriter.create(loc, getIndexType(), alignedPtr)); + const auto alignmentInt = + createAlignedOld(rewriter, loc, allocatedInt, *alignment); + alignedPtr = + rewriter.create(loc, elementPtrType, alignmentInt); + } + rewriter.replaceOp(allocOp, {alignedPtr}); + return success(); + } +}; + +/// Conversion similar to the canonical one, but not extracting the allocated +/// pointer from a struct. +struct DeallocOpLoweringOld : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::DeallocOp deallocOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!canBeLoweredToBarePtr( + deallocOp.getMemref().getType().cast())) + return failure(); + // Insert the `free` declaration if it is not already present. + const auto freeFunc = + getFreeFn(*getTypeConverter(), deallocOp->getParentOfType()); + const auto casted = + rewriter + .create(deallocOp.getLoc(), getVoidPtrType(), + adaptor.getMemref()) + .getRes(); + rewriter.replaceOpWithNewOp(deallocOp, freeFunc, casted); + return success(); + } +}; + +/// Lowers to an identity operation. +struct CastMemrefOpLoweringOld : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult match(memref::CastOp castOp) const override { + const auto srcType = castOp.getOperand().getType().cast(); + const auto dstType = castOp.getType().cast(); + + // This will be replaced by an identity function, so we need input and + // output types to match. + return success(canBeLoweredToBarePtr(dstType) && + canBeLoweredToBarePtr(srcType) && + typeConverter->convertType(srcType) == + typeConverter->convertType(dstType)); + } + + void rewrite(memref::CastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(castOp, adaptor.getSource()); + } +}; + +struct MemorySpaceCastMemRefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::MemorySpaceCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto newTy = getTypeConverter()->convertType(castOp.getType()); + rewriter.replaceOpWithNewOp(castOp, newTy, + adaptor.getSource()); + return success(); + } +}; + +/// Base class for lowering operations implementing memory accesses. +struct MemAccessLoweringOld : public ConvertToLLVMPattern { + using ConvertToLLVMPattern::ConvertToLLVMPattern; + + /// Obtains offset from a memory access indices + Value getStridedElementBarePtr(Location loc, MemRefType type, Value base, + ValueRange indices, + ConversionPatternRewriter &rewriter) const { + int64_t offset; + SmallVector strides; + LogicalResult successStrides = getStridesAndOffset(type, strides, offset); + assert(succeeded(successStrides) && "unexpected non-strided memref"); + (void)successStrides; + + auto index = + offset == 0 ? Value{} : createIndexConstant(rewriter, loc, offset); + + for (const auto &iter : llvm::enumerate(llvm::zip(indices, strides))) { + auto increment = std::get<0>(iter.value()); + const auto stride = std::get<1>(iter.value()); + if (stride != 1) { // Skip if stride is 1. + increment = rewriter.create( + loc, increment, createIndexConstant(rewriter, loc, stride)); + } + index = index ? rewriter.create(loc, index, increment) + : increment; + } + const auto elementPtrType = getTypeConverter()->convertType(type); + if (!elementPtrType) + return {}; + return index + ? rewriter.create(loc, elementPtrType, base, index) + : base; + } +}; + +struct LoadMemRefOpLoweringOld : public MemAccessLowering { + LoadMemRefOpLoweringOld(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : MemAccessLowering{memref::LoadOp::getOperationName(), + &typeConverter.getContext(), typeConverter, benefit} { + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef args, + ConversionPatternRewriter &rewriter) const override { + auto loadOp = cast(op); + if (!canBeLoweredToBarePtr(loadOp.getMemRefType())) + return failure(); + memref::LoadOp::Adaptor adaptor{args}; + const Value DataPtr = getStridedElementBarePtr( + loadOp.getLoc(), loadOp.getMemRefType(), adaptor.getMemref(), + adaptor.getIndices(), rewriter); + if (!DataPtr) + return failure(); + rewriter.replaceOpWithNewOp(op, DataPtr); + return success(); + } +}; + +struct StoreMemRefOpLoweringOld : public MemAccessLowering { + StoreMemRefOpLoweringOld(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : MemAccessLowering{memref::StoreOp::getOperationName(), + &typeConverter.getContext(), typeConverter, benefit} { + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef args, + ConversionPatternRewriter &rewriter) const override { + auto storeOp = cast(op); + if (!canBeLoweredToBarePtr(storeOp.getMemRefType())) + return failure(); + memref::StoreOp::Adaptor adaptor{args}; + const Value DataPtr = getStridedElementBarePtr( + storeOp.getLoc(), storeOp.getMemRefType(), adaptor.getMemref(), + adaptor.getIndices(), rewriter); + if (!DataPtr) + return failure(); + rewriter.replaceOpWithNewOp(op, adaptor.getValue(), DataPtr); + return success(); + } +}; +} // namespace + void mlir::polygeist::populateBareMemRefToLLVMConversionPatterns( - mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns) { + mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool useOpaquePointer) { assert(converter.getOptions().useBarePtrCallConv && "Expecting \"bare pointer\" calling convention"); - patterns.add( - converter, 2); + if (useOpaquePointer) { + patterns.add( + converter, 2); + } else { + patterns.add(converter, 2); + } // Patterns are tried in reverse add order, so this is tried before the // one added by default. - converter.addConversion([&](MemRefType type) -> Optional { - if (!canBeLoweredToBarePtr(type)) - return std::nullopt; - return LLVM::LLVMPointerType::get(type.getContext(), - type.getMemorySpaceAsInt()); - }); + converter.addConversion( + [&, useOpaquePointer](MemRefType type) -> Optional { + if (!canBeLoweredToBarePtr(type)) + return std::nullopt; + + if (useOpaquePointer) { + return LLVM::LLVMPointerType::get(type.getContext(), + type.getMemorySpaceAsInt()); + } + + const auto elemType = converter.convertType(type.getElementType()); + if (!elemType) + return Type{}; + return LLVM::LLVMPointerType::get(elemType, type.getMemorySpaceAsInt()); + }); } From b1063d4fdc06aa8d1b153698daee53bfeb11598a Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Fri, 24 Mar 2023 16:32:10 +0000 Subject: [PATCH 4/6] Add versioning and opaque pointer option Signed-off-by: Lukas Sommer --- .../mlir/Conversion/PolygeistPasses.td | 5 +- .../PolygeistToLLVM/PolygeistToLLVM.cpp | 629 +++++++++++++++++- .../test/polygeist-opt/bareptrlowering.mlir | 2 +- polygeist/test/polygeist-opt/sycl/cast.mlir | 2 +- .../test/polygeist-opt/sycl/subindex.mlir | 2 +- 5 files changed, 623 insertions(+), 17 deletions(-) diff --git a/polygeist/include/mlir/Conversion/PolygeistPasses.td b/polygeist/include/mlir/Conversion/PolygeistPasses.td index 74c1d9cc74419..2a21de4edfd1a 100644 --- a/polygeist/include/mlir/Conversion/PolygeistPasses.td +++ b/polygeist/include/mlir/Conversion/PolygeistPasses.td @@ -24,7 +24,10 @@ def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> Option<"dataLayout", "data-layout", "std::string", /*default=*/"\"\"", "String description (LLVM format) of the data layout that is " - "expected on the produced module"> + "expected on the produced module">, + Option<"useOpaquePointers", "use-opaque-pointers", "bool", + /*default=*/"false", "Generate LLVM IR using opaque pointers " + "instead of typed pointers">, ]; } diff --git a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp index 8b8efd4d307b5..ae2a94e2bb3aa 100644 --- a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp +++ b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp @@ -54,6 +54,362 @@ namespace mlir { #undef GEN_PASS_DEF_CONVERTPOLYGEISTTOLLVM } // namespace mlir +// FIXME: All the following patterns with an "Old" suffix to their name should +// be removed once we drop typed pointer support. +struct BaseSubIndexOpLoweringOld : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + +protected: + // Compute the indices of the GEP operation we lower the SubIndexOp to. + // The indices are computed based on: + // a) the (converted) source element type, and + // b) the (converted) result element type that is requested + // Examples: + // - src ty: ptr>>, res ty: ptr + // -> idxs = [0, 0, SubIndexOp's index] + // - src ty: ptr>>, res ty: ptr> + // -> idxs = [0, SubIndexOp's index] + // + // Note: when the source element type is a struct with more than one member + // type, the result type that is requested is deemed illegal unless it is one + // of the source member types. For example assume: + // - src ty: ptr,i32>> + // - res ty: ptr + // This is illegal because res ty can only be either ptr or + // ptr> + static void computeIndices(const LLVM::LLVMStructType &srcElemType, + const Type &resElemType, + SmallVectorImpl &indices, SubIndexOp op, + OpAdaptor transformed, + ConversionPatternRewriter &rewriter) { + assert(indices.empty() && "Expecting an empty vector"); + + ArrayRef memTypes = srcElemType.getBody(); + unsigned numMembers = memTypes.size(); + assert((numMembers == 1 || + any_of(memTypes, [=](Type t) { return resElemType == t; })) && + "The requested result memref element type is illegal"); + + Type indexType = transformed.getIndex().getType(); + Value zero = rewriter.create( + op.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0)); + indices.push_back(zero); + + if (numMembers == 1) { + Type currType = srcElemType.getBody()[0]; + while (currType != resElemType) { + indices.push_back(zero); + + TypeSwitch(currType) + .Case([&](LLVM::LLVMStructType t) { + assert(t.getBody().size() == 1 && "Expecting single member type"); + currType = t.getBody()[0]; + }) + .Case( + [&](auto t) { currType = t.getElementType(); }) + .Default([&](Type t) { + currType = t; + assert(currType == resElemType && + "requested result type is illegal"); + }); + } + } + + indices.push_back(transformed.getIndex()); + } +}; + +/// Conversion pattern that transforms a subview op into: +/// 1. An `llvm.mlir.undef` operation to create a memref descriptor +/// 2. Updates to the descriptor to introduce the data ptr, offset, size +/// and stride. +/// The subview op is replaced by the descriptor. +struct SubIndexOpLoweringOld : public BaseSubIndexOpLoweringOld { + using BaseSubIndexOpLoweringOld::BaseSubIndexOpLoweringOld; + + LogicalResult + matchAndRewrite(SubIndexOp subViewOp, OpAdaptor transformed, + ConversionPatternRewriter &rewriter) const override { + assert(subViewOp.getSource().getType().isa() && + "Source operand should be a memref type"); + assert(subViewOp.getType().isa() && + "Result should be a memref type"); + + auto sourceMemRefType = subViewOp.getSource().getType().cast(); + auto viewMemRefType = subViewOp.getType().cast(); + + auto loc = subViewOp.getLoc(); + MemRefDescriptor targetMemRef(transformed.getSource()); + Value prev = targetMemRef.alignedPtr(rewriter, loc); + Value idxs[] = {transformed.getIndex()}; + + SmallVector sizes, strides; + if (sourceMemRefType.getRank() != viewMemRefType.getRank()) { + if (sourceMemRefType.getRank() != viewMemRefType.getRank() + 1) + return failure(); + + size_t sz = 1; + for (int64_t i = 1; i < sourceMemRefType.getRank(); i++) { + if (sourceMemRefType.getShape()[i] == ShapedType::kDynamic) + return failure(); + sz *= sourceMemRefType.getShape()[i]; + } + Value cop = rewriter.create( + loc, idxs[0].getType(), + rewriter.getIntegerAttr(idxs[0].getType(), sz)); + idxs[0] = rewriter.create(loc, idxs[0], cop); + for (int64_t i = 1; i < sourceMemRefType.getRank(); i++) { + sizes.push_back(targetMemRef.size(rewriter, loc, i)); + strides.push_back(targetMemRef.stride(rewriter, loc, i)); + } + } else { + for (int64_t i = 0; i < sourceMemRefType.getRank(); i++) { + sizes.push_back(targetMemRef.size(rewriter, loc, i)); + strides.push_back(targetMemRef.stride(rewriter, loc, i)); + } + } + + Type sourceElemType = sourceMemRefType.getElementType(); + Type convSourceElemType = getTypeConverter()->convertType(sourceElemType); + Type viewElemType = viewMemRefType.getElementType(); + Type convViewElemType = getTypeConverter()->convertType(viewElemType); + + // Handle the general (non-SYCL) case first. + if (convViewElemType == + prev.getType().cast().getElementType()) { + auto memRefDesc = createMemRefDescriptor( + loc, viewMemRefType, targetMemRef.allocatedPtr(rewriter, loc), + rewriter.create(loc, prev.getType(), prev, idxs), sizes, + strides, rewriter); + + rewriter.replaceOp(subViewOp, {memRefDesc}); + return success(); + } + assert(convSourceElemType.isa() && + "Expecting struct type"); + + // SYCL case + assert(sourceMemRefType.getRank() == viewMemRefType.getRank() && + "Expecting the input and output MemRef ranks to be the same"); + + SmallVector indices; + computeIndices(convSourceElemType.cast(), + convViewElemType, indices, subViewOp, transformed, rewriter); + assert(!indices.empty() && "Expecting a least one index"); + + // Note: MLIRScanner::InitializeValueByInitListExpr() in clang-mlir.cc, when + // a memref element type is a struct type, the return type of a + // polygeist.subindex operation should be a memref of the element type of + // the struct. + auto elemPtrTy = LLVM::LLVMPointerType::get( + convViewElemType, viewMemRefType.getMemorySpaceAsInt()); + auto gep = rewriter.create(loc, elemPtrTy, prev, indices); + auto memRefDesc = createMemRefDescriptor(loc, viewMemRefType, gep, gep, + sizes, strides, rewriter); + LLVM_DEBUG(llvm::dbgs() << "SubIndexOpLowering: gep: " << *gep << "\n"); + + rewriter.replaceOp(subViewOp, {memRefDesc}); + return success(); + } +}; + +struct SubIndexBarePtrOpLoweringOld : public BaseSubIndexOpLoweringOld { + using BaseSubIndexOpLoweringOld::BaseSubIndexOpLoweringOld; + + LogicalResult + matchAndRewrite(SubIndexOp subViewOp, OpAdaptor transformed, + ConversionPatternRewriter &rewriter) const override { + assert(subViewOp.getSource().getType().isa() && + "Source operand should be a memref type"); + assert(subViewOp.getType().isa() && + "Result should be a memref type"); + + auto sourceMemRefType = subViewOp.getSource().getType().cast(); + auto viewMemRefType = subViewOp.getType().cast(); + if (!canBeLoweredToBarePtr(sourceMemRefType) || + !canBeLoweredToBarePtr(viewMemRefType)) + return failure(); + + const auto loc = subViewOp.getLoc(); + const auto target = transformed.getSource(); + auto idx = transformed.getIndex(); + + if (sourceMemRefType.getRank() != viewMemRefType.getRank()) { + if (sourceMemRefType.getRank() != viewMemRefType.getRank() + 1) + return failure(); + + size_t sz = 1; + for (int64_t i = 1; i < sourceMemRefType.getRank(); i++) { + if (sourceMemRefType.getShape()[i] == ShapedType::kDynamic) + return failure(); + sz *= sourceMemRefType.getShape()[i]; + } + Value cop = rewriter.create( + loc, idx.getType(), rewriter.getIntegerAttr(idx.getType(), sz)); + idx = rewriter.create(loc, idx, cop); + } + + Type sourceElemType = sourceMemRefType.getElementType(); + Type convSourceElemType = getTypeConverter()->convertType(sourceElemType); + if (!convSourceElemType) + return failure(); + Type viewElemType = viewMemRefType.getElementType(); + Type convViewElemType = getTypeConverter()->convertType(viewElemType); + Type resType = getTypeConverter()->convertType(subViewOp.getType()); + + // Handle the general (non-SYCL) case first. + if (convViewElemType == + target.getType().cast().getElementType()) { + rewriter.replaceOpWithNewOp(subViewOp, resType, target, idx); + return success(); + } + assert(convSourceElemType.isa() && + "Expecting struct type"); + + // SYCL case + assert(sourceMemRefType.getRank() == viewMemRefType.getRank() && + "Expecting the input and output MemRef ranks to be the same"); + + SmallVector indices; + computeIndices(convSourceElemType.cast(), + convViewElemType, indices, subViewOp, transformed, rewriter); + assert(!indices.empty() && "Expecting a least one index"); + + // Note: MLIRScanner::InitializeValueByInitListExpr() in clang-mlir.cc, when + // a memref element type is a struct type, the return type of a + // polygeist.subindex operation should be a memref of the element type of + // the struct. + + rewriter.replaceOpWithNewOp(subViewOp, resType, target, + indices); + + return success(); + } +}; + +struct Memref2PointerOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Memref2PointerOp op, OpAdaptor transformed, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // MemRefDescriptor sourceMemRef(operands.front()); + MemRefDescriptor targetMemRef( + transformed.getSource()); // MemRefDescriptor::undef(rewriter, loc, + // targetDescTy); + + // Offset. + Value baseOffset = targetMemRef.offset(rewriter, loc); + Value ptr = targetMemRef.alignedPtr(rewriter, loc); + Value idxs[] = {baseOffset}; + ptr = rewriter.create(loc, ptr.getType(), ptr, idxs); + assert(ptr.getType().cast().getAddressSpace() == + op.getType().getAddressSpace() && + "Expecting Memref2PointerOp source and result types to have the " + "same address space"); + ptr = rewriter.create(loc, op.getType(), ptr); + + rewriter.replaceOp(op, {ptr}); + return success(); + } +}; + +struct Pointer2MemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Pointer2MemrefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // MemRefDescriptor sourceMemRef(operands.front()); + auto convertedType = getTypeConverter()->convertType(op.getType()); + assert(convertedType && "unexpected failure in memref type conversion"); + auto descr = MemRefDescriptor::undef(rewriter, loc, convertedType); + assert(adaptor.getSource() + .getType() + .cast() + .getAddressSpace() == + op.getType().cast().getMemorySpaceAsInt() && + "Expecting Pointer2MemrefOp source and result types to have the " + "same address space"); + auto ptr = rewriter.create( + op.getLoc(), descr.getElementPtrType(), adaptor.getSource()); + + // Extract all strides and offsets and verify they are static. + int64_t offset; + SmallVector strides; + auto result = getStridesAndOffset(op.getType(), strides, offset); + (void)result; + assert(succeeded(result) && "unexpected failure in stride computation"); + assert(offset != ShapedType::kDynamic && "expected static offset"); + + bool first = true; + assert(!llvm::any_of(strides, [&](int64_t stride) { + if (first) { + first = false; + return false; + } + return stride == ShapedType::kDynamic; + }) && "expected static strides except first element"); + + descr.setAllocatedPtr(rewriter, loc, ptr); + descr.setAlignedPtr(rewriter, loc, ptr); + descr.setConstantOffset(rewriter, loc, offset); + + // Fill in sizes and strides + for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { + descr.setConstantSize(rewriter, loc, i, op.getType().getDimSize(i)); + descr.setConstantStride(rewriter, loc, i, strides[i]); + } + + rewriter.replaceOp(op, {descr}); + return success(); + } +}; + +/// Lowers to a bitcast operation +struct BareMemref2PointerOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Memref2PointerOp op, OpAdaptor transformed, + ConversionPatternRewriter &rewriter) const override { + if (!canBeLoweredToBarePtr(op.getSource().getType())) + return failure(); + + const auto target = transformed.getSource(); + rewriter.replaceOpWithNewOp(op, op.getType(), target); + + return success(); + } +}; + +/// Lowers to a bitcast operation +struct BarePointer2MemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Pointer2MemrefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!canBeLoweredToBarePtr(op.getType())) + return failure(); + + const auto convertedType = getTypeConverter()->convertType(op.getType()); + if (!convertedType) + return failure(); + rewriter.replaceOpWithNewOp(op, convertedType, + adaptor.getSource()); + return success(); + } +}; + struct BaseSubIndexOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -495,16 +851,32 @@ void populatePolygeistToLLVMConversionPatterns(LLVMTypeConverter &converter, assert(converter.getOptions().useBarePtrCallConv && "These patterns only work with bare pointer calling convention"); - patterns.add(converter); - // When adding these patterns (and other patterns changing the default - // conversion of operations on MemRef values), a higher benefit is passed - // (2), so that these patterns have a higher priority than the ones - // performing the default conversion, which should only run if the "bare - // pointer" ones fail. - patterns.add(converter, - /*benefit*/ 2); + if (converter.useOpaquePointers()) { + patterns.add(converter); + // When adding these patterns (and other patterns changing the + // default conversion of operations on MemRef values), a higher + // benefit is passed (2), so that these patterns have a higher + // priority than the ones performing the default conversion, which + // should only run if the "bare pointer" ones fail. + patterns.add(converter, + /*benefit*/ 2); + } else { + // FIXME: This 'else'-part should be removed completely when we drop typed + // pointer support. + patterns.add( + converter); + // When adding these patterns (and other patterns changing the + // default conversion of operations on MemRef values), a higher + // benefit is passed (2), so that these patterns have a higher + // priority than the ones performing the default conversion, which + // should only run if the "bare pointer" ones fail. + patterns.add(converter, + /*benefit*/ 2); + } } namespace { @@ -573,6 +945,230 @@ struct URLLVMOpLowering } }; +// FIXME: The following function and pattern with the "Old" suffix should be +// removed once we drop typed pointer support. + +// TODO lock this wrt module +static LLVM::LLVMFuncOp addMocCUDAFunctionOld(ModuleOp module, Type streamTy) { + const char fname[] = "fake_cuda_dispatch"; + + MLIRContext *ctx = module.getContext(); + auto loc = module.getLoc(); + auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); + + for (auto fn : module.getBody()->getOps()) { + if (fn.getName() == fname) + return fn; + } + + auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto i8Ptr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + + auto resumeOp = moduleBuilder.create( + fname, LLVM::LLVMFunctionType::get( + voidTy, {i8Ptr, + LLVM::LLVMPointerType::get( + LLVM::LLVMFunctionType::get(voidTy, {i8Ptr})), + streamTy})); + resumeOp.setPrivate(); + + return resumeOp; +} + +struct AsyncOpLoweringOld : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(async::ExecuteOp execute, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = execute->getParentOfType(); + + MLIRContext *ctx = module.getContext(); + Location loc = execute.getLoc(); + + auto voidTy = LLVM::LLVMVoidType::get(ctx); + Type voidPtr = LLVM::LLVMPointerType::get(IntegerType::get(ctx, 8)); + + // Make sure that all constants will be inside the outlined async function + // to reduce the number of function arguments. + Region &funcReg = execute.getRegion(); + + // Collect all outlined function inputs. + SetVector functionInputs; + + getUsedValuesDefinedAbove(execute.getRegion(), funcReg, functionInputs); + SmallVector toErase; + for (auto a : functionInputs) { + Operation *op = a.getDefiningOp(); + if (op && op->hasTrait()) + toErase.push_back(a); + } + for (auto a : toErase) { + functionInputs.remove(a); + } + + // Collect types for the outlined function inputs and outputs. + TypeConverter *converter = getTypeConverter(); + auto typesRange = llvm::map_range(functionInputs, [&](Value value) { + return converter->convertType(value.getType()); + }); + SmallVector inputTypes(typesRange.begin(), typesRange.end()); + + Type ftypes[] = {voidPtr}; + auto funcType = LLVM::LLVMFunctionType::get(voidTy, ftypes); + + // TODO: Derive outlined function name from the parent FuncOp (support + // multiple nested async.execute operations). + auto moduleBuilder = + ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); + + static int off = 0; + off++; + auto func = moduleBuilder.create( + execute.getLoc(), + "kernelbody." + std::to_string((long long int)&execute) + "." + + std::to_string(off), + funcType); + + rewriter.setInsertionPointToStart(func.addEntryBlock()); + IRMapping valueMapping; + for (Value capture : toErase) { + Operation *op = capture.getDefiningOp(); + for (auto r : + llvm::zip(op->getResults(), + rewriter.clone(*op, valueMapping)->getResults())) { + valueMapping.map(rewriter.getRemappedValue(std::get<0>(r)), + std::get<1>(r)); + } + } + // Prepare for coroutine conversion by creating the body of the function. + { + // Map from function inputs defined above the execute op to the function + // arguments. + auto arg = func.getArgument(0); + + if (functionInputs.size() == 0) { + } else if (functionInputs.size() == 1 && + converter->convertType(functionInputs[0].getType()) + .isa()) { + valueMapping.map( + functionInputs[0], + rewriter.create( + execute.getLoc(), + converter->convertType(functionInputs[0].getType()), arg)); + } else if (functionInputs.size() == 1 && + converter->convertType(functionInputs[0].getType()) + .isa()) { + valueMapping.map( + functionInputs[0], + rewriter.create( + execute.getLoc(), + converter->convertType(functionInputs[0].getType()), arg)); + } else { + SmallVector types; + for (auto v : functionInputs) + types.push_back(converter->convertType(v.getType())); + auto ST = LLVM::LLVMStructType::getLiteral(ctx, types); + auto alloc = rewriter.create( + execute.getLoc(), LLVM::LLVMPointerType::get(ST), arg); + for (auto idx : llvm::enumerate(functionInputs)) { + + mlir::Value idxs[] = { + rewriter.create(loc, 0, 32), + rewriter.create(loc, idx.index(), 32), + }; + Value next = rewriter.create( + loc, LLVM::LLVMPointerType::get(idx.value().getType()), alloc, + idxs); + valueMapping.map(idx.value(), + rewriter.create(loc, next)); + } + auto freef = getFreeFn(*getTypeConverter(), module); + Value args[] = {arg}; + rewriter.create(loc, freef, args); + } + + // Clone all operations from the execute operation body into the outlined + // function body. + for (Operation &op : execute.getBody()->without_terminator()) + rewriter.clone(op, valueMapping); + + rewriter.create(execute.getLoc(), ValueRange()); + } + + // Replace the original `async.execute` with a call to outlined function. + { + rewriter.setInsertionPoint(execute); + SmallVector crossing; + for (auto tup : llvm::zip(functionInputs, inputTypes)) { + Value val = std::get<0>(tup); + crossing.push_back(val); + } + + SmallVector vals; + if (crossing.size() == 0) { + vals.push_back( + rewriter.create(execute.getLoc(), voidPtr)); + } else if (crossing.size() == 1 && + converter->convertType(crossing[0].getType()) + .isa()) { + vals.push_back(rewriter.create(execute.getLoc(), + voidPtr, crossing[0])); + } else if (crossing.size() == 1 && + converter->convertType(crossing[0].getType()) + .isa()) { + vals.push_back(rewriter.create(execute.getLoc(), + voidPtr, crossing[0])); + } else { + SmallVector types; + for (auto v : crossing) + types.push_back(v.getType()); + auto ST = LLVM::LLVMStructType::getLiteral(ctx, types); + + auto mallocf = getAllocFn(*getTypeConverter(), module, getIndexType()); + + Value args[] = {rewriter.create( + loc, rewriter.getI64Type(), + rewriter.create(loc, rewriter.getIndexType(), + ST))}; + mlir::Value alloc = rewriter.create( + loc, LLVM::LLVMPointerType::get(ST), + rewriter.create(loc, mallocf, args) + .getResult()); + rewriter.setInsertionPoint(execute); + for (auto idx : llvm::enumerate(crossing)) { + + mlir::Value idxs[] = { + rewriter.create(loc, 0, 32), + rewriter.create(loc, idx.index(), 32), + }; + Value next = rewriter.create( + loc, LLVM::LLVMPointerType::get(idx.value().getType()), alloc, + idxs); + rewriter.create(loc, idx.value(), next); + } + vals.push_back( + rewriter.create(execute.getLoc(), voidPtr, alloc)); + } + vals.push_back( + rewriter.create(execute.getLoc(), func)); + for (auto dep : execute.getDependencies()) { + auto ctx = dep.getDefiningOp(); + vals.push_back(ctx.getSource()); + } + assert(vals.size() == 3); + + auto f = addMocCUDAFunctionOld(execute->getParentOfType(), + vals.back().getType()); + + rewriter.create(execute.getLoc(), f, vals); + rewriter.eraseOp(execute); + } + + return success(); + } +}; + // TODO lock this wrt module static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { const char fname[] = "fake_cuda_dispatch"; @@ -917,7 +1513,7 @@ struct ConvertPolygeistToLLVMPass LowerToLLVMOptions options(&getContext(), dataLayoutAnalysis.getAtOrAbove(m)); options.useBarePtrCallConv = true; - options.useOpaquePointers = true; + options.useOpaquePointers = useOpaquePointers; if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); @@ -959,7 +1555,8 @@ struct ConvertPolygeistToLLVMPass // Run these instead of the ones provided by the dialect to avoid lowering // memrefs to a struct. - populateBareMemRefToLLVMConversionPatterns(converter, patterns); + populateBareMemRefToLLVMConversionPatterns(converter, patterns, + useOpaquePointers); // Legality callback for operations that checks whether their operand and // results types are converted. @@ -1014,7 +1611,13 @@ struct ConvertPolygeistToLLVMPass if (i == 1) { target.addIllegalOp(); - patterns.add(converter); + if (useOpaquePointers) { + patterns.add(converter); + } else { + // FIXME: This part should be removed when we drop typed pointer + // support. + patterns.add(converter); + } patterns.add(converter); } if (failed(applyPartialConversion(m, target, std::move(patterns)))) diff --git a/polygeist/test/polygeist-opt/bareptrlowering.mlir b/polygeist/test/polygeist-opt/bareptrlowering.mlir index 4609998ad6d83..cc832b3289a2c 100644 --- a/polygeist/test/polygeist-opt/bareptrlowering.mlir +++ b/polygeist/test/polygeist-opt/bareptrlowering.mlir @@ -1,4 +1,4 @@ -// RUN: polygeist-opt %s --convert-polygeist-to-llvm --split-input-file | FileCheck %s +// RUN: polygeist-opt %s --convert-polygeist-to-llvm='use-opaque-pointers=1' --split-input-file | FileCheck %s // CHECK-LABEL: llvm.func @ptr_ret_static(i64) -> !llvm.ptr diff --git a/polygeist/test/polygeist-opt/sycl/cast.mlir b/polygeist/test/polygeist-opt/sycl/cast.mlir index 89917ff781b26..b355119ba8f43 100644 --- a/polygeist/test/polygeist-opt/sycl/cast.mlir +++ b/polygeist/test/polygeist-opt/sycl/cast.mlir @@ -1,4 +1,4 @@ -// RUN: polygeist-opt --convert-polygeist-to-llvm --split-input-file %s | FileCheck %s +// RUN: polygeist-opt --convert-polygeist-to-llvm='use-opaque-pointers=1' --split-input-file %s | FileCheck %s !sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> !sycl_range_1_ = !sycl.range<[1], (!sycl_array_1_)> diff --git a/polygeist/test/polygeist-opt/sycl/subindex.mlir b/polygeist/test/polygeist-opt/sycl/subindex.mlir index 50eab37fdeaff..730261b9c67dd 100644 --- a/polygeist/test/polygeist-opt/sycl/subindex.mlir +++ b/polygeist/test/polygeist-opt/sycl/subindex.mlir @@ -1,4 +1,4 @@ -// RUN: polygeist-opt --convert-polygeist-to-llvm --split-input-file %s | FileCheck %s +// RUN: polygeist-opt --convert-polygeist-to-llvm='use-opaque-pointers=1' --split-input-file %s | FileCheck %s // CHECK-LABEL: @test_1 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 From cb0c9b042ed24d26ffa50ab88d3f130ed428a1d0 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Mon, 27 Mar 2023 14:07:24 +0100 Subject: [PATCH 5/6] Address PR feedback Signed-off-by: Lukas Sommer --- .../Dialect/Polygeist/Transforms/Passes.h | 2 +- .../PolygeistToLLVM/PolygeistToLLVM.cpp | 11 +++--- .../Polygeist/Transforms/BareMemRefToLLVM.cpp | 8 ++-- .../test/polygeist-opt/bareptrlowering.mlir | 38 +++++++++---------- polygeist/test/polygeist-opt/sycl/cast.mlir | 6 +-- .../test/polygeist-opt/sycl/subindex.mlir | 2 +- 6 files changed, 32 insertions(+), 35 deletions(-) diff --git a/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h b/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h index bdec6b31a485a..f53b01231a977 100644 --- a/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h +++ b/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h @@ -27,7 +27,7 @@ namespace polygeist { /// convention. void populateBareMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool useOpaquePointer = false); + bool useOpaquePointers = false); #define GEN_PASS_DECL #include "mlir/Dialect/Polygeist/Transforms/Passes.h.inc" diff --git a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp index ae2a94e2bb3aa..81489044ff7c0 100644 --- a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp +++ b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp @@ -462,7 +462,7 @@ struct BaseSubIndexOpLowering : public ConvertOpToLLVMPattern { .Case( [&](auto t) { currType = t.getElementType(); }) .Case( - [&](auto t) { assert(false && "Pointer type not allowed"); }) + [&](auto t) { llvm_unreachable("Pointer type not allowed"); }) .Default([&](Type t) { currType = t; assert(currType == resElemType && @@ -532,7 +532,7 @@ struct SubIndexOpLowering : public BaseSubIndexOpLowering { // Handle the general (non-SYCL) case first. if (convViewElemType == - transformed.getSource().getType().cast().getElementType()) { + cast(transformed.getSource().getType()).getElementType()) { auto memRefDesc = createMemRefDescriptor( loc, viewMemRefType, targetMemRef.allocatedPtr(rewriter, loc), rewriter.create(loc, prev.getType(), convViewElemType, @@ -663,10 +663,9 @@ struct Memref2PointerOpLowering Value baseOffset = targetMemRef.offset(rewriter, loc); Value ptr = targetMemRef.alignedPtr(rewriter, loc); Value idxs[] = {baseOffset}; - ptr = rewriter.create( - loc, ptr.getType(), - transformed.getSource().getType().cast().getElementType(), - ptr, idxs); + auto elemType = getTypeConverter()->convertType( + op.getSource().getType().getElementType()); + ptr = rewriter.create(loc, ptr.getType(), elemType, ptr, idxs); assert(ptr.getType().cast().getAddressSpace() == op.getType().getAddressSpace() && "Expecting Memref2PointerOp source and result types to have the " diff --git a/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp b/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp index 79ffc85a3518b..56614b3fbcd1d 100644 --- a/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp +++ b/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp @@ -633,10 +633,10 @@ struct StoreMemRefOpLoweringOld : public MemAccessLowering { void mlir::polygeist::populateBareMemRefToLLVMConversionPatterns( mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, - bool useOpaquePointer) { + bool useOpaquePointers) { assert(converter.getOptions().useBarePtrCallConv && "Expecting \"bare pointer\" calling convention"); - if (useOpaquePointer) { + if (useOpaquePointers) { patterns.add Optional { + [&, useOpaquePointers](MemRefType type) -> Optional { if (!canBeLoweredToBarePtr(type)) return std::nullopt; - if (useOpaquePointer) { + if (useOpaquePointers) { return LLVM::LLVMPointerType::get(type.getContext(), type.getMemorySpaceAsInt()); } diff --git a/polygeist/test/polygeist-opt/bareptrlowering.mlir b/polygeist/test/polygeist-opt/bareptrlowering.mlir index cc832b3289a2c..b5512dae30c13 100644 --- a/polygeist/test/polygeist-opt/bareptrlowering.mlir +++ b/polygeist/test/polygeist-opt/bareptrlowering.mlir @@ -59,7 +59,7 @@ memref.global @global : memref<3xi64> // CHECK-LABEL: llvm.func @get_global() -> !llvm.ptr // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global : !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, i64 // CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK-NEXT: } @@ -91,8 +91,7 @@ memref.global "private" constant @shape : memref<2xi64> = dense<[2, 2]> // CHECK-LABEL: llvm.func @reshape( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @reshape(%arg0: memref<4xi32>) -> memref<2x2xi32> { @@ -107,8 +106,7 @@ memref.global "private" constant @shape : memref<1xindex> // CHECK-LABEL: llvm.func @reshape_dyn( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @reshape_dyn(%arg0: memref<4xi32>) -> memref { @@ -122,7 +120,7 @@ func.func private @reshape_dyn(%arg0: memref<4xi32>) -> memref { // CHECK-LABEL: llvm.func @alloca() // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return @@ -138,7 +136,7 @@ func.func private @alloca() { // CHECK-LABEL: llvm.func @alloca_nd() // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return @@ -154,7 +152,7 @@ func.func private @alloca_nd() { // CHECK-LABEL: llvm.func @alloca_aligned() // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return @@ -170,7 +168,7 @@ func.func private @alloca_aligned() { // CHECK-LABEL: llvm.func @alloca_nd_aligned() // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return @@ -189,7 +187,7 @@ func.func private @alloca_nd_aligned() { // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.call @malloc(%[[VAL_4]]) : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return @@ -210,7 +208,7 @@ func.func private @alloc() { // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_9:.*]] = llvm.call @malloc(%[[VAL_8]]) : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return @@ -227,7 +225,7 @@ func.func private @alloc_nd() { // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_4]], %[[VAL_5]] : i64 @@ -257,7 +255,7 @@ func.func private @alloc_aligned() { // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 // CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_9:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK-NEXT: %[[VAL_10:.*]] = llvm.add %[[VAL_8]], %[[VAL_9]] : i64 @@ -307,7 +305,7 @@ func.func private @cast(%arg0: memref<2xi32>) -> memref { // CHECK-LABEL: llvm.func @load( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64) -> f32 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.load %[[VAL_2]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_3]] : f32 // CHECK-NEXT: } @@ -326,7 +324,7 @@ func.func private @load(%arg0: memref<100xf32>, %index: index) -> f32 { // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_7]] : f32 // CHECK-NEXT: } @@ -345,7 +343,7 @@ func.func private @load_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_7]] : f32 // CHECK-NEXT: } @@ -434,7 +432,7 @@ func.func private @call(%arg0: memref, %arg1: index) -> memref { // CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(4 : i64) : i64 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mul %[[VAL_1]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.return %[[VAL_4]] : !llvm.ptr // CHECK-NEXT: } @@ -448,7 +446,7 @@ func.func private @subindexop_memref(%arg0: memref<4x4xf32>, %arg1: index) -> me // CHECK-LABEL: llvm.func @subindexop_memref_same_dim( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr // CHECK-NEXT: } @@ -479,7 +477,7 @@ func.func private @subindexop_memref_struct(%arg0: memref<4x!llvm.struct<(f32)>> // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr, f32 // CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } @@ -539,7 +537,7 @@ func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { // CHECK-NEXT: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(-1 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.add %[[VAL_3]], %[[VAL_1]] : i64 -// CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_6]] : i64 // CHECK-NEXT: } diff --git a/polygeist/test/polygeist-opt/sycl/cast.mlir b/polygeist/test/polygeist-opt/sycl/cast.mlir index b355119ba8f43..7bd28e4428911 100644 --- a/polygeist/test/polygeist-opt/sycl/cast.mlir +++ b/polygeist/test/polygeist-opt/sycl/cast.mlir @@ -10,7 +10,7 @@ // CHECK: } func.func @test1(%arg0: memref) -> memref { - %0 = "sycl.cast"(%arg0) : (memref) -> memref + %0 = sycl.cast %arg0 : memref to memref func.return %0 : memref } @@ -25,7 +25,7 @@ func.func @test1(%arg0: memref) -> memref { !sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> !sycl_id_1_ = !sycl.id<[1], (!sycl_array_1_)> func.func @test2(%arg0: memref) -> memref { - %0 = "sycl.cast"(%arg0) : (memref) -> memref + %0 = sycl.cast %arg0 : memref to memref func.return %0: memref } @@ -40,6 +40,6 @@ func.func @test2(%arg0: memref) -> memref { !sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> !sycl_id_1_ = !sycl.id<[1], (!sycl_array_1_)> func.func @test_addrspaces(%arg0: memref) -> memref { - %0 = "sycl.cast"(%arg0) : (memref) -> memref + %0 = sycl.cast %arg0 : memref to memref func.return %0: memref } diff --git a/polygeist/test/polygeist-opt/sycl/subindex.mlir b/polygeist/test/polygeist-opt/sycl/subindex.mlir index 730261b9c67dd..a917b5a6e919c 100644 --- a/polygeist/test/polygeist-opt/sycl/subindex.mlir +++ b/polygeist/test/polygeist-opt/sycl/subindex.mlir @@ -46,7 +46,7 @@ func.func @test_3(%arg0: memref>) -> memref { // ----- // CHECK: llvm.func @test_4([[A0:%.*]]: !llvm.ptr, [[A5:%.*]]: i64) -> !llvm.ptr { -// CHECK: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[A5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[A5]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(struct<"class.sycl::_V1::id.1", {{.*}})> // CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> From f7823e089e664f009c1e57ea70c6b17f3e8539f8 Mon Sep 17 00:00:00 2001 From: Lukas Sommer Date: Tue, 28 Mar 2023 09:06:14 +0100 Subject: [PATCH 6/6] Add typed pointer tests and test for nested pointer Signed-off-by: Lukas Sommer --- .../bareptrlowering-typed-pointer.mlir | 573 ++++++++++++++++++ .../test/polygeist-opt/bareptrlowering.mlir | 16 + .../sycl/cast-typed-pointer.mlir | 45 ++ .../sycl/subindex-typed-pointer.mlir | 68 +++ 4 files changed, 702 insertions(+) create mode 100644 polygeist/test/polygeist-opt/bareptrlowering-typed-pointer.mlir create mode 100644 polygeist/test/polygeist-opt/sycl/cast-typed-pointer.mlir create mode 100644 polygeist/test/polygeist-opt/sycl/subindex-typed-pointer.mlir diff --git a/polygeist/test/polygeist-opt/bareptrlowering-typed-pointer.mlir b/polygeist/test/polygeist-opt/bareptrlowering-typed-pointer.mlir new file mode 100644 index 0000000000000..dd98613cb7f18 --- /dev/null +++ b/polygeist/test/polygeist-opt/bareptrlowering-typed-pointer.mlir @@ -0,0 +1,573 @@ +// RUN: polygeist-opt %s --convert-polygeist-to-llvm='use-opaque-pointers=0' --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @ptr_ret_static(i64) -> !llvm.ptr + +func.func private @ptr_ret_static(%arg0: i64) -> memref<4xi64> + +// ----- + +// CHECK-LABEL: llvm.func @ptr_ret_dynamic(i64) -> !llvm.ptr + +func.func private @ptr_ret_dynamic(%arg0: i64) -> memref + +// ----- + +// CHECK-LABEL: llvm.func @ptr_ret_nd_static(i64) -> !llvm.ptr + +func.func private @ptr_ret_nd_static(%arg0: i64) -> memref<4x4xi64> + +// ----- + +// CHECK-LABEL: llvm.func @ptr_ret_nd_dynamic(i64) -> !llvm.ptr + +func.func private @ptr_ret_nd_dynamic(%arg0: i64) -> memref + +// ----- + +// CHECK-LABEL: llvm.func @ptr_args_and_ret(!llvm.ptr, !llvm.ptr) -> !llvm.ptr + +func.func private @ptr_args_and_ret(%arg0: memref<1xi64>, %arg1: memref) -> memref + +// ----- + +// CHECK-LABEL: llvm.func @ptr_args_and_ret_with_attrs(!llvm.ptr {llvm.byval = i64}, !llvm.ptr {llvm.byval = i64}) -> !llvm.ptr + +func.func private @ptr_args_and_ret_with_attrs(%arg0: memref<1xi64> {llvm.byval = i64}, + %arg1: memref {llvm.byval = i64}) -> memref + +// ----- + +gpu.module @kernels { + +// CHECK-LABEL: llvm.func @kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr {llvm.byval = i64}, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr {llvm.byval = i64}) attributes {gpu.kernel, workgroup_attributions = 0 : i64} { +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + + gpu.func @kernel(%arg0: memref<1xi64> {llvm.byval = i64}, + %arg1: memref {llvm.byval = i64}) kernel { + gpu.return + } +} + +// ----- + +// CHECK-LABEL: llvm.mlir.global external @global() {addr_space = 0 : i32} : !llvm.array<3 x i64> + +memref.global @global : memref<3xi64> + +// CHECK-LABEL: llvm.func @get_global() -> !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global : !llvm.ptr> +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr>) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @get_global() -> memref<3xi64> { + %0 = memref.get_global @global : memref<3xi64> + return %0 : memref<3xi64> +} + +// ----- + +// CHECK-LABEL: llvm.mlir.global external @global_addrspace() {addr_space = 4 : i32} : !llvm.array<3 x i64> + +memref.global @global_addrspace : memref<3xi64, 4> + +// CHECK-LABEL: llvm.func @get_global_addrspace() -> !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global_addrspace : !llvm.ptr, 4> +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr, 4>) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @get_global_addrspace() -> memref<3xi64, 4> { + %0 = memref.get_global @global_addrspace : memref<3xi64, 4> + return %0 : memref<3xi64, 4> +} + +// ----- + +memref.global "private" constant @shape : memref<2xi64> = dense<[2, 2]> + +// CHECK-LABEL: llvm.func @reshape( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr>) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @reshape(%arg0: memref<4xi32>) -> memref<2x2xi32> { + %shape = memref.get_global @shape : memref<2xi64> + %0 = memref.reshape %arg0(%shape) : (memref<4xi32>, memref<2xi64>) -> memref<2x2xi32> + return %0 : memref<2x2xi32> +} + +// ----- + +memref.global "private" constant @shape : memref<1xindex> + +// CHECK-LABEL: llvm.func @reshape_dyn( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr>) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @reshape_dyn(%arg0: memref<4xi32>) -> memref { + %shape = memref.get_global @shape : memref<1xindex> + %0 = memref.reshape %arg0(%shape) : (memref<4xi32>, memref<1xindex>) -> memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @alloca() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloca() { + %0 = memref.alloca() : memref<2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloca_nd() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloca_nd() { + %0 = memref.alloca() : memref<3x10x2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloca_aligned() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloca_aligned() { + %0 = memref.alloca() {alignment = 8} : memref<2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloca_nd_aligned() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloca_nd_aligned() { + %0 = memref.alloca() {alignment = 8} : memref<3x10x2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @malloc(i64) -> !llvm.ptr + +// CHECK-LABEL: llvm.func @alloc() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.call @malloc(%[[VAL_4]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.bitcast %[[VAL_5]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloc() { + %0 = memref.alloc() : memref<2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloc_nd() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(3 : index) : i64 +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(10 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.call @malloc(%[[VAL_8]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloc_nd() { + %0 = memref.alloc() : memref<3x10x2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloc_aligned() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_4]], %[[VAL_5]] : i64 +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.call @malloc(%[[VAL_6]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.ptrtoint %[[VAL_8]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-NEXT: %[[VAL_11:.*]] = llvm.sub %[[VAL_5]], %[[VAL_10]] : i64 +// CHECK-NEXT: %[[VAL_12:.*]] = llvm.add %[[VAL_9]], %[[VAL_11]] : i64 +// CHECK-NEXT: %[[VAL_13:.*]] = llvm.urem %[[VAL_12]], %[[VAL_5]] : i64 +// CHECK-NEXT: %[[VAL_14:.*]] = llvm.sub %[[VAL_12]], %[[VAL_13]] : i64 +// CHECK-NEXT: %[[VAL_15:.*]] = llvm.inttoptr %[[VAL_14]] : i64 to !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloc_aligned() { + %0 = memref.alloc() {alignment = 8} : memref<2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloc_nd_aligned() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(3 : index) : i64 +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(10 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK-NEXT: %[[VAL_10:.*]] = llvm.add %[[VAL_8]], %[[VAL_9]] : i64 +// CHECK-NEXT: %[[VAL_11:.*]] = llvm.call @malloc(%[[VAL_10]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[VAL_13:.*]] = llvm.ptrtoint %[[VAL_12]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_14:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-NEXT: %[[VAL_15:.*]] = llvm.sub %[[VAL_9]], %[[VAL_14]] : i64 +// CHECK-NEXT: %[[VAL_16:.*]] = llvm.add %[[VAL_13]], %[[VAL_15]] : i64 +// CHECK-NEXT: %[[VAL_17:.*]] = llvm.urem %[[VAL_16]], %[[VAL_9]] : i64 +// CHECK-NEXT: %[[VAL_18:.*]] = llvm.sub %[[VAL_16]], %[[VAL_17]] : i64 +// CHECK-NEXT: %[[VAL_19:.*]] = llvm.inttoptr %[[VAL_18]] : i64 to !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloc_nd_aligned() { + %0 = memref.alloc() {alignment = 8} : memref<3x10x2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @dealloc( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.call @free(%[[VAL_1]]) : (!llvm.ptr) -> () +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @dealloc(%arg0: memref) { + memref.dealloc %arg0 : memref + return +} + +// ----- + +// CHECK-LABEL: llvm.func @cast( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @cast(%arg0: memref<2xi32>) -> memref { + %0 = memref.cast %arg0 : memref<2xi32> to memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @load( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> f32 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.load %[[VAL_2]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : f32 +// CHECK-NEXT: } + +func.func private @load(%arg0: memref<100xf32>, %index: index) -> f32 { + %0 = memref.load %arg0[%index] : memref<100xf32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: llvm.func @load_nd( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, +// CHECK-SAME: %[[VAL_2:.*]]: i64) -> f32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_7]] : f32 +// CHECK-NEXT: } + +func.func private @load_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: index) -> f32 { + %0 = memref.load %arg0[%index0, %index1] : memref<100x100xf32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: llvm.func @load_nd_dyn( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, +// CHECK-SAME: %[[VAL_2:.*]]: i64) -> f32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_7]] : f32 +// CHECK-NEXT: } + +func.func private @load_nd_dyn(%arg0: memref, %index0: index, %index1: index) -> f32 { + %0 = memref.load %arg0[%index0, %index1] : memref + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: llvm.func @store( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, +// CHECK-SAME: %[[VAL_2:.*]]: f32) +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.store %[[VAL_2]], %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @store(%arg0: memref<100xf32>, %index: index, %val: f32) { + memref.store %val, %arg0[%index] : memref<100xf32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @store_nd( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, +// CHECK-SAME: %[[VAL_3:.*]]: f32) +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mul %[[VAL_1]], %[[VAL_4]] : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_5]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @store_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: index, %val: f32) { + memref.store %val, %arg0[%index0, %index1] : memref<100x100xf32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @store_nd_dyn( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, +// CHECK-SAME: %[[VAL_3:.*]]: f32) +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mul %[[VAL_1]], %[[VAL_4]] : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_5]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @store_nd_dyn(%arg0: memref, %index0: index, %index1: index, %val: f32) { + memref.store %val, %arg0[%index0, %index1] : memref + return +} + +// ----- + +// CHECK-LABEL: llvm.func @impl(!llvm.ptr, i64) -> !llvm.ptr + +func.func private @impl(%arg0: memref, %arg1: index) -> memref + +// CHECK-LABEL: llvm.func @call( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.call @impl(%[[VAL_0]], %[[VAL_1]]) : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @call(%arg0: memref, %arg1: index) -> memref { + %res = func.call @impl(%arg0, %arg1) : (memref, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mul %[[VAL_1]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_4]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref(%arg0: memref<4x4xf32>, %arg1: index) -> memref<4xf32> { + %res = "polygeist.subindex"(%arg0 , %arg1) : (memref<4x4xf32>, index) -> memref<4xf32> + return %res : memref<4xf32> +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_same_dim( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_same_dim(%arg0: memref<4x4xf32>, %arg1: index) -> memref<4x4xf32> { + %res = "polygeist.subindex"(%arg0 , %arg1) : (memref<4x4xf32>, index) -> memref<4x4xf32> + return %res : memref<4x4xf32> +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_struct( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0] : (!llvm.ptr>, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_struct(%arg0: memref<4x!llvm.struct<(f32)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(f32)>>, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_nested_struct( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>>) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_nested_struct(%arg0: memref<4x!llvm.struct<(struct<(f32)>)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(struct<(f32)>)>>, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_ptr( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>)>>) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr>)>>, i64, i64, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_nested_struct_ptr(%arg0: memref<4x!llvm.struct<(ptr>)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(ptr>)>>, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_array( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>)>>) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], 0] : (!llvm.ptr>)>>, i64, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_nested_struct_array(%arg0: memref<4x!llvm.struct<(array<4x!llvm.struct<(f32)>>)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(array<4x!llvm.struct<(f32)>>)>>, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @memref2ptr( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @memref2ptr(%arg0: memref<4xf32>) -> !llvm.ptr { + %res = "polygeist.memref2pointer"(%arg0) : (memref<4xf32>) -> !llvm.ptr + return %res : !llvm.ptr +} + +// ----- + +// CHECK-LABEL: llvm.func @ptr2memref( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { + %res = "polygeist.pointer2memref"(%arg0) : (!llvm.ptr) -> memref + return %res : memref +} + +// ----- + +#layout = affine_map<(s0) -> (s0 - 1)> + +// CHECK-LABEL: llvm.func @non_bare_due_to_layout( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(-1 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.add %[[VAL_3]], %[[VAL_1]] : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_6]] : i64 +// CHECK-NEXT: } + +func.func private @non_bare_due_to_layout(%arg0: memref<100xi64, #layout>, %arg1: index) -> i64 { + %res = memref.load %arg0[%arg1] : memref<100xi64, #layout> + return %res : i64 +} diff --git a/polygeist/test/polygeist-opt/bareptrlowering.mlir b/polygeist/test/polygeist-opt/bareptrlowering.mlir index b5512dae30c13..327c2dc75c842 100644 --- a/polygeist/test/polygeist-opt/bareptrlowering.mlir +++ b/polygeist/test/polygeist-opt/bareptrlowering.mlir @@ -489,6 +489,22 @@ func.func private @subindexop_memref_nested_struct(%arg0: memref<4x!llvm.struct< // ----- +// CHECK-LABEL: llvm.func @subindexop_memref_nested_ptr( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_nested_ptr(%arg0: memref<4x!llvm.struct<(ptr)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(ptr)>>, index) -> memref + return %res : memref +} + +// ----- + // CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_array( // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 diff --git a/polygeist/test/polygeist-opt/sycl/cast-typed-pointer.mlir b/polygeist/test/polygeist-opt/sycl/cast-typed-pointer.mlir new file mode 100644 index 0000000000000..07129e4873b59 --- /dev/null +++ b/polygeist/test/polygeist-opt/sycl/cast-typed-pointer.mlir @@ -0,0 +1,45 @@ +// RUN: polygeist-opt --convert-polygeist-to-llvm='use-opaque-pointers=0' --split-input-file %s | FileCheck %s + +!sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> +!sycl_range_1_ = !sycl.range<[1], (!sycl_array_1_)> + +// CHECK-LABEL: llvm.func @test1( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>>) -> !llvm.ptr)>> { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>> to !llvm.ptr)>> +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>> +// CHECK: } + +func.func @test1(%arg0: memref) -> memref { + %0 = "sycl.cast"(%arg0) : (memref) -> memref + func.return %0 : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @test2( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>>) -> !llvm.ptr)>> { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>> to !llvm.ptr)>> +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>> +// CHECK: } + +!sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> +!sycl_id_1_ = !sycl.id<[1], (!sycl_array_1_)> +func.func @test2(%arg0: memref) -> memref { + %0 = "sycl.cast"(%arg0) : (memref) -> memref + func.return %0: memref +} + +// ----- + +// CHECK-LABEL: llvm.func @test_addrspaces( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>, 4>) -> !llvm.ptr)>, 4> { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>, 4> to !llvm.ptr)>, 4> +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>, 4> +// CHECK: } + +!sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> +!sycl_id_1_ = !sycl.id<[1], (!sycl_array_1_)> +func.func @test_addrspaces(%arg0: memref) -> memref { + %0 = "sycl.cast"(%arg0) : (memref) -> memref + func.return %0: memref +} diff --git a/polygeist/test/polygeist-opt/sycl/subindex-typed-pointer.mlir b/polygeist/test/polygeist-opt/sycl/subindex-typed-pointer.mlir new file mode 100644 index 0000000000000..98c3c057a4eaa --- /dev/null +++ b/polygeist/test/polygeist-opt/sycl/subindex-typed-pointer.mlir @@ -0,0 +1,68 @@ +// RUN: polygeist-opt --convert-polygeist-to-llvm='use-opaque-pointers=0' --split-input-file %s | FileCheck %s + +// CHECK-LABEL: @test_1 +// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr %{{.*}}[[[ZERO]], 0] : (!llvm.ptr !llvm.ptr<[[SYCLIDSTRUCT]], {{.*}} +// CHECK-NEXT: llvm.return [[GEP]] + +!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> +func.func @test_1(%arg0: memref>) -> memref { + %c0 = arith.constant 0 : index + %0 = "polygeist.subindex"(%arg0, %c0) : (memref>, index) -> memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: @test_2 +// CHECK: llvm.return %{{.*}} : !llvm.ptr)>)> +!sycl_range_1_ = !sycl.range<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> +!sycl_accessor_impl_device_1_ = !sycl.accessor_impl_device<[1], (!sycl_id_1_, !sycl_range_1_, !sycl_range_1_)> +!sycl_accessor_1_ = !sycl.accessor<[1, i32, read_write, global_buffer], (!sycl.accessor_impl_device<[1], (!sycl_id_1_, !sycl_range_1_, !sycl_range_1_)>, !llvm.struct<(ptr)>)> + +func.func @test_2(%arg0: memref) -> memref { + %c0 = arith.constant 0 : index + %0 = "polygeist.subindex"(%arg0, %c0) : (memref, index) -> memref + return %0 : memref +} + +// ----- + +// CHECK: llvm.func @test_3([[A0:.*]]: !llvm.ptr>) -> !llvm.ptr { +// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO]], 0] : (!llvm.ptr>, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr + +func.func @test_3(%arg0: memref>) -> memref { + %c0 = arith.constant 0 : index + %0 = "polygeist.subindex"(%arg0, %c0) : (memref>, index) -> memref + return %0 : memref +} + +// ----- + +// CHECK: llvm.func @test_4([[A0:%.*]]: !llvm.ptr\)>\)>]])>>, [[A5:%.*]]: i64) -> !llvm.ptr)>)>)>> { +// CHECK: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[A5]]] : (!llvm.ptr>, i64) -> !llvm.ptr> +// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr> + +!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> +func.func @test_4(%arg0: memref<1x!llvm.struct<(!sycl_id_1_)>>, %arg1: index) -> memref> { + %0 = "polygeist.subindex"(%arg0, %arg1) : (memref<1x!llvm.struct<(!sycl_id_1_)>>, index) -> memref> + return %0 : memref> +} + +// ----- + +// CHECK: llvm.func @test_5([[A0:%.*]]: !llvm.ptr<[[ARRTYPE:struct<"class.sycl::_V1::detail::array.1", \(array<1 x i64>\)>]], 4>) -> !llvm.ptr { +// CHECK-DAG: [[ZERO1:%.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-DAG: [[ZERO2:%.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO2]], 0, [[ZERO1]]] : (!llvm.ptr<[[ARRTYPE]], 4>, i64, i64) -> !llvm.ptr + +!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> +func.func @test_5(%arg0: memref)>, 4>) -> memref<1xi64, 4> { + %c0 = arith.constant 0 : index + %0 = "polygeist.subindex"(%arg0, %c0) : (memref)>, 4>, index) -> memref<1xi64, 4> + return %0 : memref<1xi64, 4> +}