diff --git a/polygeist/include/mlir/Conversion/PolygeistPasses.td b/polygeist/include/mlir/Conversion/PolygeistPasses.td index 74c1d9cc74419..2a21de4edfd1a 100644 --- a/polygeist/include/mlir/Conversion/PolygeistPasses.td +++ b/polygeist/include/mlir/Conversion/PolygeistPasses.td @@ -24,7 +24,10 @@ def ConvertPolygeistToLLVM : Pass<"convert-polygeist-to-llvm", "mlir::ModuleOp"> Option<"dataLayout", "data-layout", "std::string", /*default=*/"\"\"", "String description (LLVM format) of the data layout that is " - "expected on the produced module"> + "expected on the produced module">, + Option<"useOpaquePointers", "use-opaque-pointers", "bool", + /*default=*/"false", "Generate LLVM IR using opaque pointers " + "instead of typed pointers">, ]; } diff --git a/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h b/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h index bc4db810755dd..f53b01231a977 100644 --- a/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h +++ b/polygeist/include/mlir/Dialect/Polygeist/Transforms/Passes.h @@ -26,7 +26,8 @@ namespace polygeist { /// MemRef dialect to the LLVM dialect forcing a "bare pointer" calling /// convention. void populateBareMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns); + RewritePatternSet &patterns, + bool useOpaquePointers = false); #define GEN_PASS_DECL #include "mlir/Dialect/Polygeist/Transforms/Passes.h.inc" diff --git a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp index 5d51bc6de6599..81489044ff7c0 100644 --- a/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp +++ b/polygeist/lib/Conversion/PolygeistToLLVM/PolygeistToLLVM.cpp @@ -54,6 +54,362 @@ namespace mlir { #undef GEN_PASS_DEF_CONVERTPOLYGEISTTOLLVM } // namespace mlir +// FIXME: All the following patterns with an "Old" suffix to their name should +// be removed once we drop typed pointer support. +struct BaseSubIndexOpLoweringOld : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + +protected: + // Compute the indices of the GEP operation we lower the SubIndexOp to. + // The indices are computed based on: + // a) the (converted) source element type, and + // b) the (converted) result element type that is requested + // Examples: + // - src ty: ptr>>, res ty: ptr + // -> idxs = [0, 0, SubIndexOp's index] + // - src ty: ptr>>, res ty: ptr> + // -> idxs = [0, SubIndexOp's index] + // + // Note: when the source element type is a struct with more than one member + // type, the result type that is requested is deemed illegal unless it is one + // of the source member types. For example assume: + // - src ty: ptr,i32>> + // - res ty: ptr + // This is illegal because res ty can only be either ptr or + // ptr> + static void computeIndices(const LLVM::LLVMStructType &srcElemType, + const Type &resElemType, + SmallVectorImpl &indices, SubIndexOp op, + OpAdaptor transformed, + ConversionPatternRewriter &rewriter) { + assert(indices.empty() && "Expecting an empty vector"); + + ArrayRef memTypes = srcElemType.getBody(); + unsigned numMembers = memTypes.size(); + assert((numMembers == 1 || + any_of(memTypes, [=](Type t) { return resElemType == t; })) && + "The requested result memref element type is illegal"); + + Type indexType = transformed.getIndex().getType(); + Value zero = rewriter.create( + op.getLoc(), indexType, rewriter.getIntegerAttr(indexType, 0)); + indices.push_back(zero); + + if (numMembers == 1) { + Type currType = srcElemType.getBody()[0]; + while (currType != resElemType) { + indices.push_back(zero); + + TypeSwitch(currType) + .Case([&](LLVM::LLVMStructType t) { + assert(t.getBody().size() == 1 && "Expecting single member type"); + currType = t.getBody()[0]; + }) + .Case( + [&](auto t) { currType = t.getElementType(); }) + .Default([&](Type t) { + currType = t; + assert(currType == resElemType && + "requested result type is illegal"); + }); + } + } + + indices.push_back(transformed.getIndex()); + } +}; + +/// Conversion pattern that transforms a subview op into: +/// 1. An `llvm.mlir.undef` operation to create a memref descriptor +/// 2. Updates to the descriptor to introduce the data ptr, offset, size +/// and stride. +/// The subview op is replaced by the descriptor. +struct SubIndexOpLoweringOld : public BaseSubIndexOpLoweringOld { + using BaseSubIndexOpLoweringOld::BaseSubIndexOpLoweringOld; + + LogicalResult + matchAndRewrite(SubIndexOp subViewOp, OpAdaptor transformed, + ConversionPatternRewriter &rewriter) const override { + assert(subViewOp.getSource().getType().isa() && + "Source operand should be a memref type"); + assert(subViewOp.getType().isa() && + "Result should be a memref type"); + + auto sourceMemRefType = subViewOp.getSource().getType().cast(); + auto viewMemRefType = subViewOp.getType().cast(); + + auto loc = subViewOp.getLoc(); + MemRefDescriptor targetMemRef(transformed.getSource()); + Value prev = targetMemRef.alignedPtr(rewriter, loc); + Value idxs[] = {transformed.getIndex()}; + + SmallVector sizes, strides; + if (sourceMemRefType.getRank() != viewMemRefType.getRank()) { + if (sourceMemRefType.getRank() != viewMemRefType.getRank() + 1) + return failure(); + + size_t sz = 1; + for (int64_t i = 1; i < sourceMemRefType.getRank(); i++) { + if (sourceMemRefType.getShape()[i] == ShapedType::kDynamic) + return failure(); + sz *= sourceMemRefType.getShape()[i]; + } + Value cop = rewriter.create( + loc, idxs[0].getType(), + rewriter.getIntegerAttr(idxs[0].getType(), sz)); + idxs[0] = rewriter.create(loc, idxs[0], cop); + for (int64_t i = 1; i < sourceMemRefType.getRank(); i++) { + sizes.push_back(targetMemRef.size(rewriter, loc, i)); + strides.push_back(targetMemRef.stride(rewriter, loc, i)); + } + } else { + for (int64_t i = 0; i < sourceMemRefType.getRank(); i++) { + sizes.push_back(targetMemRef.size(rewriter, loc, i)); + strides.push_back(targetMemRef.stride(rewriter, loc, i)); + } + } + + Type sourceElemType = sourceMemRefType.getElementType(); + Type convSourceElemType = getTypeConverter()->convertType(sourceElemType); + Type viewElemType = viewMemRefType.getElementType(); + Type convViewElemType = getTypeConverter()->convertType(viewElemType); + + // Handle the general (non-SYCL) case first. + if (convViewElemType == + prev.getType().cast().getElementType()) { + auto memRefDesc = createMemRefDescriptor( + loc, viewMemRefType, targetMemRef.allocatedPtr(rewriter, loc), + rewriter.create(loc, prev.getType(), prev, idxs), sizes, + strides, rewriter); + + rewriter.replaceOp(subViewOp, {memRefDesc}); + return success(); + } + assert(convSourceElemType.isa() && + "Expecting struct type"); + + // SYCL case + assert(sourceMemRefType.getRank() == viewMemRefType.getRank() && + "Expecting the input and output MemRef ranks to be the same"); + + SmallVector indices; + computeIndices(convSourceElemType.cast(), + convViewElemType, indices, subViewOp, transformed, rewriter); + assert(!indices.empty() && "Expecting a least one index"); + + // Note: MLIRScanner::InitializeValueByInitListExpr() in clang-mlir.cc, when + // a memref element type is a struct type, the return type of a + // polygeist.subindex operation should be a memref of the element type of + // the struct. + auto elemPtrTy = LLVM::LLVMPointerType::get( + convViewElemType, viewMemRefType.getMemorySpaceAsInt()); + auto gep = rewriter.create(loc, elemPtrTy, prev, indices); + auto memRefDesc = createMemRefDescriptor(loc, viewMemRefType, gep, gep, + sizes, strides, rewriter); + LLVM_DEBUG(llvm::dbgs() << "SubIndexOpLowering: gep: " << *gep << "\n"); + + rewriter.replaceOp(subViewOp, {memRefDesc}); + return success(); + } +}; + +struct SubIndexBarePtrOpLoweringOld : public BaseSubIndexOpLoweringOld { + using BaseSubIndexOpLoweringOld::BaseSubIndexOpLoweringOld; + + LogicalResult + matchAndRewrite(SubIndexOp subViewOp, OpAdaptor transformed, + ConversionPatternRewriter &rewriter) const override { + assert(subViewOp.getSource().getType().isa() && + "Source operand should be a memref type"); + assert(subViewOp.getType().isa() && + "Result should be a memref type"); + + auto sourceMemRefType = subViewOp.getSource().getType().cast(); + auto viewMemRefType = subViewOp.getType().cast(); + if (!canBeLoweredToBarePtr(sourceMemRefType) || + !canBeLoweredToBarePtr(viewMemRefType)) + return failure(); + + const auto loc = subViewOp.getLoc(); + const auto target = transformed.getSource(); + auto idx = transformed.getIndex(); + + if (sourceMemRefType.getRank() != viewMemRefType.getRank()) { + if (sourceMemRefType.getRank() != viewMemRefType.getRank() + 1) + return failure(); + + size_t sz = 1; + for (int64_t i = 1; i < sourceMemRefType.getRank(); i++) { + if (sourceMemRefType.getShape()[i] == ShapedType::kDynamic) + return failure(); + sz *= sourceMemRefType.getShape()[i]; + } + Value cop = rewriter.create( + loc, idx.getType(), rewriter.getIntegerAttr(idx.getType(), sz)); + idx = rewriter.create(loc, idx, cop); + } + + Type sourceElemType = sourceMemRefType.getElementType(); + Type convSourceElemType = getTypeConverter()->convertType(sourceElemType); + if (!convSourceElemType) + return failure(); + Type viewElemType = viewMemRefType.getElementType(); + Type convViewElemType = getTypeConverter()->convertType(viewElemType); + Type resType = getTypeConverter()->convertType(subViewOp.getType()); + + // Handle the general (non-SYCL) case first. + if (convViewElemType == + target.getType().cast().getElementType()) { + rewriter.replaceOpWithNewOp(subViewOp, resType, target, idx); + return success(); + } + assert(convSourceElemType.isa() && + "Expecting struct type"); + + // SYCL case + assert(sourceMemRefType.getRank() == viewMemRefType.getRank() && + "Expecting the input and output MemRef ranks to be the same"); + + SmallVector indices; + computeIndices(convSourceElemType.cast(), + convViewElemType, indices, subViewOp, transformed, rewriter); + assert(!indices.empty() && "Expecting a least one index"); + + // Note: MLIRScanner::InitializeValueByInitListExpr() in clang-mlir.cc, when + // a memref element type is a struct type, the return type of a + // polygeist.subindex operation should be a memref of the element type of + // the struct. + + rewriter.replaceOpWithNewOp(subViewOp, resType, target, + indices); + + return success(); + } +}; + +struct Memref2PointerOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Memref2PointerOp op, OpAdaptor transformed, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // MemRefDescriptor sourceMemRef(operands.front()); + MemRefDescriptor targetMemRef( + transformed.getSource()); // MemRefDescriptor::undef(rewriter, loc, + // targetDescTy); + + // Offset. + Value baseOffset = targetMemRef.offset(rewriter, loc); + Value ptr = targetMemRef.alignedPtr(rewriter, loc); + Value idxs[] = {baseOffset}; + ptr = rewriter.create(loc, ptr.getType(), ptr, idxs); + assert(ptr.getType().cast().getAddressSpace() == + op.getType().getAddressSpace() && + "Expecting Memref2PointerOp source and result types to have the " + "same address space"); + ptr = rewriter.create(loc, op.getType(), ptr); + + rewriter.replaceOp(op, {ptr}); + return success(); + } +}; + +struct Pointer2MemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Pointer2MemrefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + // MemRefDescriptor sourceMemRef(operands.front()); + auto convertedType = getTypeConverter()->convertType(op.getType()); + assert(convertedType && "unexpected failure in memref type conversion"); + auto descr = MemRefDescriptor::undef(rewriter, loc, convertedType); + assert(adaptor.getSource() + .getType() + .cast() + .getAddressSpace() == + op.getType().cast().getMemorySpaceAsInt() && + "Expecting Pointer2MemrefOp source and result types to have the " + "same address space"); + auto ptr = rewriter.create( + op.getLoc(), descr.getElementPtrType(), adaptor.getSource()); + + // Extract all strides and offsets and verify they are static. + int64_t offset; + SmallVector strides; + auto result = getStridesAndOffset(op.getType(), strides, offset); + (void)result; + assert(succeeded(result) && "unexpected failure in stride computation"); + assert(offset != ShapedType::kDynamic && "expected static offset"); + + bool first = true; + assert(!llvm::any_of(strides, [&](int64_t stride) { + if (first) { + first = false; + return false; + } + return stride == ShapedType::kDynamic; + }) && "expected static strides except first element"); + + descr.setAllocatedPtr(rewriter, loc, ptr); + descr.setAlignedPtr(rewriter, loc, ptr); + descr.setConstantOffset(rewriter, loc, offset); + + // Fill in sizes and strides + for (unsigned i = 0, e = op.getType().getRank(); i != e; ++i) { + descr.setConstantSize(rewriter, loc, i, op.getType().getDimSize(i)); + descr.setConstantStride(rewriter, loc, i, strides[i]); + } + + rewriter.replaceOp(op, {descr}); + return success(); + } +}; + +/// Lowers to a bitcast operation +struct BareMemref2PointerOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Memref2PointerOp op, OpAdaptor transformed, + ConversionPatternRewriter &rewriter) const override { + if (!canBeLoweredToBarePtr(op.getSource().getType())) + return failure(); + + const auto target = transformed.getSource(); + rewriter.replaceOpWithNewOp(op, op.getType(), target); + + return success(); + } +}; + +/// Lowers to a bitcast operation +struct BarePointer2MemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Pointer2MemrefOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!canBeLoweredToBarePtr(op.getType())) + return failure(); + + const auto convertedType = getTypeConverter()->convertType(op.getType()); + if (!convertedType) + return failure(); + rewriter.replaceOpWithNewOp(op, convertedType, + adaptor.getSource()); + return success(); + } +}; + struct BaseSubIndexOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -103,8 +459,10 @@ struct BaseSubIndexOpLowering : public ConvertOpToLLVMPattern { assert(t.getBody().size() == 1 && "Expecting single member type"); currType = t.getBody()[0]; }) - .Case( + .Case( [&](auto t) { currType = t.getElementType(); }) + .Case( + [&](auto t) { llvm_unreachable("Pointer type not allowed"); }) .Default([&](Type t) { currType = t; assert(currType == resElemType && @@ -174,11 +532,12 @@ struct SubIndexOpLowering : public BaseSubIndexOpLowering { // Handle the general (non-SYCL) case first. if (convViewElemType == - prev.getType().cast().getElementType()) { + cast(transformed.getSource().getType()).getElementType()) { auto memRefDesc = createMemRefDescriptor( loc, viewMemRefType, targetMemRef.allocatedPtr(rewriter, loc), - rewriter.create(loc, prev.getType(), prev, idxs), sizes, - strides, rewriter); + rewriter.create(loc, prev.getType(), convViewElemType, + prev, idxs), + sizes, strides, rewriter); rewriter.replaceOp(subViewOp, {memRefDesc}); return success(); @@ -200,8 +559,9 @@ struct SubIndexOpLowering : public BaseSubIndexOpLowering { // polygeist.subindex operation should be a memref of the element type of // the struct. auto elemPtrTy = LLVM::LLVMPointerType::get( - convViewElemType, viewMemRefType.getMemorySpaceAsInt()); - auto gep = rewriter.create(loc, elemPtrTy, prev, indices); + convViewElemType.getContext(), viewMemRefType.getMemorySpaceAsInt()); + auto gep = rewriter.create(loc, elemPtrTy, convViewElemType, + prev, indices); auto memRefDesc = createMemRefDescriptor(loc, viewMemRefType, gep, gep, sizes, strides, rewriter); LLVM_DEBUG(llvm::dbgs() << "SubIndexOpLowering: gep: " << *gep << "\n"); @@ -256,9 +616,9 @@ struct SubIndexBarePtrOpLowering : public BaseSubIndexOpLowering { Type resType = getTypeConverter()->convertType(subViewOp.getType()); // Handle the general (non-SYCL) case first. - if (convViewElemType == - target.getType().cast().getElementType()) { - rewriter.replaceOpWithNewOp(subViewOp, resType, target, idx); + if (convViewElemType == convSourceElemType) { + rewriter.replaceOpWithNewOp(subViewOp, resType, + convViewElemType, target, idx); return success(); } assert(convSourceElemType.isa() && @@ -278,8 +638,8 @@ struct SubIndexBarePtrOpLowering : public BaseSubIndexOpLowering { // polygeist.subindex operation should be a memref of the element type of // the struct. - rewriter.replaceOpWithNewOp(subViewOp, resType, target, - indices); + rewriter.replaceOpWithNewOp(subViewOp, resType, + convViewElemType, target, indices); return success(); } @@ -303,12 +663,13 @@ struct Memref2PointerOpLowering Value baseOffset = targetMemRef.offset(rewriter, loc); Value ptr = targetMemRef.alignedPtr(rewriter, loc); Value idxs[] = {baseOffset}; - ptr = rewriter.create(loc, ptr.getType(), ptr, idxs); + auto elemType = getTypeConverter()->convertType( + op.getSource().getType().getElementType()); + ptr = rewriter.create(loc, ptr.getType(), elemType, ptr, idxs); assert(ptr.getType().cast().getAddressSpace() == op.getType().getAddressSpace() && "Expecting Memref2PointerOp source and result types to have the " "same address space"); - ptr = rewriter.create(loc, op.getType(), ptr); rewriter.replaceOp(op, {ptr}); return success(); @@ -335,8 +696,7 @@ struct Pointer2MemrefOpLowering op.getType().cast().getMemorySpaceAsInt() && "Expecting Pointer2MemrefOp source and result types to have the " "same address space"); - auto ptr = rewriter.create( - op.getLoc(), descr.getElementPtrType(), adaptor.getSource()); + auto ptr = adaptor.getSource(); // Extract all strides and offsets and verify they are static. int64_t offset; @@ -396,7 +756,9 @@ struct BareMemref2PointerOpLowering return failure(); const auto target = transformed.getSource(); - rewriter.replaceOpWithNewOp(op, op.getType(), target); + // In an opaque pointer world, a bitcast is a no-op, so no need to insert + // one here. + rewriter.replaceOp(op, target); return success(); } @@ -416,8 +778,9 @@ struct BarePointer2MemrefOpLowering const auto convertedType = getTypeConverter()->convertType(op.getType()); if (!convertedType) return failure(); - rewriter.replaceOpWithNewOp(op, convertedType, - adaptor.getSource()); + // In an opaque pointer world, a bitcast is a no-op, so no need to insert + // one here. + rewriter.replaceOp(op, adaptor.getSource()); return success(); } }; @@ -487,16 +850,32 @@ void populatePolygeistToLLVMConversionPatterns(LLVMTypeConverter &converter, assert(converter.getOptions().useBarePtrCallConv && "These patterns only work with bare pointer calling convention"); - patterns.add(converter); - // When adding these patterns (and other patterns changing the default - // conversion of operations on MemRef values), a higher benefit is passed - // (2), so that these patterns have a higher priority than the ones - // performing the default conversion, which should only run if the "bare - // pointer" ones fail. - patterns.add(converter, - /*benefit*/ 2); + if (converter.useOpaquePointers()) { + patterns.add(converter); + // When adding these patterns (and other patterns changing the + // default conversion of operations on MemRef values), a higher + // benefit is passed (2), so that these patterns have a higher + // priority than the ones performing the default conversion, which + // should only run if the "bare pointer" ones fail. + patterns.add(converter, + /*benefit*/ 2); + } else { + // FIXME: This 'else'-part should be removed completely when we drop typed + // pointer support. + patterns.add( + converter); + // When adding these patterns (and other patterns changing the + // default conversion of operations on MemRef values), a higher + // benefit is passed (2), so that these patterns have a higher + // priority than the ones performing the default conversion, which + // should only run if the "bare pointer" ones fail. + patterns.add(converter, + /*benefit*/ 2); + } } namespace { @@ -565,8 +944,11 @@ struct URLLVMOpLowering } }; +// FIXME: The following function and pattern with the "Old" suffix should be +// removed once we drop typed pointer support. + // TODO lock this wrt module -static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { +static LLVM::LLVMFuncOp addMocCUDAFunctionOld(ModuleOp module, Type streamTy) { const char fname[] = "fake_cuda_dispatch"; MLIRContext *ctx = module.getContext(); @@ -592,7 +974,7 @@ static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { return resumeOp; } -struct AsyncOpLowering : public ConvertOpToLLVMPattern { +struct AsyncOpLoweringOld : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -775,6 +1157,214 @@ struct AsyncOpLowering : public ConvertOpToLLVMPattern { } assert(vals.size() == 3); + auto f = addMocCUDAFunctionOld(execute->getParentOfType(), + vals.back().getType()); + + rewriter.create(execute.getLoc(), f, vals); + rewriter.eraseOp(execute); + } + + return success(); + } +}; + +// TODO lock this wrt module +static LLVM::LLVMFuncOp addMocCUDAFunction(ModuleOp module, Type streamTy) { + const char fname[] = "fake_cuda_dispatch"; + + MLIRContext *ctx = module.getContext(); + auto loc = module.getLoc(); + auto moduleBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); + + for (auto fn : module.getBody()->getOps()) { + if (fn.getName() == fname) + return fn; + } + + auto voidTy = LLVM::LLVMVoidType::get(ctx); + auto i8Ptr = LLVM::LLVMPointerType::get(ctx); + + auto resumeOp = moduleBuilder.create( + fname, LLVM::LLVMFunctionType::get( + voidTy, {i8Ptr, LLVM::LLVMPointerType::get(ctx), streamTy})); + resumeOp.setPrivate(); + + return resumeOp; +} + +struct AsyncOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(async::ExecuteOp execute, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ModuleOp module = execute->getParentOfType(); + + MLIRContext *ctx = module.getContext(); + Location loc = execute.getLoc(); + + auto voidTy = LLVM::LLVMVoidType::get(ctx); + Type voidPtr = LLVM::LLVMPointerType::get(ctx); + + // Make sure that all constants will be inside the outlined async function + // to reduce the number of function arguments. + Region &funcReg = execute.getRegion(); + + // Collect all outlined function inputs. + SetVector functionInputs; + + getUsedValuesDefinedAbove(execute.getRegion(), funcReg, functionInputs); + SmallVector toErase; + for (auto a : functionInputs) { + Operation *op = a.getDefiningOp(); + if (op && op->hasTrait()) + toErase.push_back(a); + } + for (auto a : toErase) { + functionInputs.remove(a); + } + + // Collect types for the outlined function inputs and outputs. + TypeConverter *converter = getTypeConverter(); + auto typesRange = llvm::map_range(functionInputs, [&](Value value) { + return converter->convertType(value.getType()); + }); + SmallVector inputTypes(typesRange.begin(), typesRange.end()); + + Type ftypes[] = {voidPtr}; + auto funcType = LLVM::LLVMFunctionType::get(voidTy, ftypes); + + // TODO: Derive outlined function name from the parent FuncOp (support + // multiple nested async.execute operations). + auto moduleBuilder = + ImplicitLocOpBuilder::atBlockEnd(loc, module.getBody()); + + static int off = 0; + off++; + auto func = moduleBuilder.create( + execute.getLoc(), + "kernelbody." + std::to_string((long long int)&execute) + "." + + std::to_string(off), + funcType); + + rewriter.setInsertionPointToStart(func.addEntryBlock()); + IRMapping valueMapping; + for (Value capture : toErase) { + Operation *op = capture.getDefiningOp(); + for (auto r : + llvm::zip(op->getResults(), + rewriter.clone(*op, valueMapping)->getResults())) { + valueMapping.map(rewriter.getRemappedValue(std::get<0>(r)), + std::get<1>(r)); + } + } + // Prepare for coroutine conversion by creating the body of the function. + { + // Map from function inputs defined above the execute op to the function + // arguments. + auto arg = func.getArgument(0); + + if (functionInputs.size() == 0) { + } else if (functionInputs.size() == 1 && + converter->convertType(functionInputs[0].getType()) + .isa()) { + valueMapping.map(functionInputs[0], arg); + } else if (functionInputs.size() == 1 && + converter->convertType(functionInputs[0].getType()) + .isa()) { + valueMapping.map( + functionInputs[0], + rewriter.create( + execute.getLoc(), + converter->convertType(functionInputs[0].getType()), arg)); + } else { + SmallVector types; + for (auto v : functionInputs) + types.push_back(converter->convertType(v.getType())); + + for (auto idx : llvm::enumerate(functionInputs)) { + + mlir::Value idxs[] = { + rewriter.create(loc, 0, 32), + rewriter.create(loc, idx.index(), 32), + }; + auto nextTy = types[idx.index()]; + Value next = rewriter.create( + loc, LLVM::LLVMPointerType::get(ctx), nextTy, arg, idxs); + valueMapping.map(idx.value(), + rewriter.create(loc, nextTy, next)); + } + auto freef = getFreeFn(*getTypeConverter(), module); + Value args[] = {arg}; + rewriter.create(loc, freef, args); + } + + // Clone all operations from the execute operation body into the outlined + // function body. + for (Operation &op : execute.getBody()->without_terminator()) + rewriter.clone(op, valueMapping); + + rewriter.create(execute.getLoc(), ValueRange()); + } + + // Replace the original `async.execute` with a call to outlined function. + { + rewriter.setInsertionPoint(execute); + SmallVector crossing; + for (auto tup : llvm::zip(functionInputs, inputTypes)) { + Value val = std::get<0>(tup); + crossing.push_back(val); + } + + SmallVector vals; + if (crossing.size() == 0) { + vals.push_back( + rewriter.create(execute.getLoc(), voidPtr)); + } else if (crossing.size() == 1 && + converter->convertType(crossing[0].getType()) + .isa()) { + vals.push_back(crossing[0]); + } else if (crossing.size() == 1 && + converter->convertType(crossing[0].getType()) + .isa()) { + vals.push_back(rewriter.create(execute.getLoc(), + voidPtr, crossing[0])); + } else { + SmallVector types; + for (auto v : crossing) + types.push_back(v.getType()); + auto ST = LLVM::LLVMStructType::getLiteral(ctx, types); + + auto mallocf = getAllocFn(*getTypeConverter(), module, getIndexType()); + + Value args[] = {rewriter.create( + loc, rewriter.getI64Type(), + rewriter.create(loc, rewriter.getIndexType(), + ST))}; + mlir::Value alloc = + rewriter.create(loc, mallocf, args).getResult(); + rewriter.setInsertionPoint(execute); + for (auto idx : llvm::enumerate(crossing)) { + + mlir::Value idxs[] = { + rewriter.create(loc, 0, 32), + rewriter.create(loc, idx.index(), 32), + }; + Value next = + rewriter.create(loc, LLVM::LLVMPointerType::get(ctx), + idx.value().getType(), alloc, idxs); + rewriter.create(loc, idx.value(), next); + } + vals.push_back(alloc); + } + vals.push_back( + rewriter.create(execute.getLoc(), func)); + for (auto dep : execute.getDependencies()) { + auto ctx = dep.getDefiningOp(); + vals.push_back(ctx.getSource()); + } + assert(vals.size() == 3); + auto f = addMocCUDAFunction(execute->getParentOfType(), vals.back().getType()); @@ -922,6 +1512,7 @@ struct ConvertPolygeistToLLVMPass LowerToLLVMOptions options(&getContext(), dataLayoutAnalysis.getAtOrAbove(m)); options.useBarePtrCallConv = true; + options.useOpaquePointers = useOpaquePointers; if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout) options.overrideIndexBitwidth(indexBitwidth); @@ -931,7 +1522,6 @@ struct ConvertPolygeistToLLVMPass LLVMTypeConverter converter(&getContext(), options, &dataLayoutAnalysis); RewritePatternSet patterns(&getContext()); - // Keep these at the top; these should be run before the rest of // function conversion patterns. populateReturnOpTypeConversionPattern(patterns, converter); @@ -964,7 +1554,8 @@ struct ConvertPolygeistToLLVMPass // Run these instead of the ones provided by the dialect to avoid lowering // memrefs to a struct. - populateBareMemRefToLLVMConversionPatterns(converter, patterns); + populateBareMemRefToLLVMConversionPatterns(converter, patterns, + useOpaquePointers); // Legality callback for operations that checks whether their operand and // results types are converted. @@ -1019,7 +1610,13 @@ struct ConvertPolygeistToLLVMPass if (i == 1) { target.addIllegalOp(); - patterns.add(converter); + if (useOpaquePointers) { + patterns.add(converter); + } else { + // FIXME: This part should be removed when we drop typed pointer + // support. + patterns.add(converter); + } patterns.add(converter); } if (failed(applyPartialConversion(m, target, std::move(patterns)))) diff --git a/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp b/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp index fc78e7c35b32b..56614b3fbcd1d 100644 --- a/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp +++ b/polygeist/lib/Dialect/Polygeist/Transforms/BareMemRefToLLVM.cpp @@ -27,6 +27,306 @@ struct GetGlobalMemrefOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(memref::GetGlobalOp getGlobalOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto memrefTy = getGlobalOp.getType(); + if (!canBeLoweredToBarePtr(memrefTy)) + return failure(); + + // LLVM type for a global memref will be a multi-dimension array. For + // declarations or uninitialized global memrefs, we can potentially flatten + // this to a 1D array. However, for memref.global's with an initial value, + // we do not intend to flatten the ElementsAttribute when going from std -> + // LLVM dialect, so the LLVM type needs to me a multi-dimension array. + const auto convElemType = + typeConverter->convertType(memrefTy.getElementType()); + if (!convElemType) + return failure(); + const auto addressOf = + static_cast(rewriter.create( + getGlobalOp.getLoc(), + LLVM::LLVMPointerType::get(memrefTy.getContext(), + memrefTy.getMemorySpaceAsInt()), + adaptor.getName())); + + // Get the address of the first element in the array by creating a GEP with + // the address of the GV as the base, and (rank + 1) number of 0 indices. + rewriter.replaceOpWithNewOp( + getGlobalOp, typeConverter->convertType(memrefTy), convElemType, + addressOf, SmallVector(memrefTy.getRank() + 1, 0), + /* inbounds */ true); + + return success(); + } +}; + +/// Simply replace by the source, as we don't care about the shape. +struct ReshapeMemrefOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::ReshapeOp reshape, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!canBeLoweredToBarePtr(reshape.getType()) || + !canBeLoweredToBarePtr( + reshape.getSource().getType().cast())) + return failure(); + + rewriter.replaceOp(reshape, adaptor.getSource()); + return success(); + } +}; + +/// Conversion similar to the canonical one, but not inserting the obtained +/// pointer in a struct. +struct AllocaMemrefOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::AllocaOp allocaOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto memrefType = allocaOp.getType(); + if (!memrefType.hasStaticShape() || !memrefType.getLayout().isIdentity()) + return failure(); + + const auto ptrType = typeConverter->convertType(allocaOp.getType()); + if (!ptrType) + return failure(); + const auto convElemType = + typeConverter->convertType(memrefType.getElementType()); + const auto loc = allocaOp.getLoc(); + auto nullPtr = rewriter.create(loc, ptrType); + auto gepPtr = rewriter.create( + loc, ptrType, convElemType, nullPtr, + createIndexConstant(rewriter, loc, memrefType.getNumElements())); + auto sizeBytes = + rewriter.create(loc, getIndexType(), gepPtr); + + rewriter.replaceOpWithNewOp( + allocaOp, ptrType, convElemType, sizeBytes, + allocaOp.getAlignment().value_or(0)); + return success(); + } +}; + +static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, + Value input, Value alignment) { + auto one = rewriter.create(loc, alignment.getType(), 1); + auto bump = rewriter.create(loc, alignment, one); + auto bumped = rewriter.create(loc, input, bump); + auto mod = rewriter.create(loc, bumped, alignment); + return rewriter.create(loc, bumped, mod); +} + +/// Conversion similar to the canonical one, but not inserting the obtained +/// pointer in a struct. +struct AllocMemrefOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto memrefType = allocOp.getType(); + const auto elementPtrType = typeConverter->convertType(memrefType); + if (!elementPtrType || !memrefType.hasStaticShape() || + !memrefType.getLayout().isIdentity()) + return failure(); + + const auto loc = allocOp.getLoc(); + SmallVector sizes; + SmallVector strides; + Value sizeBytes; + getMemRefDescriptorSizes(loc, memrefType, adaptor.getOperands(), rewriter, + sizes, strides, sizeBytes); + + const auto alignment = + llvm::transformOptional(allocOp.getAlignment(), [&](auto val) { + return createIndexConstant(rewriter, loc, val); + }); + if (alignment) { + // Adjust the allocation size to consider alignment. + sizeBytes = rewriter.create(loc, sizeBytes, *alignment); + } + + auto module = allocOp->getParentOfType(); + const auto allocFuncOp = + getAllocFn(*getTypeConverter(), module, getIndexType()); + + auto alignedPtr = static_cast( + rewriter.create(loc, allocFuncOp, sizeBytes) + .getResults() + .front()); + if (alignment) { + // Compute the aligned pointer. + const auto allocatedInt = static_cast( + rewriter.create(loc, getIndexType(), alignedPtr)); + const auto alignmentInt = + createAligned(rewriter, loc, allocatedInt, *alignment); + alignedPtr = + rewriter.create(loc, elementPtrType, alignmentInt); + } + rewriter.replaceOp(allocOp, {alignedPtr}); + return success(); + } +}; + +/// Conversion similar to the canonical one, but not extracting the allocated +/// pointer from a struct. +struct DeallocOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::DeallocOp deallocOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!canBeLoweredToBarePtr( + deallocOp.getMemref().getType().cast())) + return failure(); + // Insert the `free` declaration if it is not already present. + const auto freeFunc = + getFreeFn(*getTypeConverter(), deallocOp->getParentOfType()); + rewriter.replaceOpWithNewOp(deallocOp, freeFunc, + adaptor.getMemref()); + return success(); + } +}; + +/// Lowers to an identity operation. +struct CastMemrefOpLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult match(memref::CastOp castOp) const override { + const auto srcType = castOp.getOperand().getType().cast(); + const auto dstType = castOp.getType().cast(); + + // This will be replaced by an identity function, so we need input and + // output types to match. + return success(canBeLoweredToBarePtr(dstType) && + canBeLoweredToBarePtr(srcType) && + typeConverter->convertType(srcType) == + typeConverter->convertType(dstType)); + } + + void rewrite(memref::CastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(castOp, adaptor.getSource()); + } +}; + +struct MemorySpaceCastMemRefOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(memref::MemorySpaceCastOp castOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + const auto newTy = getTypeConverter()->convertType(castOp.getType()); + rewriter.replaceOpWithNewOp(castOp, newTy, + adaptor.getSource()); + return success(); + } +}; + +/// Base class for lowering operations implementing memory accesses. +struct MemAccessLowering : public ConvertToLLVMPattern { + using ConvertToLLVMPattern::ConvertToLLVMPattern; + + /// Obtains offset from a memory access indices + Value getStridedElementBarePtr(Location loc, MemRefType type, Value base, + ValueRange indices, + ConversionPatternRewriter &rewriter) const { + int64_t offset; + SmallVector strides; + LogicalResult successStrides = getStridesAndOffset(type, strides, offset); + assert(succeeded(successStrides) && "unexpected non-strided memref"); + (void)successStrides; + + auto index = + offset == 0 ? Value{} : createIndexConstant(rewriter, loc, offset); + + for (const auto &iter : llvm::enumerate(llvm::zip(indices, strides))) { + auto increment = std::get<0>(iter.value()); + const auto stride = std::get<1>(iter.value()); + if (stride != 1) { // Skip if stride is 1. + increment = rewriter.create( + loc, increment, createIndexConstant(rewriter, loc, stride)); + } + index = index ? rewriter.create(loc, index, increment) + : increment; + } + const auto elementPtrType = getTypeConverter()->convertType(type); + if (!elementPtrType) + return {}; + const auto convElemType = + getTypeConverter()->convertType(type.getElementType()); + return index ? rewriter.create(loc, elementPtrType, + convElemType, base, index) + : base; + } +}; + +struct LoadMemRefOpLowering : public MemAccessLowering { + LoadMemRefOpLowering(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : MemAccessLowering{memref::LoadOp::getOperationName(), + &typeConverter.getContext(), typeConverter, benefit} { + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef args, + ConversionPatternRewriter &rewriter) const override { + auto loadOp = cast(op); + if (!canBeLoweredToBarePtr(loadOp.getMemRefType())) + return failure(); + memref::LoadOp::Adaptor adaptor{args}; + const Value DataPtr = getStridedElementBarePtr( + loadOp.getLoc(), loadOp.getMemRefType(), adaptor.getMemref(), + adaptor.getIndices(), rewriter); + if (!DataPtr) + return failure(); + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(loadOp.getType()), DataPtr); + return success(); + } +}; + +struct StoreMemRefOpLowering : public MemAccessLowering { + StoreMemRefOpLowering(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : MemAccessLowering{memref::StoreOp::getOperationName(), + &typeConverter.getContext(), typeConverter, benefit} { + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef args, + ConversionPatternRewriter &rewriter) const override { + auto storeOp = cast(op); + if (!canBeLoweredToBarePtr(storeOp.getMemRefType())) + return failure(); + memref::StoreOp::Adaptor adaptor{args}; + const Value DataPtr = getStridedElementBarePtr( + storeOp.getLoc(), storeOp.getMemRefType(), adaptor.getMemref(), + adaptor.getIndices(), rewriter); + if (!DataPtr) + return failure(); + rewriter.replaceOpWithNewOp(op, adaptor.getValue(), DataPtr); + return success(); + } +}; +} // namespace + +// The following patterns are outdated and only used in case typed pointers +// should be used for the lowering. They will be removed soon. +namespace { +/// Conversion similar to the canonical one, but not inserting the obtained +/// pointer in a struct. +struct GetGlobalMemrefOpLoweringOld + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult matchAndRewrite(memref::GetGlobalOp getGlobalOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -75,7 +375,7 @@ struct GetGlobalMemrefOpLowering }; /// Simply replace by the source, as we don't care about the shape. -struct ReshapeMemrefOpLowering +struct ReshapeMemrefOpLoweringOld : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -94,7 +394,7 @@ struct ReshapeMemrefOpLowering /// Conversion similar to the canonical one, but not inserting the obtained /// pointer in a struct. -struct AllocaMemrefOpLowering +struct AllocaMemrefOpLoweringOld : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -123,8 +423,8 @@ struct AllocaMemrefOpLowering } }; -static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, - Value input, Value alignment) { +static Value createAlignedOld(ConversionPatternRewriter &rewriter, Location loc, + Value input, Value alignment) { auto one = rewriter.create(loc, alignment.getType(), 1); auto bump = rewriter.create(loc, alignment, one); auto bumped = rewriter.create(loc, input, bump); @@ -134,7 +434,8 @@ static Value createAligned(ConversionPatternRewriter &rewriter, Location loc, /// Conversion similar to the canonical one, but not inserting the obtained /// pointer in a struct. -struct AllocMemrefOpLowering : public ConvertOpToLLVMPattern { +struct AllocMemrefOpLoweringOld + : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -175,7 +476,7 @@ struct AllocMemrefOpLowering : public ConvertOpToLLVMPattern { const auto allocatedInt = static_cast( rewriter.create(loc, getIndexType(), alignedPtr)); const auto alignmentInt = - createAligned(rewriter, loc, allocatedInt, *alignment); + createAlignedOld(rewriter, loc, allocatedInt, *alignment); alignedPtr = rewriter.create(loc, elementPtrType, alignmentInt); } @@ -186,7 +487,7 @@ struct AllocMemrefOpLowering : public ConvertOpToLLVMPattern { /// Conversion similar to the canonical one, but not extracting the allocated /// pointer from a struct. -struct DeallocOpLowering : public ConvertOpToLLVMPattern { +struct DeallocOpLoweringOld : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult @@ -209,7 +510,7 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern { }; /// Lowers to an identity operation. -struct CastMemrefOpLowering : public ConvertOpToLLVMPattern { +struct CastMemrefOpLoweringOld : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult match(memref::CastOp castOp) const override { @@ -230,7 +531,7 @@ struct CastMemrefOpLowering : public ConvertOpToLLVMPattern { } }; -struct MemorySpaceCastMemRefOpLowering +struct MemorySpaceCastMemRefOpLoweringOld : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< memref::MemorySpaceCastOp>::ConvertOpToLLVMPattern; @@ -246,7 +547,7 @@ struct MemorySpaceCastMemRefOpLowering }; /// Base class for lowering operations implementing memory accesses. -struct MemAccessLowering : public ConvertToLLVMPattern { +struct MemAccessLoweringOld : public ConvertToLLVMPattern { using ConvertToLLVMPattern::ConvertToLLVMPattern; /// Obtains offset from a memory access indices @@ -281,9 +582,9 @@ struct MemAccessLowering : public ConvertToLLVMPattern { } }; -struct LoadMemRefOpLowering : public MemAccessLowering { - LoadMemRefOpLowering(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) +struct LoadMemRefOpLoweringOld : public MemAccessLowering { + LoadMemRefOpLoweringOld(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) : MemAccessLowering{memref::LoadOp::getOperationName(), &typeConverter.getContext(), typeConverter, benefit} { } @@ -305,9 +606,9 @@ struct LoadMemRefOpLowering : public MemAccessLowering { } }; -struct StoreMemRefOpLowering : public MemAccessLowering { - StoreMemRefOpLowering(LLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) +struct StoreMemRefOpLoweringOld : public MemAccessLowering { + StoreMemRefOpLoweringOld(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) : MemAccessLowering{memref::StoreOp::getOperationName(), &typeConverter.getContext(), typeConverter, benefit} { } @@ -331,23 +632,39 @@ struct StoreMemRefOpLowering : public MemAccessLowering { } // namespace void mlir::polygeist::populateBareMemRefToLLVMConversionPatterns( - mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns) { + mlir::LLVMTypeConverter &converter, RewritePatternSet &patterns, + bool useOpaquePointers) { assert(converter.getOptions().useBarePtrCallConv && "Expecting \"bare pointer\" calling convention"); - patterns.add( - converter, 2); + if (useOpaquePointers) { + patterns.add( + converter, 2); + } else { + patterns.add(converter, 2); + } // Patterns are tried in reverse add order, so this is tried before the // one added by default. - converter.addConversion([&](MemRefType type) -> Optional { - if (!canBeLoweredToBarePtr(type)) - return std::nullopt; - const auto elemType = converter.convertType(type.getElementType()); - if (!elemType) - return Type{}; - return LLVM::LLVMPointerType::get(elemType, type.getMemorySpaceAsInt()); - }); + converter.addConversion( + [&, useOpaquePointers](MemRefType type) -> Optional { + if (!canBeLoweredToBarePtr(type)) + return std::nullopt; + + if (useOpaquePointers) { + return LLVM::LLVMPointerType::get(type.getContext(), + type.getMemorySpaceAsInt()); + } + + const auto elemType = converter.convertType(type.getElementType()); + if (!elemType) + return Type{}; + return LLVM::LLVMPointerType::get(elemType, type.getMemorySpaceAsInt()); + }); } diff --git a/polygeist/test/polygeist-opt/bareptrlowering-typed-pointer.mlir b/polygeist/test/polygeist-opt/bareptrlowering-typed-pointer.mlir new file mode 100644 index 0000000000000..dd98613cb7f18 --- /dev/null +++ b/polygeist/test/polygeist-opt/bareptrlowering-typed-pointer.mlir @@ -0,0 +1,573 @@ +// RUN: polygeist-opt %s --convert-polygeist-to-llvm='use-opaque-pointers=0' --split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @ptr_ret_static(i64) -> !llvm.ptr + +func.func private @ptr_ret_static(%arg0: i64) -> memref<4xi64> + +// ----- + +// CHECK-LABEL: llvm.func @ptr_ret_dynamic(i64) -> !llvm.ptr + +func.func private @ptr_ret_dynamic(%arg0: i64) -> memref + +// ----- + +// CHECK-LABEL: llvm.func @ptr_ret_nd_static(i64) -> !llvm.ptr + +func.func private @ptr_ret_nd_static(%arg0: i64) -> memref<4x4xi64> + +// ----- + +// CHECK-LABEL: llvm.func @ptr_ret_nd_dynamic(i64) -> !llvm.ptr + +func.func private @ptr_ret_nd_dynamic(%arg0: i64) -> memref + +// ----- + +// CHECK-LABEL: llvm.func @ptr_args_and_ret(!llvm.ptr, !llvm.ptr) -> !llvm.ptr + +func.func private @ptr_args_and_ret(%arg0: memref<1xi64>, %arg1: memref) -> memref + +// ----- + +// CHECK-LABEL: llvm.func @ptr_args_and_ret_with_attrs(!llvm.ptr {llvm.byval = i64}, !llvm.ptr {llvm.byval = i64}) -> !llvm.ptr + +func.func private @ptr_args_and_ret_with_attrs(%arg0: memref<1xi64> {llvm.byval = i64}, + %arg1: memref {llvm.byval = i64}) -> memref + +// ----- + +gpu.module @kernels { + +// CHECK-LABEL: llvm.func @kernel( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr {llvm.byval = i64}, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr {llvm.byval = i64}) attributes {gpu.kernel, workgroup_attributions = 0 : i64} { +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + + gpu.func @kernel(%arg0: memref<1xi64> {llvm.byval = i64}, + %arg1: memref {llvm.byval = i64}) kernel { + gpu.return + } +} + +// ----- + +// CHECK-LABEL: llvm.mlir.global external @global() {addr_space = 0 : i32} : !llvm.array<3 x i64> + +memref.global @global : memref<3xi64> + +// CHECK-LABEL: llvm.func @get_global() -> !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global : !llvm.ptr> +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr>) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @get_global() -> memref<3xi64> { + %0 = memref.get_global @global : memref<3xi64> + return %0 : memref<3xi64> +} + +// ----- + +// CHECK-LABEL: llvm.mlir.global external @global_addrspace() {addr_space = 4 : i32} : !llvm.array<3 x i64> + +memref.global @global_addrspace : memref<3xi64, 4> + +// CHECK-LABEL: llvm.func @get_global_addrspace() -> !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global_addrspace : !llvm.ptr, 4> +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr, 4>) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @get_global_addrspace() -> memref<3xi64, 4> { + %0 = memref.get_global @global_addrspace : memref<3xi64, 4> + return %0 : memref<3xi64, 4> +} + +// ----- + +memref.global "private" constant @shape : memref<2xi64> = dense<[2, 2]> + +// CHECK-LABEL: llvm.func @reshape( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr>) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @reshape(%arg0: memref<4xi32>) -> memref<2x2xi32> { + %shape = memref.get_global @shape : memref<2xi64> + %0 = memref.reshape %arg0(%shape) : (memref<4xi32>, memref<2xi64>) -> memref<2x2xi32> + return %0 : memref<2x2xi32> +} + +// ----- + +memref.global "private" constant @shape : memref<1xindex> + +// CHECK-LABEL: llvm.func @reshape_dyn( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr>) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @reshape_dyn(%arg0: memref<4xi32>) -> memref { + %shape = memref.get_global @shape : memref<1xindex> + %0 = memref.reshape %arg0(%shape) : (memref<4xi32>, memref<1xindex>) -> memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @alloca() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloca() { + %0 = memref.alloca() : memref<2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloca_nd() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloca_nd() { + %0 = memref.alloca() : memref<3x10x2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloca_aligned() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloca_aligned() { + %0 = memref.alloca() {alignment = 8} : memref<2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloca_nd_aligned() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloca_nd_aligned() { + %0 = memref.alloca() {alignment = 8} : memref<3x10x2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @malloc(i64) -> !llvm.ptr + +// CHECK-LABEL: llvm.func @alloc() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.call @malloc(%[[VAL_4]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.bitcast %[[VAL_5]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloc() { + %0 = memref.alloc() : memref<2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloc_nd() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(3 : index) : i64 +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(10 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.call @malloc(%[[VAL_8]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloc_nd() { + %0 = memref.alloc() : memref<3x10x2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloc_aligned() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_4]], %[[VAL_5]] : i64 +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.call @malloc(%[[VAL_6]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.ptrtoint %[[VAL_8]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-NEXT: %[[VAL_11:.*]] = llvm.sub %[[VAL_5]], %[[VAL_10]] : i64 +// CHECK-NEXT: %[[VAL_12:.*]] = llvm.add %[[VAL_9]], %[[VAL_11]] : i64 +// CHECK-NEXT: %[[VAL_13:.*]] = llvm.urem %[[VAL_12]], %[[VAL_5]] : i64 +// CHECK-NEXT: %[[VAL_14:.*]] = llvm.sub %[[VAL_12]], %[[VAL_13]] : i64 +// CHECK-NEXT: %[[VAL_15:.*]] = llvm.inttoptr %[[VAL_14]] : i64 to !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloc_aligned() { + %0 = memref.alloc() {alignment = 8} : memref<2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @alloc_nd_aligned() +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(3 : index) : i64 +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(10 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(2 : index) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.mlir.constant(8 : index) : i64 +// CHECK-NEXT: %[[VAL_10:.*]] = llvm.add %[[VAL_8]], %[[VAL_9]] : i64 +// CHECK-NEXT: %[[VAL_11:.*]] = llvm.call @malloc(%[[VAL_10]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[VAL_13:.*]] = llvm.ptrtoint %[[VAL_12]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_14:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-NEXT: %[[VAL_15:.*]] = llvm.sub %[[VAL_9]], %[[VAL_14]] : i64 +// CHECK-NEXT: %[[VAL_16:.*]] = llvm.add %[[VAL_13]], %[[VAL_15]] : i64 +// CHECK-NEXT: %[[VAL_17:.*]] = llvm.urem %[[VAL_16]], %[[VAL_9]] : i64 +// CHECK-NEXT: %[[VAL_18:.*]] = llvm.sub %[[VAL_16]], %[[VAL_17]] : i64 +// CHECK-NEXT: %[[VAL_19:.*]] = llvm.inttoptr %[[VAL_18]] : i64 to !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @alloc_nd_aligned() { + %0 = memref.alloc() {alignment = 8} : memref<3x10x2xi32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @dealloc( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.call @free(%[[VAL_1]]) : (!llvm.ptr) -> () +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @dealloc(%arg0: memref) { + memref.dealloc %arg0 : memref + return +} + +// ----- + +// CHECK-LABEL: llvm.func @cast( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @cast(%arg0: memref<2xi32>) -> memref { + %0 = memref.cast %arg0 : memref<2xi32> to memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @load( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> f32 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.load %[[VAL_2]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : f32 +// CHECK-NEXT: } + +func.func private @load(%arg0: memref<100xf32>, %index: index) -> f32 { + %0 = memref.load %arg0[%index] : memref<100xf32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: llvm.func @load_nd( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, +// CHECK-SAME: %[[VAL_2:.*]]: i64) -> f32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_7]] : f32 +// CHECK-NEXT: } + +func.func private @load_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: index) -> f32 { + %0 = memref.load %arg0[%index0, %index1] : memref<100x100xf32> + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: llvm.func @load_nd_dyn( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, +// CHECK-SAME: %[[VAL_2:.*]]: i64) -> f32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_7]] : f32 +// CHECK-NEXT: } + +func.func private @load_nd_dyn(%arg0: memref, %index0: index, %index1: index) -> f32 { + %0 = memref.load %arg0[%index0, %index1] : memref + return %0 : f32 +} + +// ----- + +// CHECK-LABEL: llvm.func @store( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, +// CHECK-SAME: %[[VAL_2:.*]]: f32) +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.store %[[VAL_2]], %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @store(%arg0: memref<100xf32>, %index: index, %val: f32) { + memref.store %val, %arg0[%index] : memref<100xf32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @store_nd( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, +// CHECK-SAME: %[[VAL_3:.*]]: f32) +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mul %[[VAL_1]], %[[VAL_4]] : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_5]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @store_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: index, %val: f32) { + memref.store %val, %arg0[%index0, %index1] : memref<100x100xf32> + return +} + +// ----- + +// CHECK-LABEL: llvm.func @store_nd_dyn( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, +// CHECK-SAME: %[[VAL_3:.*]]: f32) +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(100 : index) : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.mul %[[VAL_1]], %[[VAL_4]] : i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_5]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr +// CHECK-NEXT: llvm.return +// CHECK-NEXT: } + +func.func private @store_nd_dyn(%arg0: memref, %index0: index, %index1: index, %val: f32) { + memref.store %val, %arg0[%index0, %index1] : memref + return +} + +// ----- + +// CHECK-LABEL: llvm.func @impl(!llvm.ptr, i64) -> !llvm.ptr + +func.func private @impl(%arg0: memref, %arg1: index) -> memref + +// CHECK-LABEL: llvm.func @call( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.call @impl(%[[VAL_0]], %[[VAL_1]]) : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @call(%arg0: memref, %arg1: index) -> memref { + %res = func.call @impl(%arg0, %arg1) : (memref, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(4 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mul %[[VAL_1]], %[[VAL_2]] : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_4]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref(%arg0: memref<4x4xf32>, %arg1: index) -> memref<4xf32> { + %res = "polygeist.subindex"(%arg0 , %arg1) : (memref<4x4xf32>, index) -> memref<4xf32> + return %res : memref<4xf32> +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_same_dim( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_same_dim(%arg0: memref<4x4xf32>, %arg1: index) -> memref<4x4xf32> { + %res = "polygeist.subindex"(%arg0 , %arg1) : (memref<4x4xf32>, index) -> memref<4x4xf32> + return %res : memref<4x4xf32> +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_struct( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0] : (!llvm.ptr>, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_struct(%arg0: memref<4x!llvm.struct<(f32)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(f32)>>, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_nested_struct( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>>) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_nested_struct(%arg0: memref<4x!llvm.struct<(struct<(f32)>)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(struct<(f32)>)>>, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_ptr( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>)>>) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr>)>>, i64, i64, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_nested_struct_ptr(%arg0: memref<4x!llvm.struct<(ptr>)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(ptr>)>>, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_array( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>)>>) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], 0] : (!llvm.ptr>)>>, i64, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @subindexop_memref_nested_struct_array(%arg0: memref<4x!llvm.struct<(array<4x!llvm.struct<(f32)>>)>>) -> memref { + %c_0 = arith.constant 0 : index + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(array<4x!llvm.struct<(f32)>>)>>, index) -> memref + return %res : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @memref2ptr( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @memref2ptr(%arg0: memref<4xf32>) -> !llvm.ptr { + %res = "polygeist.memref2pointer"(%arg0) : (memref<4xf32>) -> !llvm.ptr + return %res : !llvm.ptr +} + +// ----- + +// CHECK-LABEL: llvm.func @ptr2memref( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-NEXT: } + +func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { + %res = "polygeist.pointer2memref"(%arg0) : (!llvm.ptr) -> memref + return %res : memref +} + +// ----- + +#layout = affine_map<(s0) -> (s0 - 1)> + +// CHECK-LABEL: llvm.func @non_bare_due_to_layout( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(-1 : index) : i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.add %[[VAL_3]], %[[VAL_1]] : i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_6]] : i64 +// CHECK-NEXT: } + +func.func private @non_bare_due_to_layout(%arg0: memref<100xi64, #layout>, %arg1: index) -> i64 { + %res = memref.load %arg0[%arg1] : memref<100xi64, #layout> + return %res : i64 +} diff --git a/polygeist/test/polygeist-opt/bareptrlowering.mlir b/polygeist/test/polygeist-opt/bareptrlowering.mlir index a09d8f7660a03..327c2dc75c842 100644 --- a/polygeist/test/polygeist-opt/bareptrlowering.mlir +++ b/polygeist/test/polygeist-opt/bareptrlowering.mlir @@ -1,36 +1,36 @@ -// RUN: polygeist-opt %s --convert-polygeist-to-llvm --split-input-file | FileCheck %s +// RUN: polygeist-opt %s --convert-polygeist-to-llvm='use-opaque-pointers=1' --split-input-file | FileCheck %s -// CHECK-LABEL: llvm.func @ptr_ret_static(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_ret_static(i64) -> !llvm.ptr func.func private @ptr_ret_static(%arg0: i64) -> memref<4xi64> // ----- -// CHECK-LABEL: llvm.func @ptr_ret_dynamic(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_ret_dynamic(i64) -> !llvm.ptr func.func private @ptr_ret_dynamic(%arg0: i64) -> memref // ----- -// CHECK-LABEL: llvm.func @ptr_ret_nd_static(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_ret_nd_static(i64) -> !llvm.ptr func.func private @ptr_ret_nd_static(%arg0: i64) -> memref<4x4xi64> // ----- -// CHECK-LABEL: llvm.func @ptr_ret_nd_dynamic(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_ret_nd_dynamic(i64) -> !llvm.ptr func.func private @ptr_ret_nd_dynamic(%arg0: i64) -> memref // ----- -// CHECK-LABEL: llvm.func @ptr_args_and_ret(!llvm.ptr, !llvm.ptr) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_args_and_ret(!llvm.ptr, !llvm.ptr) -> !llvm.ptr func.func private @ptr_args_and_ret(%arg0: memref<1xi64>, %arg1: memref) -> memref // ----- -// CHECK-LABEL: llvm.func @ptr_args_and_ret_with_attrs(!llvm.ptr {llvm.byval = i64}, !llvm.ptr {llvm.byval = i64}) -> !llvm.ptr +// CHECK-LABEL: llvm.func @ptr_args_and_ret_with_attrs(!llvm.ptr {llvm.byval = i64}, !llvm.ptr {llvm.byval = i64}) -> !llvm.ptr func.func private @ptr_args_and_ret_with_attrs(%arg0: memref<1xi64> {llvm.byval = i64}, %arg1: memref {llvm.byval = i64}) -> memref @@ -40,8 +40,8 @@ func.func private @ptr_args_and_ret_with_attrs(%arg0: memref<1xi64> {llvm.byval gpu.module @kernels { // CHECK-LABEL: llvm.func @kernel( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr {llvm.byval = i64}, -// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr {llvm.byval = i64}) attributes {gpu.kernel, workgroup_attributions = 0 : i64} { +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr {llvm.byval = i64}, +// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr {llvm.byval = i64}) attributes {gpu.kernel, workgroup_attributions = 0 : i64} { // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -57,10 +57,10 @@ gpu.module @kernels { memref.global @global : memref<3xi64> -// CHECK-LABEL: llvm.func @get_global() -> !llvm.ptr -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global : !llvm.ptr> -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr>) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-LABEL: llvm.func @get_global() -> !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global : !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr) -> !llvm.ptr, i64 +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK-NEXT: } func.func private @get_global() -> memref<3xi64> { @@ -74,10 +74,10 @@ func.func private @get_global() -> memref<3xi64> { memref.global @global_addrspace : memref<3xi64, 4> -// CHECK-LABEL: llvm.func @get_global_addrspace() -> !llvm.ptr -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global_addrspace : !llvm.ptr, 4> -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr, 4>) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-LABEL: llvm.func @get_global_addrspace() -> !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.addressof @global_addrspace : !llvm.ptr<4> +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.getelementptr inbounds %[[VAL_0]][0, 0] : (!llvm.ptr<4>) -> !llvm.ptr<4>, i64 +// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr<4> // CHECK-NEXT: } func.func private @get_global_addrspace() -> memref<3xi64, 4> { @@ -90,9 +90,8 @@ func.func private @get_global_addrspace() -> memref<3xi64, 4> { memref.global "private" constant @shape : memref<2xi64> = dense<[2, 2]> // CHECK-LABEL: llvm.func @reshape( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr>) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @reshape(%arg0: memref<4xi32>) -> memref<2x2xi32> { @@ -106,9 +105,8 @@ func.func private @reshape(%arg0: memref<4xi32>) -> memref<2x2xi32> { memref.global "private" constant @shape : memref<1xindex> // CHECK-LABEL: llvm.func @reshape_dyn( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK: %[[VAL_2:.*]] = llvm.getelementptr inbounds %{{.*}}[0, 0] : (!llvm.ptr>) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @reshape_dyn(%arg0: memref<4xi32>) -> memref { @@ -120,11 +118,11 @@ func.func private @reshape_dyn(%arg0: memref<4xi32>) -> memref { // ----- // CHECK-LABEL: llvm.func @alloca() -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -136,11 +134,11 @@ func.func private @alloca() { // ----- // CHECK-LABEL: llvm.func @alloca_nd() -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -152,11 +150,11 @@ func.func private @alloca_nd() { // ----- // CHECK-LABEL: llvm.func @alloca_aligned() -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(2 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -168,11 +166,11 @@ func.func private @alloca_aligned() { // ----- // CHECK-LABEL: llvm.func @alloca_nd_aligned() -// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.null : !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.ptrtoint %[[VAL_2]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.alloca %[[VAL_3]] x i32 {alignment = 8 : i64} : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -183,16 +181,15 @@ func.func private @alloca_nd_aligned() { // ----- -// CHECK-LABEL: llvm.func @malloc(i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @malloc(i64) -> !llvm.ptr // CHECK-LABEL: llvm.func @alloc() // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_5:.*]] = llvm.call @malloc(%[[VAL_4]]) : (i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.bitcast %[[VAL_5]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.call @malloc(%[[VAL_4]]) : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -210,11 +207,10 @@ func.func private @alloc() { // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_9:.*]] = llvm.call @malloc(%[[VAL_8]]) : (i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_10:.*]] = llvm.bitcast %[[VAL_9]] : !llvm.ptr to !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.call @malloc(%[[VAL_8]]) : (i64) -> !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -228,20 +224,19 @@ func.func private @alloc_nd() { // CHECK-LABEL: llvm.func @alloc_aligned() // CHECK-NEXT: %[[VAL_0:.*]] = llvm.mlir.constant(2 : index) : i64 // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(1 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_2]]{{\[}}%[[VAL_0]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.ptrtoint %[[VAL_3]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_4]], %[[VAL_5]] : i64 -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.call @malloc(%[[VAL_6]]) : (i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: %[[VAL_9:.*]] = llvm.ptrtoint %[[VAL_8]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK-NEXT: %[[VAL_11:.*]] = llvm.sub %[[VAL_5]], %[[VAL_10]] : i64 -// CHECK-NEXT: %[[VAL_12:.*]] = llvm.add %[[VAL_9]], %[[VAL_11]] : i64 -// CHECK-NEXT: %[[VAL_13:.*]] = llvm.urem %[[VAL_12]], %[[VAL_5]] : i64 -// CHECK-NEXT: %[[VAL_14:.*]] = llvm.sub %[[VAL_12]], %[[VAL_13]] : i64 -// CHECK-NEXT: %[[VAL_15:.*]] = llvm.inttoptr %[[VAL_14]] : i64 to !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.call @malloc(%[[VAL_6]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_9:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-NEXT: %[[VAL_10:.*]] = llvm.sub %[[VAL_5]], %[[VAL_9]] : i64 +// CHECK-NEXT: %[[VAL_11:.*]] = llvm.add %[[VAL_8]], %[[VAL_10]] : i64 +// CHECK-NEXT: %[[VAL_12:.*]] = llvm.urem %[[VAL_11]], %[[VAL_5]] : i64 +// CHECK-NEXT: %[[VAL_13:.*]] = llvm.sub %[[VAL_11]], %[[VAL_12]] : i64 +// CHECK-NEXT: %[[VAL_14:.*]] = llvm.inttoptr %[[VAL_13]] : i64 to !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -259,20 +254,19 @@ func.func private @alloc_aligned() { // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(20 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mlir.constant(60 : index) : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.mlir.null : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_6]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: %[[VAL_8:.*]] = llvm.ptrtoint %[[VAL_7]] : !llvm.ptr to i64 // CHECK-NEXT: %[[VAL_9:.*]] = llvm.mlir.constant(8 : index) : i64 // CHECK-NEXT: %[[VAL_10:.*]] = llvm.add %[[VAL_8]], %[[VAL_9]] : i64 -// CHECK-NEXT: %[[VAL_11:.*]] = llvm.call @malloc(%[[VAL_10]]) : (i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_12:.*]] = llvm.bitcast %[[VAL_11]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: %[[VAL_13:.*]] = llvm.ptrtoint %[[VAL_12]] : !llvm.ptr to i64 -// CHECK-NEXT: %[[VAL_14:.*]] = llvm.mlir.constant(1 : i64) : i64 -// CHECK-NEXT: %[[VAL_15:.*]] = llvm.sub %[[VAL_9]], %[[VAL_14]] : i64 -// CHECK-NEXT: %[[VAL_16:.*]] = llvm.add %[[VAL_13]], %[[VAL_15]] : i64 -// CHECK-NEXT: %[[VAL_17:.*]] = llvm.urem %[[VAL_16]], %[[VAL_9]] : i64 -// CHECK-NEXT: %[[VAL_18:.*]] = llvm.sub %[[VAL_16]], %[[VAL_17]] : i64 -// CHECK-NEXT: %[[VAL_19:.*]] = llvm.inttoptr %[[VAL_18]] : i64 to !llvm.ptr +// CHECK-NEXT: %[[VAL_11:.*]] = llvm.call @malloc(%[[VAL_10]]) : (i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_12:.*]] = llvm.ptrtoint %[[VAL_11]] : !llvm.ptr to i64 +// CHECK-NEXT: %[[VAL_13:.*]] = llvm.mlir.constant(1 : i64) : i64 +// CHECK-NEXT: %[[VAL_14:.*]] = llvm.sub %[[VAL_9]], %[[VAL_13]] : i64 +// CHECK-NEXT: %[[VAL_15:.*]] = llvm.add %[[VAL_12]], %[[VAL_14]] : i64 +// CHECK-NEXT: %[[VAL_16:.*]] = llvm.urem %[[VAL_15]], %[[VAL_9]] : i64 +// CHECK-NEXT: %[[VAL_17:.*]] = llvm.sub %[[VAL_15]], %[[VAL_16]] : i64 +// CHECK-NEXT: %[[VAL_18:.*]] = llvm.inttoptr %[[VAL_17]] : i64 to !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -284,9 +278,8 @@ func.func private @alloc_nd_aligned() { // ----- // CHECK-LABEL: llvm.func @dealloc( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: llvm.call @free(%[[VAL_1]]) : (!llvm.ptr) -> () +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) +// CHECK-NEXT: llvm.call @free(%[[VAL_0]]) : (!llvm.ptr) -> () // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -298,8 +291,8 @@ func.func private @dealloc(%arg0: memref) { // ----- // CHECK-LABEL: llvm.func @cast( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } func.func private @cast(%arg0: memref<2xi32>) -> memref { @@ -310,10 +303,10 @@ func.func private @cast(%arg0: memref<2xi32>) -> memref { // ----- // CHECK-LABEL: llvm.func @load( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64) -> f32 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.load %[[VAL_2]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.load %[[VAL_2]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_3]] : f32 // CHECK-NEXT: } @@ -325,14 +318,14 @@ func.func private @load(%arg0: memref<100xf32>, %index: index) -> f32 { // ----- // CHECK-LABEL: llvm.func @load_nd( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, // CHECK-SAME: %[[VAL_2:.*]]: i64) -> f32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_7]] : f32 // CHECK-NEXT: } @@ -344,14 +337,14 @@ func.func private @load_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: // ----- // CHECK-LABEL: llvm.func @load_nd_dyn( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, // CHECK-SAME: %[[VAL_2:.*]]: i64) -> f32 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mul %[[VAL_1]], %[[VAL_3]] : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.add %[[VAL_4]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_5]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.load %[[VAL_6]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_7]] : f32 // CHECK-NEXT: } @@ -363,11 +356,11 @@ func.func private @load_nd_dyn(%arg0: memref, %index0: index, %index1 // ----- // CHECK-LABEL: llvm.func @store( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, // CHECK-SAME: %[[VAL_2:.*]]: f32) -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.store %[[VAL_2]], %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.store %[[VAL_2]], %[[VAL_3]] : f32, !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -379,14 +372,14 @@ func.func private @store(%arg0: memref<100xf32>, %index: index, %val: f32) { // ----- // CHECK-LABEL: llvm.func @store_nd( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, // CHECK-SAME: %[[VAL_3:.*]]: f32) // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mul %[[VAL_1]], %[[VAL_4]] : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_5]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : f32, !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -398,14 +391,14 @@ func.func private @store_nd(%arg0: memref<100x100xf32>, %index0: index, %index1: // ----- // CHECK-LABEL: llvm.func @store_nd_dyn( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, // CHECK-SAME: %[[VAL_1:.*]]: i64, %[[VAL_2:.*]]: i64, // CHECK-SAME: %[[VAL_3:.*]]: f32) // CHECK-NEXT: %[[VAL_4:.*]] = llvm.mlir.constant(100 : index) : i64 // CHECK-NEXT: %[[VAL_5:.*]] = llvm.mul %[[VAL_1]], %[[VAL_4]] : i64 // CHECK-NEXT: %[[VAL_6:.*]] = llvm.add %[[VAL_5]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_7:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_6]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.store %[[VAL_3]], %[[VAL_7]] : f32, !llvm.ptr // CHECK-NEXT: llvm.return // CHECK-NEXT: } @@ -416,15 +409,15 @@ func.func private @store_nd_dyn(%arg0: memref, %index0: index, %index // ----- -// CHECK-LABEL: llvm.func @impl(!llvm.ptr, i64) -> !llvm.ptr +// CHECK-LABEL: llvm.func @impl(!llvm.ptr, i64) -> !llvm.ptr func.func private @impl(%arg0: memref, %arg1: index) -> memref // CHECK-LABEL: llvm.func @call( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.call @impl(%[[VAL_0]], %[[VAL_1]]) : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.call @impl(%[[VAL_0]], %[[VAL_1]]) : (!llvm.ptr, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr // CHECK-NEXT: } func.func private @call(%arg0: memref, %arg1: index) -> memref { @@ -435,12 +428,12 @@ func.func private @call(%arg0: memref, %arg1: index) -> memref { // ----- // CHECK-LABEL: llvm.func @subindexop_memref( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(4 : i64) : i64 // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mul %[[VAL_1]], %[[VAL_2]] : i64 -// CHECK-NEXT: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_4]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_4:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_3]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.return %[[VAL_4]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref(%arg0: memref<4x4xf32>, %arg1: index) -> memref<4xf32> { @@ -451,10 +444,10 @@ func.func private @subindexop_memref(%arg0: memref<4x4xf32>, %arg1: index) -> me // ----- // CHECK-LABEL: llvm.func @subindexop_memref_same_dim( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, -// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr, +// CHECK-SAME: %[[VAL_1:.*]]: i64) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.return %[[VAL_2]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_same_dim(%arg0: memref<4x4xf32>, %arg1: index) -> memref<4x4xf32> { @@ -465,11 +458,11 @@ func.func private @subindexop_memref_same_dim(%arg0: memref<4x4xf32>, %arg1: ind // ----- // CHECK-LABEL: llvm.func @subindexop_memref_struct( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0] : (!llvm.ptr>, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_struct(%arg0: memref<4x!llvm.struct<(f32)>>) -> memref { @@ -481,11 +474,11 @@ func.func private @subindexop_memref_struct(%arg0: memref<4x!llvm.struct<(f32)>> // ----- // CHECK-LABEL: llvm.func @subindexop_memref_nested_struct( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>>) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, 0] : (!llvm.ptr)>>, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_nested_struct(%arg0: memref<4x!llvm.struct<(struct<(f32)>)>>) -> memref { @@ -496,28 +489,28 @@ func.func private @subindexop_memref_nested_struct(%arg0: memref<4x!llvm.struct< // ----- -// CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_ptr( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>)>>) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr>)>>, i64, i64, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr -// CHECK-NEXT: } +// CHECK-LABEL: llvm.func @subindexop_memref_nested_ptr( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: } -func.func private @subindexop_memref_nested_struct_ptr(%arg0: memref<4x!llvm.struct<(ptr>)>>) -> memref { +func.func private @subindexop_memref_nested_ptr(%arg0: memref<4x!llvm.struct<(ptr)>>) -> memref { %c_0 = arith.constant 0 : index - %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(ptr>)>>, index) -> memref - return %res : memref + %res = "polygeist.subindex"(%arg0, %c_0) : (memref<4x!llvm.struct<(ptr)>>, index) -> memref + return %res : memref } // ----- // CHECK-LABEL: llvm.func @subindexop_memref_nested_struct_array( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr>)>>) -> !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr // CHECK-NEXT: %[[VAL_1:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-NEXT: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], 0, %[[VAL_2]], 0] : (!llvm.ptr>)>>, i64, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_3:.*]] = llvm.getelementptr %[[VAL_0]]{{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]], %[[VAL_1]]] : (!llvm.ptr, i64, i64, i64, i64) -> !llvm.ptr, f32 +// CHECK-NEXT: llvm.return %[[VAL_3]] : !llvm.ptr // CHECK-NEXT: } func.func private @subindexop_memref_nested_struct_array(%arg0: memref<4x!llvm.struct<(array<4x!llvm.struct<(f32)>>)>>) -> memref { @@ -529,26 +522,24 @@ func.func private @subindexop_memref_nested_struct_array(%arg0: memref<4x!llvm.s // ----- // CHECK-LABEL: llvm.func @memref2ptr( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } -func.func private @memref2ptr(%arg0: memref<4xf32>) -> !llvm.ptr { - %res = "polygeist.memref2pointer"(%arg0) : (memref<4xf32>) -> !llvm.ptr - return %res : !llvm.ptr +func.func private @memref2ptr(%arg0: memref<4xf32>) -> !llvm.ptr { + %res = "polygeist.memref2pointer"(%arg0) : (memref<4xf32>) -> !llvm.ptr + return %res : !llvm.ptr } // ----- // CHECK-LABEL: llvm.func @ptr2memref( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr -// CHECK-NEXT: llvm.return %[[VAL_1]] : !llvm.ptr +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr +// CHECK-NEXT: llvm.return %[[VAL_0]] : !llvm.ptr // CHECK-NEXT: } -func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { - %res = "polygeist.pointer2memref"(%arg0) : (!llvm.ptr) -> memref +func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { + %res = "polygeist.pointer2memref"(%arg0) : (!llvm.ptr) -> memref return %res : memref } @@ -557,13 +548,13 @@ func.func private @ptr2memref(%arg0: !llvm.ptr) -> memref { #layout = affine_map<(s0) -> (s0 - 1)> // CHECK-LABEL: llvm.func @non_bare_due_to_layout( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>, // CHECK-SAME: %[[VAL_1:.*]]: i64) -> i64 -// CHECK-NEXT: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> +// CHECK-NEXT: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_0]][1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK-NEXT: %[[VAL_3:.*]] = llvm.mlir.constant(-1 : index) : i64 // CHECK-NEXT: %[[VAL_4:.*]] = llvm.add %[[VAL_3]], %[[VAL_1]] : i64 -// CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr -// CHECK-NEXT: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr +// CHECK-NEXT: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_2]][%[[VAL_4]]] : (!llvm.ptr, i64) -> !llvm.ptr, i64 +// CHECK-NEXT: %[[VAL_6:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr // CHECK-NEXT: llvm.return %[[VAL_6]] : i64 // CHECK-NEXT: } diff --git a/polygeist/test/polygeist-opt/sycl/cast-typed-pointer.mlir b/polygeist/test/polygeist-opt/sycl/cast-typed-pointer.mlir new file mode 100644 index 0000000000000..07129e4873b59 --- /dev/null +++ b/polygeist/test/polygeist-opt/sycl/cast-typed-pointer.mlir @@ -0,0 +1,45 @@ +// RUN: polygeist-opt --convert-polygeist-to-llvm='use-opaque-pointers=0' --split-input-file %s | FileCheck %s + +!sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> +!sycl_range_1_ = !sycl.range<[1], (!sycl_array_1_)> + +// CHECK-LABEL: llvm.func @test1( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>>) -> !llvm.ptr)>> { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>> to !llvm.ptr)>> +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>> +// CHECK: } + +func.func @test1(%arg0: memref) -> memref { + %0 = "sycl.cast"(%arg0) : (memref) -> memref + func.return %0 : memref +} + +// ----- + +// CHECK-LABEL: llvm.func @test2( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>>) -> !llvm.ptr)>> { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>> to !llvm.ptr)>> +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>> +// CHECK: } + +!sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> +!sycl_id_1_ = !sycl.id<[1], (!sycl_array_1_)> +func.func @test2(%arg0: memref) -> memref { + %0 = "sycl.cast"(%arg0) : (memref) -> memref + func.return %0: memref +} + +// ----- + +// CHECK-LABEL: llvm.func @test_addrspaces( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>, 4>) -> !llvm.ptr)>, 4> { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>, 4> to !llvm.ptr)>, 4> +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>, 4> +// CHECK: } + +!sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> +!sycl_id_1_ = !sycl.id<[1], (!sycl_array_1_)> +func.func @test_addrspaces(%arg0: memref) -> memref { + %0 = "sycl.cast"(%arg0) : (memref) -> memref + func.return %0: memref +} diff --git a/polygeist/test/polygeist-opt/sycl/cast.mlir b/polygeist/test/polygeist-opt/sycl/cast.mlir index 4d52f3c1344e5..7bd28e4428911 100644 --- a/polygeist/test/polygeist-opt/sycl/cast.mlir +++ b/polygeist/test/polygeist-opt/sycl/cast.mlir @@ -1,45 +1,45 @@ -// RUN: polygeist-opt --convert-polygeist-to-llvm --split-input-file %s | FileCheck %s +// RUN: polygeist-opt --convert-polygeist-to-llvm='use-opaque-pointers=1' --split-input-file %s | FileCheck %s !sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> !sycl_range_1_ = !sycl.range<[1], (!sycl_array_1_)> // CHECK-LABEL: llvm.func @test1( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>>) -> !llvm.ptr)>> { -// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>> to !llvm.ptr)>> -// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>> +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK: } func.func @test1(%arg0: memref) -> memref { - %0 = "sycl.cast"(%arg0) : (memref) -> memref + %0 = sycl.cast %arg0 : memref to memref func.return %0 : memref } // ----- // CHECK-LABEL: llvm.func @test2( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>>) -> !llvm.ptr)>> { -// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>> to !llvm.ptr)>> -// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>> +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr) -> !llvm.ptr { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr to !llvm.ptr +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr // CHECK: } !sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> !sycl_id_1_ = !sycl.id<[1], (!sycl_array_1_)> func.func @test2(%arg0: memref) -> memref { - %0 = "sycl.cast"(%arg0) : (memref) -> memref + %0 = sycl.cast %arg0 : memref to memref func.return %0: memref } // ----- // CHECK-LABEL: llvm.func @test_addrspaces( -// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr)>)>, 4>) -> !llvm.ptr)>, 4> { -// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr)>)>, 4> to !llvm.ptr)>, 4> -// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr)>, 4> +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<4>) -> !llvm.ptr<4> { +// CHECK: %[[VAL_1:.*]] = llvm.bitcast %[[VAL_0]] : !llvm.ptr<4> to !llvm.ptr<4> +// CHECK: llvm.return %[[VAL_1]] : !llvm.ptr<4> // CHECK: } !sycl_array_1_ = !sycl.array<[1], (memref<1xi64>)> !sycl_id_1_ = !sycl.id<[1], (!sycl_array_1_)> func.func @test_addrspaces(%arg0: memref) -> memref { - %0 = "sycl.cast"(%arg0) : (memref) -> memref + %0 = sycl.cast %arg0 : memref to memref func.return %0: memref } diff --git a/polygeist/test/polygeist-opt/sycl/subindex-typed-pointer.mlir b/polygeist/test/polygeist-opt/sycl/subindex-typed-pointer.mlir new file mode 100644 index 0000000000000..98c3c057a4eaa --- /dev/null +++ b/polygeist/test/polygeist-opt/sycl/subindex-typed-pointer.mlir @@ -0,0 +1,68 @@ +// RUN: polygeist-opt --convert-polygeist-to-llvm='use-opaque-pointers=0' --split-input-file %s | FileCheck %s + +// CHECK-LABEL: @test_1 +// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr %{{.*}}[[[ZERO]], 0] : (!llvm.ptr !llvm.ptr<[[SYCLIDSTRUCT]], {{.*}} +// CHECK-NEXT: llvm.return [[GEP]] + +!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> +func.func @test_1(%arg0: memref>) -> memref { + %c0 = arith.constant 0 : index + %0 = "polygeist.subindex"(%arg0, %c0) : (memref>, index) -> memref + return %0 : memref +} + +// ----- + +// CHECK-LABEL: @test_2 +// CHECK: llvm.return %{{.*}} : !llvm.ptr)>)> +!sycl_range_1_ = !sycl.range<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> +!sycl_accessor_impl_device_1_ = !sycl.accessor_impl_device<[1], (!sycl_id_1_, !sycl_range_1_, !sycl_range_1_)> +!sycl_accessor_1_ = !sycl.accessor<[1, i32, read_write, global_buffer], (!sycl.accessor_impl_device<[1], (!sycl_id_1_, !sycl_range_1_, !sycl_range_1_)>, !llvm.struct<(ptr)>)> + +func.func @test_2(%arg0: memref) -> memref { + %c0 = arith.constant 0 : index + %0 = "polygeist.subindex"(%arg0, %c0) : (memref, index) -> memref + return %0 : memref +} + +// ----- + +// CHECK: llvm.func @test_3([[A0:.*]]: !llvm.ptr>) -> !llvm.ptr { +// CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO]], 0] : (!llvm.ptr>, i64) -> !llvm.ptr +// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr + +func.func @test_3(%arg0: memref>) -> memref { + %c0 = arith.constant 0 : index + %0 = "polygeist.subindex"(%arg0, %c0) : (memref>, index) -> memref + return %0 : memref +} + +// ----- + +// CHECK: llvm.func @test_4([[A0:%.*]]: !llvm.ptr\)>\)>]])>>, [[A5:%.*]]: i64) -> !llvm.ptr)>)>)>> { +// CHECK: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[A5]]] : (!llvm.ptr>, i64) -> !llvm.ptr> +// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr> + +!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> +func.func @test_4(%arg0: memref<1x!llvm.struct<(!sycl_id_1_)>>, %arg1: index) -> memref> { + %0 = "polygeist.subindex"(%arg0, %arg1) : (memref<1x!llvm.struct<(!sycl_id_1_)>>, index) -> memref> + return %0 : memref> +} + +// ----- + +// CHECK: llvm.func @test_5([[A0:%.*]]: !llvm.ptr<[[ARRTYPE:struct<"class.sycl::_V1::detail::array.1", \(array<1 x i64>\)>]], 4>) -> !llvm.ptr { +// CHECK-DAG: [[ZERO1:%.*]] = llvm.mlir.constant(0 : index) : i64 +// CHECK-DAG: [[ZERO2:%.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO2]], 0, [[ZERO1]]] : (!llvm.ptr<[[ARRTYPE]], 4>, i64, i64) -> !llvm.ptr + +!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> +func.func @test_5(%arg0: memref)>, 4>) -> memref<1xi64, 4> { + %c0 = arith.constant 0 : index + %0 = "polygeist.subindex"(%arg0, %c0) : (memref)>, 4>, index) -> memref<1xi64, 4> + return %0 : memref<1xi64, 4> +} diff --git a/polygeist/test/polygeist-opt/sycl/subindex.mlir b/polygeist/test/polygeist-opt/sycl/subindex.mlir index 4fb19075221e0..a917b5a6e919c 100644 --- a/polygeist/test/polygeist-opt/sycl/subindex.mlir +++ b/polygeist/test/polygeist-opt/sycl/subindex.mlir @@ -1,8 +1,8 @@ -// RUN: polygeist-opt --convert-polygeist-to-llvm --split-input-file %s | FileCheck %s +// RUN: polygeist-opt --convert-polygeist-to-llvm='use-opaque-pointers=1' --split-input-file %s | FileCheck %s // CHECK-LABEL: @test_1 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr %{{.*}}[[[ZERO]], 0] : (!llvm.ptr !llvm.ptr<[[SYCLIDSTRUCT]], {{.*}} +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr %{{.*}}[[[ZERO]], 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<"class.sycl::_V1::id.1", {{.*}} // CHECK-NEXT: llvm.return [[GEP]] !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> @@ -15,7 +15,8 @@ func.func @test_1(%arg0: memref>) -> memref !llvm.ptr, !llvm.struct<"class.sycl::_V1::detail::AccessorImplDevice.1", {{.*}} +// CHECK-NEXT: llvm.return [[GEP]] !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> !sycl_range_1_ = !sycl.range<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> @@ -30,10 +31,11 @@ func.func @test_2(%arg0: memref) -> memref>) -> !llvm.ptr { +// CHECK: llvm.func @test_3([[A0:.*]]: !llvm.ptr) -> !llvm.ptr { +// CHECK: [[IDX_ZERO:%.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO]], 0] : (!llvm.ptr>, i64) -> !llvm.ptr -// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO]], [[IDX_ZERO]]] : (!llvm.ptr, i64, i64) -> !llvm.ptr, i32 +// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr func.func @test_3(%arg0: memref>) -> memref { %c0 = arith.constant 0 : index @@ -43,9 +45,9 @@ func.func @test_3(%arg0: memref>) -> memref { // ----- -// CHECK: llvm.func @test_4([[A0:%.*]]: !llvm.ptr\)>\)>]])>>, [[A5:%.*]]: i64) -> !llvm.ptr)>)>)>> { -// CHECK: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[A5]]] : (!llvm.ptr>, i64) -> !llvm.ptr> -// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr> +// CHECK: llvm.func @test_4([[A0:%.*]]: !llvm.ptr, [[A5:%.*]]: i64) -> !llvm.ptr { +// CHECK: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[A5]]] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(struct<"class.sycl::_V1::id.1", {{.*}})> +// CHECK-NEXT: llvm.return [[GEP]] : !llvm.ptr !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> func.func @test_4(%arg0: memref<1x!llvm.struct<(!sycl_id_1_)>>, %arg1: index) -> memref> { @@ -55,10 +57,10 @@ func.func @test_4(%arg0: memref<1x!llvm.struct<(!sycl_id_1_)>>, %arg1: index) -> // ----- -// CHECK: llvm.func @test_5([[A0:%.*]]: !llvm.ptr<[[ARRTYPE:struct<"class.sycl::_V1::detail::array.1", \(array<1 x i64>\)>]], 4>) -> !llvm.ptr { +// CHECK: llvm.func @test_5([[A0:%.*]]: !llvm.ptr<4>) -> !llvm.ptr<4> { // CHECK-DAG: [[ZERO1:%.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-DAG: [[ZERO2:%.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO2]], 0, [[ZERO1]]] : (!llvm.ptr<[[ARRTYPE]], 4>, i64, i64) -> !llvm.ptr +// CHECK-NEXT: [[GEP:%.*]] = llvm.getelementptr [[A0]][[[ZERO2]], [[ZERO2]], [[ZERO1]]] : (!llvm.ptr<4>, i64, i64, i64) -> !llvm.ptr<4>, i64 !sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)> func.func @test_5(%arg0: memref)>, 4>) -> memref<1xi64, 4> {