From 8f7475d665ee4d89429372ff7e9374ab848a1fc2 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 7 Aug 2020 15:55:00 +0300 Subject: [PATCH 001/259] code --- mlir-compiler/CMakeLists.txt | 49 ++++ mlir-compiler/lowering.cpp | 430 ++++++++++++++++++++++++++++++++++ mlir-compiler/lowering.hpp | 5 + mlir-compiler/module.cpp | 18 ++ mlir-compiler/type_parser.cpp | 24 ++ mlir-compiler/type_parser.hpp | 5 + numba/core/lowering.py | 7 +- 7 files changed, 537 insertions(+), 1 deletion(-) create mode 100644 mlir-compiler/CMakeLists.txt create mode 100644 mlir-compiler/lowering.cpp create mode 100644 mlir-compiler/lowering.hpp create mode 100644 mlir-compiler/module.cpp create mode 100644 mlir-compiler/type_parser.cpp create mode 100644 mlir-compiler/type_parser.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt new file mode 100644 index 00000000000..5dcf303e482 --- /dev/null +++ b/mlir-compiler/CMakeLists.txt @@ -0,0 +1,49 @@ +cmake_minimum_required(VERSION 3.5) + +project(mlir_compiler LANGUAGES CXX) + +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +find_package(pybind11 REQUIRED) + +find_package(LLVM REQUIRED CONFIG) +find_package(MLIR REQUIRED CONFIG) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") + +include(HandleLLVMOptions) + +set(SOURCES_LIST + lowering.cpp + module.cpp + type_parser.cpp + ) +set(HEADERS_LIST + lowering.hpp + type_parser.hpp + ) + +pybind11_add_module(${PROJECT_NAME} ${SOURCES_LIST} ${HEADERS_LIST}) + +if (MSVC) + target_compile_options(${PROJECT_NAME} PRIVATE /EHsc) +endif () + +target_compile_definitions(${PROJECT_NAME} PRIVATE ${LLVM_DEFINITIONS}) + +target_link_libraries(${PROJECT_NAME} + PRIVATE + MLIRSupport + MLIRLLVMIR + MLIRStandardOps + MLIRTargetLLVMIR + ) + +target_include_directories(${PROJECT_NAME} + PRIVATE + ./include + ${LLVM_INCLUDE_DIRS} + ${MLIR_INCLUDE_DIRS} + ) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp new file mode 100644 index 00000000000..c4d7db505b4 --- /dev/null +++ b/mlir-compiler/lowering.cpp @@ -0,0 +1,430 @@ +#include "lowering.hpp" + +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include + +#include + +#include + +#include "type_parser.hpp" + +#include + +namespace py = pybind11; +namespace mllvm = mlir::LLVM; +namespace +{ +[[noreturn]] void report_error(const llvm::Twine& msg) +{ + auto str = msg.str(); + throw std::exception(str.c_str()); +} + +template +std::string to_str(T& obj) +{ + std::string ret; + llvm::raw_string_ostream stream(ret); + obj.print(stream); + stream.flush(); + return ret; +} + +mllvm::LLVMDialect& get_dialect(mlir::MLIRContext& ctx) +{ + auto dialect = ctx.getRegisteredDialect(); + assert(nullptr != dialect); + return *dialect; +} + +std::vector> get_blocks(const py::object& func) +{ + std::vector> ret; + auto blocks = func.attr("blocks").cast(); + ret.reserve(blocks.size()); + for (auto it : blocks) + { + ret.push_back({it.first.cast(), it.second}); + } + return ret; +} + +py::list get_body(const py::handle& block) +{ + return block.attr("body").cast(); +} + +struct scoped_goto_block +{ + scoped_goto_block(mlir::OpBuilder& b, mlir::Block* new_block): + builder(b), + old_block(b.getBlock()) + { + builder.setInsertionPointToEnd(new_block); + } + + ~scoped_goto_block() + { + builder.setInsertionPointToEnd(old_block); + } + + mlir::OpBuilder& builder; + mlir::Block* old_block = nullptr; +}; + +struct inst_handles +{ + inst_handles() + { + auto mod = py::module::import("numba.core.ir"); + Assign = mod.attr("Assign"); + Del = mod.attr("Del"); + Return = mod.attr("Return"); + + Arg = mod.attr("Arg"); + Const = mod.attr("Const"); + Global = mod.attr("Global"); + Expr = mod.attr("Expr"); + + auto ops = py::module::import("operator"); + + add = ops.attr("add"); + gt = ops.attr("gt"); + } + + py::handle Assign; + py::handle Del; + py::handle Return; + + py::handle Arg; + py::handle Const; + py::handle Global; + py::handle Expr; + + py::handle add; + py::handle gt; +}; + +struct type_cache +{ + using Type = mllvm::LLVMType; + + Type get_type(mllvm::LLVMDialect& dialect, llvm::StringRef str) + { + assert(!str.empty()); + auto s = str.str(); + auto it = typemap.find(s); + if (typemap.end() != it) + { + return it->second; + } + auto type = parse_type(dialect, str); + typemap[s] = type; + return type; + } + +private: + std::unordered_map typemap; +}; + +struct lowerer +{ + lowerer(): + dialect(get_dialect(ctx)), + builder(&ctx) + { + + } + + void lower(const py::object& compilation_context, const py::object& func_ir) + { + auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); + auto typ = get_func_type(compilation_context["fntype"]); + func = builder.create(builder.getUnknownLoc(), "test", typ); + lower_func_body(func_ir); + mod.push_back(func); + mod.dump(); + auto llvmmod = mlir::translateModuleToLLVMIR(mod); + llvmmod->dump(); + } +private: + mlir::MLIRContext ctx; + mllvm::LLVMDialect& dialect; + mlir::OpBuilder builder; + mllvm::LLVMFuncOp func; + mlir::Block::BlockArgListType fnargs; + mlir::Block* entry_bb = nullptr; + std::vector blocks; + std::vector locals; + std::unordered_map vars; + inst_handles insts; + type_cache types; + + void lower_func_body(const py::object& func_ir) + { + entry_bb = func.addEntryBlock(); + assert(func.getNumArguments() >= 2); + fnargs = func.getArguments().slice(2); + auto ir_blocks = get_blocks(func_ir); + assert(!ir_blocks.empty()); + blocks.resize(ir_blocks.size()); + std::generate(blocks.begin(), blocks.end(), [&](){ return func.addBlock(); }); + + std::size_t i = 0; + for (auto& it : ir_blocks) + { + lower_block(blocks[i], it.second); + ++i; + } + + builder.setInsertionPointToEnd(entry_bb); + builder.create(builder.getUnknownLoc(), mlir::None, blocks.front()); + } + + void lower_block(mlir::Block* bb, const py::handle& ir_block) + { + assert(nullptr != bb); + builder.setInsertionPointToEnd(bb); + for (auto it : get_body(ir_block)) + { + lower_inst(it); + } + } + + void lower_inst(const py::handle& inst) + { + if (py::isinstance(inst, insts.Assign)) + { + auto name = inst.attr("target").attr("name"); + auto val = lower_assign(inst, name); + storevar(val, inst, name); + } + else if (py::isinstance(inst, insts.Del)) + { + delvar(inst.attr("value")); + } + else if (py::isinstance(inst, insts.Return)) + { + retvar(inst.attr("value").attr("name")); + } + else + { + report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); + } + } + + mllvm::LLVMType get_ll_type(const py::handle& name) + { + return mllvm::LLVMType::getInt64Ty(&dialect); // TODO + } + + mlir::Value resolve_op(mlir::Value lhs, mlir::Value rhs, const py::handle& op) + { + // TODO unhardcode + if (op.is(insts.add)) + { + return builder.create(builder.getUnknownLoc(), lhs, rhs); + } + if (op.is(insts.gt)) + { + assert(lhs.getType() == rhs.getType()); + if (lhs.getType().cast().isIntegerTy()) + { + return builder.create(builder.getUnknownLoc(), mllvm::ICmpPredicate::sgt, lhs, rhs); + } + } + + report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); + } + + mlir::Value lower_binop(const py::handle& expr, const py::handle& op) + { + auto lhs_name = expr.attr("lhs").attr("name"); + auto rhs_name = expr.attr("rhs").attr("name"); + auto lhs = loadvar(lhs_name); + auto rhs = loadvar(rhs_name); + // TODO casts + return resolve_op(lhs, rhs, op); + } + + mlir::Value lower_expr(const py::handle& expr) + { + auto op = expr.attr("op").cast(); + if (op == "binop") + { + return lower_binop(expr, expr.attr("fn")); + } + else if (op == "cast") + { + auto val = loadvar(expr.attr("value").attr("name")); + // TODO cast + return val; + } + else + { + report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); + } + } + + mlir::Value get_const_val(const py::handle& val) + { + std::cout << "asdasd " << py::str(val).cast() << std::endl; + std::cout << "asdasd " << py::str(val.get_type()).cast() << std::endl; +// if (py::isinstance(val)) + { + auto b = val.cast(); + auto mlir_type = mllvm::LLVMType::getInt1Ty(&dialect); + auto value = builder.getBoolAttr(b); + return builder.create(builder.getUnknownLoc(), mlir_type, value); + } + report_error(llvm::Twine("get_const_val unhandled type") + py::str(val).cast()); + } + + mlir::Value lower_assign(const py::handle& inst, const py::handle& name) + { + auto value = inst.attr("value"); + if (py::isinstance(value, insts.Arg)) + { + auto index = value.attr("index").cast(); + // TODO: incref + // TODO: cast + return fnargs[index]; + } + if (py::isinstance(value, insts.Const) || py::isinstance(value, insts.Global)) + { + // TODO unhardcode + // TODO incref + auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 64); + auto val = builder.getI64IntegerAttr(value.attr("value").cast()); + return builder.create(builder.getUnknownLoc(), mlir_type, val); +// return get_const_val(value.attr("value")); + } + if(py::isinstance(value, insts.Expr)) + { + return lower_expr(value); + } + report_error(llvm::Twine("lower_assign not handled: \"") + py::str(value.get_type()).cast() + "\""); + } + + void alloca_var(const py::handle& name) + { + auto name_str = name.cast(); + if (0 == vars.count(name_str)) + { + scoped_goto_block s(builder, entry_bb); + auto size_type = mllvm::LLVMType::getIntNTy(&dialect, 64); + auto size_val = builder.getI64IntegerAttr(/*TODO*/1); + auto size = builder.create(builder.getUnknownLoc(), size_type, size_val); + auto type = get_ll_type(name); + auto ptype = type.getPointerTo(); + auto op = builder.create(builder.getUnknownLoc(), ptype, size, /*align*/0); + auto null = zero_val(type); + builder.create(builder.getUnknownLoc(), null, op); + vars[name_str] = op; + } + } + + mlir::Value get_var(const py::handle& name) + { + auto it = vars.find(name.cast()); + assert(vars.end() != it); + return it->second; + } + + mlir::Value loadvar(const py::handle& name) + { + auto type = get_ll_type(name); + return builder.create(builder.getUnknownLoc(), type, get_var(name)); + } + + void storevar(mlir::Value val, const py::handle& inst, const py::handle& name) + { + alloca_var(name); + auto old = loadvar(name); + // TODO decref old + auto ptr = get_var(name); + builder.create(builder.getUnknownLoc(), val, ptr); + } + + mlir::Value zero_val(mllvm::LLVMType type) + { + if (type.isPointerTy()) + { + return builder.create(builder.getUnknownLoc(), type); + } + else if (type.isIntegerTy()) + { + return builder.create(builder.getUnknownLoc(), type, builder.getI64IntegerAttr(0)); + } + else + { + report_error(llvm::Twine("zero_val unhandled type ") + to_str(type)); + } + } + + void delvar(const py::handle& name) + { + alloca_var(name); + auto ptr = get_var(name); + // TODO decref + + // TODO + auto type = get_ll_type(name); + auto null = zero_val(type); + builder.create(builder.getUnknownLoc(), null, ptr); + } + + void retvar(const py::handle& name) + { + alloca_var(name); + auto val = loadvar(name); + // TODO casts + + auto ret_ptr = func.getArgument(0); + builder.create(builder.getUnknownLoc(), val, ret_ptr); + + auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 32); + mlir::Value ret = builder.create(builder.getUnknownLoc(), mlir_type, builder.getI32IntegerAttr(0)); + builder.create(builder.getUnknownLoc(), ret); + } + + mllvm::LLVMType parse_type(llvm::StringRef str) + { + return types.get_type(dialect, str); + } + + mllvm::LLVMType get_func_type(const py::handle& typedesc) + { + auto get_type = [&](const auto& h) { + return parse_type(py::str(h).cast()); + }; + auto p_func = typedesc(); + using Type = mllvm::LLVMType; + auto ret = get_type(p_func.attr("return_type")); + llvm::SmallVector args; + for (auto arg : p_func.attr("args")) + { + args.push_back(get_type(arg)); + } + return Type::getFunctionTy(ret, args, false); + } +}; +} + +void lower_function(const py::object& compilation_context, const py::object& func_ir) +{ + mlir::registerDialect(); + lowerer().lower(compilation_context, func_ir); +} diff --git a/mlir-compiler/lowering.hpp b/mlir-compiler/lowering.hpp new file mode 100644 index 00000000000..803a58aa692 --- /dev/null +++ b/mlir-compiler/lowering.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include + +void lower_function(const pybind11::object& compilation_context, const pybind11::object& func_ir); diff --git a/mlir-compiler/module.cpp b/mlir-compiler/module.cpp new file mode 100644 index 00000000000..aed279bd89f --- /dev/null +++ b/mlir-compiler/module.cpp @@ -0,0 +1,18 @@ +#include + +#include "lowering.hpp" + +namespace py = pybind11; + +namespace +{ +void lower_normal_function(py::object compilation_context, py::object func_ir) +{ + lower_function(compilation_context, func_ir); +} +} + +PYBIND11_MODULE(mlir_compiler, m) +{ + m.def("lower_normal_function", &lower_normal_function, "todo"); +} diff --git a/mlir-compiler/type_parser.cpp b/mlir-compiler/type_parser.cpp new file mode 100644 index 00000000000..edbb38e5953 --- /dev/null +++ b/mlir-compiler/type_parser.cpp @@ -0,0 +1,24 @@ +#include "type_parser.hpp" + +#include + +namespace +{ +[[noreturn]] void report_error(const llvm::Twine& msg) +{ + auto str = msg.str(); + throw std::exception(str.c_str()); +} +} + +mlir::LLVM::LLVMType parse_type(mlir::LLVM::LLVMDialect& dialect, llvm::StringRef str) +{ + assert(!str.empty()); + auto mlir_type = (std::string("!llvm<\"") + str + "\">").str(); + auto res = mlir::parseType(mlir_type, dialect.getContext()).dyn_cast_or_null(); + if (mlir::Type() == res) + { + report_error(llvm::Twine("cannot parse type: \"") + str + "\""); + } + return res; +} diff --git a/mlir-compiler/type_parser.hpp b/mlir-compiler/type_parser.hpp new file mode 100644 index 00000000000..6de1c61eff5 --- /dev/null +++ b/mlir-compiler/type_parser.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include + +mlir::LLVM::LLVMType parse_type(mlir::LLVM::LLVMDialect& dialect, llvm::StringRef str); diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 1c9c19cd3b1..46fd5474c7b 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -183,6 +183,12 @@ def lower_normal_function(self, fndesc): """ Lower non-generator *fndesc*. """ + print('lower_normal_function',self.func_ir) + ctx = {} + ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) + import mlir_compiler + mlir_compiler.lower_normal_function(ctx, self.func_ir) + # self.func_ir.dump() self.setup_function(fndesc) # Init argument values @@ -278,7 +284,6 @@ def pre_block(self, block): from numba.core.unsafe import eh super(Lower, self).pre_block(block) - if block == self.firstblk: # create slots for all the vars, irrespective of whether they are # initialized, SSA will pick this up and warn users about using From 6840f691a8f583c3cf07c7f022f8f5c4328d2f72 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 7 Aug 2020 15:55:41 +0300 Subject: [PATCH 002/259] gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 340ae2678b8..7ffccb559d5 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ .nfs* tags MANIFEST +CMakeLists.txt.user build/ docs/_build/ From 189dc26a366cd8d2ab847998dbc8c6308f4440ed Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 14 Aug 2020 19:40:05 +0300 Subject: [PATCH 003/259] some work --- mlir-compiler/lowering.cpp | 52 ++++++++++++++++++++++++++++---------- mlir-compiler/lowering.hpp | 2 +- mlir-compiler/module.cpp | 4 +-- numba/core/lowering.py | 8 +++++- 4 files changed, 48 insertions(+), 18 deletions(-) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index c4d7db505b4..08c2c6085ba 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -20,6 +20,8 @@ #include "type_parser.hpp" +#include + #include namespace py = pybind11; @@ -32,6 +34,16 @@ namespace throw std::exception(str.c_str()); } +std::string serialize_mod(const llvm::Module& mod) +{ + std::string ret; + llvm::raw_string_ostream stream(ret); +// mod.print(stream, nullptr); + llvm::WriteBitcodeToFile(mod, stream); + stream.flush(); + return ret; +} + template std::string to_str(T& obj) { @@ -148,16 +160,17 @@ struct lowerer } - void lower(const py::object& compilation_context, const py::object& func_ir) + py::bytes lower(const py::object& compilation_context, const py::object& func_ir) { auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); auto typ = get_func_type(compilation_context["fntype"]); func = builder.create(builder.getUnknownLoc(), "test", typ); lower_func_body(func_ir); mod.push_back(func); - mod.dump(); +// mod.dump(); auto llvmmod = mlir::translateModuleToLLVMIR(mod); - llvmmod->dump(); +// llvmmod->dump(); + return py::bytes(serialize_mod(*llvmmod)); } private: mlir::MLIRContext ctx; @@ -282,14 +295,25 @@ struct lowerer { std::cout << "asdasd " << py::str(val).cast() << std::endl; std::cout << "asdasd " << py::str(val.get_type()).cast() << std::endl; -// if (py::isinstance(val)) + if (py::isinstance(val)) { - auto b = val.cast(); - auto mlir_type = mllvm::LLVMType::getInt1Ty(&dialect); - auto value = builder.getBoolAttr(b); + auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 64); + auto value = builder.getI64IntegerAttr(val.cast()); return builder.create(builder.getUnknownLoc(), mlir_type, value); } - report_error(llvm::Twine("get_const_val unhandled type") + py::str(val).cast()); +// if (py::isinstance(val)) +// { +// auto b = val.cast(); +// auto mlir_type = mllvm::LLVMType::getInt1Ty(&dialect); +// auto value = builder.getBoolAttr(b); +// return builder.create(builder.getUnknownLoc(), mlir_type, value); +// } + + // assume it is a PyObject* + auto mlir_type = mllvm::LLVMType::getInt8Ty(&dialect).getPointerTo(); + return builder.create(builder.getUnknownLoc(), mlir_type); + +// report_error(llvm::Twine("get_const_val unhandled type \"") + py::str(val.get_type()).cast() + "\""); } mlir::Value lower_assign(const py::handle& inst, const py::handle& name) @@ -306,10 +330,10 @@ struct lowerer { // TODO unhardcode // TODO incref - auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 64); - auto val = builder.getI64IntegerAttr(value.attr("value").cast()); - return builder.create(builder.getUnknownLoc(), mlir_type, val); -// return get_const_val(value.attr("value")); +// auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 64); +// auto val = builder.getI64IntegerAttr(value.attr("value").cast()); +// return builder.create(builder.getUnknownLoc(), mlir_type, val); + return get_const_val(value.attr("value")); } if(py::isinstance(value, insts.Expr)) { @@ -423,8 +447,8 @@ struct lowerer }; } -void lower_function(const py::object& compilation_context, const py::object& func_ir) +py::bytes lower_function(const py::object& compilation_context, const py::object& func_ir) { mlir::registerDialect(); - lowerer().lower(compilation_context, func_ir); + return lowerer().lower(compilation_context, func_ir); } diff --git a/mlir-compiler/lowering.hpp b/mlir-compiler/lowering.hpp index 803a58aa692..d5ee4e5d9b2 100644 --- a/mlir-compiler/lowering.hpp +++ b/mlir-compiler/lowering.hpp @@ -2,4 +2,4 @@ #include -void lower_function(const pybind11::object& compilation_context, const pybind11::object& func_ir); +pybind11::bytes lower_function(const pybind11::object& compilation_context, const pybind11::object& func_ir); diff --git a/mlir-compiler/module.cpp b/mlir-compiler/module.cpp index aed279bd89f..fb262dc6e6b 100644 --- a/mlir-compiler/module.cpp +++ b/mlir-compiler/module.cpp @@ -6,9 +6,9 @@ namespace py = pybind11; namespace { -void lower_normal_function(py::object compilation_context, py::object func_ir) +py::bytes lower_normal_function(py::object compilation_context, py::object func_ir) { - lower_function(compilation_context, func_ir); + return lower_function(compilation_context, func_ir); } } diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 46fd5474c7b..c72a11beaa5 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -187,7 +187,13 @@ def lower_normal_function(self, fndesc): ctx = {} ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) import mlir_compiler - mlir_compiler.lower_normal_function(ctx, self.func_ir) + mod_ir = mlir_compiler.lower_normal_function(ctx, self.func_ir) + import llvmlite.binding as llvm + mod = llvm.parse_bitcode(mod_ir) + # print(mod) + func = mod.get_function('test'); + print(func); + # self.func_ir.dump() self.setup_function(fndesc) From b1eb79e264879d7c6e123e14d25cbac098a417fb Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 19 Aug 2020 14:50:20 +0300 Subject: [PATCH 004/259] func name --- mlir-compiler/lowering.cpp | 3 ++- numba/core/lowering.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 08c2c6085ba..f50abb782c9 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -164,7 +164,8 @@ struct lowerer { auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); auto typ = get_func_type(compilation_context["fntype"]); - func = builder.create(builder.getUnknownLoc(), "test", typ); + auto name = compilation_context["fnname"]().cast(); + func = builder.create(builder.getUnknownLoc(), name, typ); lower_func_body(func_ir); mod.push_back(func); // mod.dump(); diff --git a/numba/core/lowering.py b/numba/core/lowering.py index c72a11beaa5..da1eac8cb1c 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -186,13 +186,14 @@ def lower_normal_function(self, fndesc): print('lower_normal_function',self.func_ir) ctx = {} ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) + ctx['fnname'] = lambda: fndesc.mangled_name import mlir_compiler mod_ir = mlir_compiler.lower_normal_function(ctx, self.func_ir) import llvmlite.binding as llvm mod = llvm.parse_bitcode(mod_ir) - # print(mod) - func = mod.get_function('test'); - print(func); + print(mod) + # func = mod.get_function('test'); + # print(func); # self.func_ir.dump() self.setup_function(fndesc) From e84bd3b6459ce0fe9b40e8427c680d8baf750534 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 19 Aug 2020 15:10:49 +0300 Subject: [PATCH 005/259] use mlir function code --- numba/core/lowering.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/numba/core/lowering.py b/numba/core/lowering.py index da1eac8cb1c..7fc75aa20a9 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -171,7 +171,7 @@ def lower(self): self.context.post_lowering(self.module, self.library) # Materialize LLVM Module - self.library.add_ir_module(self.module) + # self.library.add_ir_module(self.module) def extract_function_arguments(self): self.fnargs = self.call_conv.decode_arguments(self.builder, @@ -198,13 +198,15 @@ def lower_normal_function(self, fndesc): # self.func_ir.dump() self.setup_function(fndesc) - # Init argument values - self.extract_function_arguments() - entry_block_tail = self.lower_function_body() + self.library.add_llvm_module(mod); - # Close tail of entry block - self.builder.position_at_end(entry_block_tail) - self.builder.branch(self.blkmap[self.firstblk]) + # # Init argument values + # self.extract_function_arguments() + # entry_block_tail = self.lower_function_body() + + # # Close tail of entry block + # self.builder.position_at_end(entry_block_tail) + # self.builder.branch(self.blkmap[self.firstblk]) def lower_function_body(self): """ From 10643f3af0aab67d98880396fc21524fd49f44f1 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 19 Aug 2020 20:50:56 +0300 Subject: [PATCH 006/259] some simple tests --- mlir-compiler/lowering.cpp | 2 -- mlir-compiler/test.py | 31 +++++++++++++++++++++++++++++++ numba/core/lowering.py | 4 ++-- 3 files changed, 33 insertions(+), 4 deletions(-) create mode 100644 mlir-compiler/test.py diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index f50abb782c9..7664879e1a5 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -294,8 +294,6 @@ struct lowerer mlir::Value get_const_val(const py::handle& val) { - std::cout << "asdasd " << py::str(val).cast() << std::endl; - std::cout << "asdasd " << py::str(val.get_type()).cast() << std::endl; if (py::isinstance(val)) { auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 64); diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py new file mode 100644 index 00000000000..a8a0bb074b8 --- /dev/null +++ b/mlir-compiler/test.py @@ -0,0 +1,31 @@ +import numba + +@numba.njit +def sum1(a): + return a + 42 + +def sum2(a, b): + return a + b + +@numba.njit +def bar(a, b): + if a > b: + return a + else: + return b + +def test(func, result, params): + print('test', func.__name__, '... ', end='') + try: + res = func(*params) + if (res != result): + raise Exception(f'Invalid value "{res}", expected "{result}"') + print('SUCCESS') + except Exception as e: + print(e) + print('FAILED') + + +test(sum1, 47, (5,)) +test(sum2, 7, (3,4)) +#print(bar(5,6)) \ No newline at end of file diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 7fc75aa20a9..d59b97c5cea 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -183,7 +183,7 @@ def lower_normal_function(self, fndesc): """ Lower non-generator *fndesc*. """ - print('lower_normal_function',self.func_ir) + # print('lower_normal_function',self.func_ir) ctx = {} ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) ctx['fnname'] = lambda: fndesc.mangled_name @@ -191,7 +191,7 @@ def lower_normal_function(self, fndesc): mod_ir = mlir_compiler.lower_normal_function(ctx, self.func_ir) import llvmlite.binding as llvm mod = llvm.parse_bitcode(mod_ir) - print(mod) + # print(mod) # func = mod.get_function('test'); # print(func); From 66925032b2ada9205a80ec9bc1f593ae02d8822b Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 21 Aug 2020 18:42:12 +0300 Subject: [PATCH 007/259] flag --- numba/core/lowering.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/numba/core/lowering.py b/numba/core/lowering.py index d59b97c5cea..49408f34c55 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -11,6 +11,7 @@ from numba.core.funcdesc import default_mangler from numba.core.environment import Environment +_use_mlir = True _VarArgItem = namedtuple("_VarArgItem", ("vararg", "index")) @@ -170,8 +171,9 @@ def lower(self): # Run target specific post lowering transformation self.context.post_lowering(self.module, self.library) - # Materialize LLVM Module - # self.library.add_ir_module(self.module) + if not _use_mlir: + # Materialize LLVM Module + self.library.add_ir_module(self.module) def extract_function_arguments(self): self.fnargs = self.call_conv.decode_arguments(self.builder, @@ -198,15 +200,16 @@ def lower_normal_function(self, fndesc): # self.func_ir.dump() self.setup_function(fndesc) - self.library.add_llvm_module(mod); - - # # Init argument values - # self.extract_function_arguments() - # entry_block_tail = self.lower_function_body() + if _use_mlir: + self.library.add_llvm_module(mod); + else: + # Init argument values + self.extract_function_arguments() + entry_block_tail = self.lower_function_body() - # # Close tail of entry block - # self.builder.position_at_end(entry_block_tail) - # self.builder.branch(self.blkmap[self.firstblk]) + # Close tail of entry block + self.builder.position_at_end(entry_block_tail) + self.builder.branch(self.blkmap[self.firstblk]) def lower_function_body(self): """ From ee9354c69b6c361541334701fd031058681446c7 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 21 Aug 2020 18:43:53 +0300 Subject: [PATCH 008/259] flag --- numba/core/lowering.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 49408f34c55..32299989799 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -185,19 +185,15 @@ def lower_normal_function(self, fndesc): """ Lower non-generator *fndesc*. """ - # print('lower_normal_function',self.func_ir) - ctx = {} - ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) - ctx['fnname'] = lambda: fndesc.mangled_name - import mlir_compiler - mod_ir = mlir_compiler.lower_normal_function(ctx, self.func_ir) - import llvmlite.binding as llvm - mod = llvm.parse_bitcode(mod_ir) - # print(mod) - # func = mod.get_function('test'); - # print(func); - - # self.func_ir.dump() + if _use_mlir: + ctx = {} + ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) + ctx['fnname'] = lambda: fndesc.mangled_name + import mlir_compiler + mod_ir = mlir_compiler.lower_normal_function(ctx, self.func_ir) + import llvmlite.binding as llvm + mod = llvm.parse_bitcode(mod_ir) + self.setup_function(fndesc) if _use_mlir: From 56309dd5a5033d37fdbd7f96e64be7116dd6bb10 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 21 Aug 2020 19:13:05 +0300 Subject: [PATCH 009/259] eq --- mlir-compiler/lowering.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 7664879e1a5..59ba965e8c8 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -113,6 +113,8 @@ struct inst_handles auto ops = py::module::import("operator"); add = ops.attr("add"); + + eq = ops.attr("eq"); gt = ops.attr("gt"); } @@ -126,6 +128,8 @@ struct inst_handles py::handle Expr; py::handle add; + + py::handle eq; py::handle gt; }; @@ -251,6 +255,14 @@ struct lowerer { return builder.create(builder.getUnknownLoc(), lhs, rhs); } + if (op.is(insts.eq)) + { + assert(lhs.getType() == rhs.getType()); + if (lhs.getType().cast().isIntegerTy()) + { + return builder.create(builder.getUnknownLoc(), mllvm::ICmpPredicate::eq, lhs, rhs); + } + } if (op.is(insts.gt)) { assert(lhs.getType() == rhs.getType()); From 590ca8f6a090f2cc1ba542ee7782d07fafc2f20e Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 24 Aug 2020 20:31:48 +0300 Subject: [PATCH 010/259] lower cond branches --- mlir-compiler/lowering.cpp | 51 ++++++++++++++++++++++++++++++++------ mlir-compiler/test.py | 7 +++--- numba/core/lowering.py | 1 + 3 files changed, 49 insertions(+), 10 deletions(-) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 59ba965e8c8..6e8861bfd08 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -104,6 +104,7 @@ struct inst_handles Assign = mod.attr("Assign"); Del = mod.attr("Del"); Return = mod.attr("Return"); + Branch = mod.attr("Branch"); Arg = mod.attr("Arg"); Const = mod.attr("Const"); @@ -121,6 +122,7 @@ struct inst_handles py::handle Assign; py::handle Del; py::handle Return; + py::handle Branch; py::handle Arg; py::handle Const; @@ -167,12 +169,14 @@ struct lowerer py::bytes lower(const py::object& compilation_context, const py::object& func_ir) { auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); + var_type_resolver = compilation_context["get_var_type"]; auto typ = get_func_type(compilation_context["fntype"]); auto name = compilation_context["fnname"]().cast(); func = builder.create(builder.getUnknownLoc(), name, typ); lower_func_body(func_ir); mod.push_back(func); // mod.dump(); + assert(mlir::succeeded(mod.verify())); auto llvmmod = mlir::translateModuleToLLVMIR(mod); // llvmmod->dump(); return py::bytes(serialize_mod(*llvmmod)); @@ -185,10 +189,12 @@ struct lowerer mlir::Block::BlockArgListType fnargs; mlir::Block* entry_bb = nullptr; std::vector blocks; + std::unordered_map blocks_map; std::vector locals; std::unordered_map vars; inst_handles insts; type_cache types; + py::handle var_type_resolver; void lower_func_body(const py::object& func_ir) { @@ -197,14 +203,16 @@ struct lowerer fnargs = func.getArguments().slice(2); auto ir_blocks = get_blocks(func_ir); assert(!ir_blocks.empty()); - blocks.resize(ir_blocks.size()); - std::generate(blocks.begin(), blocks.end(), [&](){ return func.addBlock(); }); + blocks.reserve(ir_blocks.size()); + for (std::size_t i = 0; i < ir_blocks.size(); ++i) + { + blocks.push_back(func.addBlock()); + blocks_map[ir_blocks[i].first] = blocks.back(); + } - std::size_t i = 0; - for (auto& it : ir_blocks) + for (std::size_t i = 0; i < ir_blocks.size(); ++i) { - lower_block(blocks[i], it.second); - ++i; + lower_block(blocks[i], ir_blocks[i].second); } builder.setInsertionPointToEnd(entry_bb); @@ -237,6 +245,10 @@ struct lowerer { retvar(inst.attr("value").attr("name")); } + else if (py::isinstance(inst, insts.Branch)) + { + branch(inst.attr("cond").attr("name"), inst.attr("truebr"), inst.attr("falsebr")); + } else { report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); @@ -245,7 +257,7 @@ struct lowerer mllvm::LLVMType get_ll_type(const py::handle& name) { - return mllvm::LLVMType::getInt64Ty(&dialect); // TODO + return parse_type(py::str(var_type_resolver(name)).cast()); } mlir::Value resolve_op(mlir::Value lhs, mlir::Value rhs, const py::handle& op) @@ -285,6 +297,17 @@ struct lowerer return resolve_op(lhs, rhs, op); } + mlir::Value lower_call(const py::handle& expr) + { + auto args = expr.attr("args").cast(); + auto vararg = expr.attr("vararg"); + auto kws = expr.attr("kws"); + // TODO fold args + + // TODO: hardcode for bool + return loadvar(args[0].attr("name")); + } + mlir::Value lower_expr(const py::handle& expr) { auto op = expr.attr("op").cast(); @@ -298,6 +321,10 @@ struct lowerer // TODO cast return val; } + else if (op == "call") + { + return lower_call(expr); + } else { report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); @@ -435,6 +462,16 @@ struct lowerer builder.create(builder.getUnknownLoc(), ret); } + void branch(const py::handle& cond, const py::handle& tr, const py::handle& fl) + { + auto c = loadvar(cond); + auto tr_block = blocks_map.find(tr.cast())->second; + auto fl_block = blocks_map.find(fl.cast())->second; + // TODO: casts + + builder.create(builder.getUnknownLoc(), c, tr_block, fl_block); + } + mllvm::LLVMType parse_type(llvm::StringRef str) { return types.get_type(dialect, str); diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index a8a0bb074b8..7b2f1ee3cb5 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -8,14 +8,14 @@ def sum2(a, b): return a + b @numba.njit -def bar(a, b): +def cond(a, b): if a > b: return a else: return b def test(func, result, params): - print('test', func.__name__, '... ', end='') + print('test', func.__name__, params, '... ', end='') try: res = func(*params) if (res != result): @@ -28,4 +28,5 @@ def test(func, result, params): test(sum1, 47, (5,)) test(sum2, 7, (3,4)) -#print(bar(5,6)) \ No newline at end of file +test(cond, 6, (5,6)) +test(cond, 8, (8,7)) diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 32299989799..93294b91748 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -189,6 +189,7 @@ def lower_normal_function(self, fndesc): ctx = {} ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) ctx['fnname'] = lambda: fndesc.mangled_name + ctx['get_var_type'] = lambda name: self.context.get_value_type(self.typeof(name)) import mlir_compiler mod_ir = mlir_compiler.lower_normal_function(ctx, self.func_ir) import llvmlite.binding as llvm From ad468e3ab5ec516a261a42e3007b4d225232ca1f Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 24 Aug 2020 20:38:06 +0300 Subject: [PATCH 011/259] tests --- mlir-compiler/test.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 7b2f1ee3cb5..3c7cb3391fc 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -1,23 +1,25 @@ import numba -@numba.njit def sum1(a): return a + 42 def sum2(a, b): return a + b -@numba.njit def cond(a, b): if a > b: return a else: return b -def test(func, result, params): + + +def test(func, params): print('test', func.__name__, params, '... ', end='') + result = func(*params) + wrapped = numba.njit()(func) try: - res = func(*params) + res = wrapped(*params) if (res != result): raise Exception(f'Invalid value "{res}", expected "{result}"') print('SUCCESS') @@ -26,7 +28,7 @@ def test(func, result, params): print('FAILED') -test(sum1, 47, (5,)) -test(sum2, 7, (3,4)) -test(cond, 6, (5,6)) -test(cond, 8, (8,7)) +test(sum1, (5,)) +test(sum2, (3,4)) +test(cond, (5,6)) +test(cond, (8,7)) From 9012ffc3aad261ad5c57678a36633300c6d3e71c Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 24 Aug 2020 20:40:25 +0300 Subject: [PATCH 012/259] throw error on module validation failure --- mlir-compiler/lowering.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 6e8861bfd08..42a9df49110 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -176,7 +176,10 @@ struct lowerer lower_func_body(func_ir); mod.push_back(func); // mod.dump(); - assert(mlir::succeeded(mod.verify())); + if (mlir::failed(mod.verify())) + { + report_error("MLIR module validation failed"); + } auto llvmmod = mlir::translateModuleToLLVMIR(mod); // llvmmod->dump(); return py::bytes(serialize_mod(*llvmmod)); From 862009bf6ad4b541d26237ed94e90bca4cd6b1ed Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 28 Aug 2020 19:59:01 +0300 Subject: [PATCH 013/259] var --- mlir-compiler/lowering.cpp | 23 +++++++++++++++++------ mlir-compiler/test.py | 14 +++++++++++++- 2 files changed, 30 insertions(+), 7 deletions(-) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 42a9df49110..34a98aba395 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -107,9 +107,10 @@ struct inst_handles Branch = mod.attr("Branch"); Arg = mod.attr("Arg"); + Expr = mod.attr("Expr"); + Var = mod.attr("Var"); Const = mod.attr("Const"); Global = mod.attr("Global"); - Expr = mod.attr("Expr"); auto ops = py::module::import("operator"); @@ -125,9 +126,10 @@ struct inst_handles py::handle Branch; py::handle Arg; + py::handle Expr; + py::handle Var; py::handle Const; py::handle Global; - py::handle Expr; py::handle add; @@ -367,6 +369,18 @@ struct lowerer // TODO: cast return fnargs[index]; } + if(py::isinstance(value, insts.Expr)) + { + return lower_expr(value); + } + if(py::isinstance(value, insts.Var)) + { + auto var = loadvar(value.attr("name")); + + // TODO: cast + // TODO: incref + return var; + } if (py::isinstance(value, insts.Const) || py::isinstance(value, insts.Global)) { // TODO unhardcode @@ -376,10 +390,7 @@ struct lowerer // return builder.create(builder.getUnknownLoc(), mlir_type, val); return get_const_val(value.attr("value")); } - if(py::isinstance(value, insts.Expr)) - { - return lower_expr(value); - } + report_error(llvm::Twine("lower_assign not handled: \"") + py::str(value.get_type()).cast() + "\""); } diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 3c7cb3391fc..d5c52f3ca56 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -12,6 +12,16 @@ def cond(a, b): else: return b +def var(a): + c = 1 + c = c + a + return c + +def loop(n): + res = 0 + for i in range(n): + res += i + return res def test(func, params): @@ -26,9 +36,11 @@ def test(func, params): except Exception as e: print(e) print('FAILED') - + test(sum1, (5,)) test(sum2, (3,4)) test(cond, (5,6)) test(cond, (8,7)) +test(var, (8,)) +#test(loop, (8,)) From 6fb5c73c68e4c371368e3ecd4ffa1ea1c51e3b50 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 28 Aug 2020 20:09:46 +0300 Subject: [PATCH 014/259] jump --- mlir-compiler/lowering.cpp | 12 ++++++++++++ mlir-compiler/test.py | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 34a98aba395..451118ec7ce 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -105,6 +105,7 @@ struct inst_handles Del = mod.attr("Del"); Return = mod.attr("Return"); Branch = mod.attr("Branch"); + Jump = mod.attr("Jump"); Arg = mod.attr("Arg"); Expr = mod.attr("Expr"); @@ -124,6 +125,7 @@ struct inst_handles py::handle Del; py::handle Return; py::handle Branch; + py::handle Jump; py::handle Arg; py::handle Expr; @@ -254,6 +256,10 @@ struct lowerer { branch(inst.attr("cond").attr("name"), inst.attr("truebr"), inst.attr("falsebr")); } + else if (py::isinstance(inst, insts.Jump)) + { + jump(inst.attr("target")); + } else { report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); @@ -486,6 +492,12 @@ struct lowerer builder.create(builder.getUnknownLoc(), c, tr_block, fl_block); } + void jump(const py::handle& target) + { + auto block = blocks_map.find(target.cast())->second; + builder.create(builder.getUnknownLoc(), mlir::None, block); + } + mllvm::LLVMType parse_type(llvm::StringRef str) { return types.get_type(dialect, str); diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index d5c52f3ca56..990730f2c3c 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -17,6 +17,13 @@ def var(a): c = c + a return c +def jump(a, b): + c = 3 + if a > 5: + c = c + a + c = c + b + return c + def loop(n): res = 0 for i in range(n): @@ -43,4 +50,6 @@ def test(func, params): test(cond, (5,6)) test(cond, (8,7)) test(var, (8,)) +test(jump, (1,8)) +test(jump, (7,8)) #test(loop, (8,)) From cd13cb70a5241aaeb8d5096c9c5abfa854b37776 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 28 Aug 2020 20:13:03 +0300 Subject: [PATCH 015/259] test --- mlir-compiler/test.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 990730f2c3c..a1a10cdd375 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -24,6 +24,11 @@ def jump(a, b): c = c + b return c +sum2_jit = numba.njit()(sum2) + +def call(a, b, c): + return sum2_jit(a, sum2_jit(b, c)) + def loop(n): res = 0 for i in range(n): @@ -52,4 +57,5 @@ def test(func, params): test(var, (8,)) test(jump, (1,8)) test(jump, (7,8)) +#test(call, (1,2,3)) #test(loop, (8,)) From 7a4a46045bfe79e0520eaf40f4af91a5eef38ce8 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 28 Aug 2020 20:17:59 +0300 Subject: [PATCH 016/259] tests --- mlir-compiler/test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index a1a10cdd375..634a544ce5e 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -29,6 +29,10 @@ def jump(a, b): def call(a, b, c): return sum2_jit(a, sum2_jit(b, c)) +def tuple(a,b,c): + t = (a,b,c) + return t[0] + t[1] + t[2] + def loop(n): res = 0 for i in range(n): @@ -58,4 +62,5 @@ def test(func, params): test(jump, (1,8)) test(jump, (7,8)) #test(call, (1,2,3)) +#test(tuple, (1,2,3)) #test(loop, (8,)) From 967942137282e2c444f6a970c43717e86ed1ef29 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 10 Sep 2020 21:00:41 +0300 Subject: [PATCH 017/259] initial infra --- mlir-compiler/CMakeLists.txt | 11 +++++- mlir-compiler/include/plier/CMakeLists.txt | 11 ++++++ mlir-compiler/include/plier/PlierOps.td | 39 ++++++++++++++++++++++ mlir-compiler/include/plier/dialect.hpp | 34 +++++++++++++++++++ mlir-compiler/lowering.cpp | 33 ++++++++++-------- mlir-compiler/test.py | 4 +++ 6 files changed, 118 insertions(+), 14 deletions(-) create mode 100644 mlir-compiler/include/plier/CMakeLists.txt create mode 100644 mlir-compiler/include/plier/PlierOps.td create mode 100644 mlir-compiler/include/plier/dialect.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 5dcf303e482..bb7c2057341 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -12,9 +12,13 @@ find_package(MLIR REQUIRED CONFIG) list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") - +include(TableGen) +include(AddLLVM) +include(AddMLIR) include(HandleLLVMOptions) +add_subdirectory(include/plier) + set(SOURCES_LIST lowering.cpp module.cpp @@ -23,6 +27,8 @@ set(SOURCES_LIST set(HEADERS_LIST lowering.hpp type_parser.hpp + include/plier/dialect.hpp + include/plier/PlierOps.td ) pybind11_add_module(${PROJECT_NAME} ${SOURCES_LIST} ${HEADERS_LIST}) @@ -46,4 +52,7 @@ target_include_directories(${PROJECT_NAME} ./include ${LLVM_INCLUDE_DIRS} ${MLIR_INCLUDE_DIRS} + ${PROJECT_BINARY_DIR}/include ) + +add_dependencies(${PROJECT_NAME} MLIRPlierOpsIncGen) diff --git a/mlir-compiler/include/plier/CMakeLists.txt b/mlir-compiler/include/plier/CMakeLists.txt new file mode 100644 index 00000000000..2966b672b41 --- /dev/null +++ b/mlir-compiler/include/plier/CMakeLists.txt @@ -0,0 +1,11 @@ +include_directories(${MLIR_INCLUDE_DIRS}) +set(dialect PlierOps) +set(dialect_namespace plier) +set(LLVM_TARGET_DEFINITIONS ${dialect}.td) +mlir_tablegen(${dialect}Enums.h.inc -gen-enum-decls) +mlir_tablegen(${dialect}Enums.cpp.inc -gen-enum-defs) +mlir_tablegen(${dialect}.h.inc -gen-op-decls) +mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) +mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace}) +add_public_tablegen_target(MLIR${dialect}IncGen) +add_dependencies(mlir-headers MLIR${dialect}IncGen) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td new file mode 100644 index 00000000000..f56d5169cf1 --- /dev/null +++ b/mlir-compiler/include/plier/PlierOps.td @@ -0,0 +1,39 @@ +#ifndef PLIER_OPS +#define PLIER_OPS + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def Plier_Dialect : Dialect { + let name = "plier"; + let cppNamespace = "plier"; +} + +def Plier_Type : DialectType()">, "type">, + BuildableType<"$_builder.getType<::plier::Type>()"> { +} + +class Plier_Op traits = []> : + Op; + +def CastOp : Plier_Op<"cast", []> { + let arguments = (ins + Plier_Type:$value); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value obj"> + ]; +} + + +def DelOp : Plier_Op<"del", []> { + let arguments = (ins + Plier_Type:$value); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value obj"> + ]; +} + +#endif // PLIER_OPS diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp new file mode 100644 index 00000000000..4d24b92337e --- /dev/null +++ b/mlir-compiler/include/plier/dialect.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Function.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "proto/PlierOpsEnums.h.inc" + +namespace plier +{ +using namespace mlir; // TODO: remove +#include "proto/PlierOpsDialect.h.inc" +#define GET_OP_CLASSES +#include "proto/PlierOps.h.inc" +} + +namespace plier +{ + +void register_dialect(); + +namespace types +{ +enum Kind +{ + // Dialect types. +// PyState = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_3_TYPE, +// PyObject, +}; +} + +} diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 451118ec7ce..f6508646bbf 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -161,11 +161,27 @@ struct type_cache std::unordered_map typemap; }; -struct lowerer +struct lowerer_base +{ + lowerer_base(): builder(&ctx) {} + +protected: + mlir::MLIRContext ctx; + mlir::OpBuilder builder; + mlir::Block::BlockArgListType fnargs; + mlir::Block* entry_bb = nullptr; + std::vector blocks; + std::unordered_map blocks_map; + std::vector locals; + std::unordered_map vars; + inst_handles insts; + type_cache types; +}; + +struct lowerer : public lowerer_base { lowerer(): - dialect(get_dialect(ctx)), - builder(&ctx) + dialect(get_dialect(ctx)) { } @@ -189,18 +205,9 @@ struct lowerer return py::bytes(serialize_mod(*llvmmod)); } private: - mlir::MLIRContext ctx; + mllvm::LLVMDialect& dialect; - mlir::OpBuilder builder; mllvm::LLVMFuncOp func; - mlir::Block::BlockArgListType fnargs; - mlir::Block* entry_bb = nullptr; - std::vector blocks; - std::unordered_map blocks_map; - std::vector locals; - std::unordered_map vars; - inst_handles insts; - type_cache types; py::handle var_type_resolver; void lower_func_body(const py::object& func_ir) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 634a544ce5e..8b832ae9eb3 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -1,5 +1,8 @@ import numba +def ret(a): + return a + def sum1(a): return a + 42 @@ -54,6 +57,7 @@ def test(func, params): print('FAILED') +test(ret, (7,)) test(sum1, (5,)) test(sum2, (3,4)) test(cond, (5,6)) From 3c003382124e5c19b71536580f824b53c149df68 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 11 Sep 2020 19:42:52 +0300 Subject: [PATCH 018/259] some work --- mlir-compiler/CMakeLists.txt | 1 + mlir-compiler/dialect.cpp | 61 +++++++ mlir-compiler/include/plier/PlierOps.td | 43 ++++- mlir-compiler/include/plier/dialect.hpp | 20 ++- mlir-compiler/lowering.cpp | 219 ++++++++++++++++++++++-- mlir-compiler/type_parser.cpp | 4 +- mlir-compiler/type_parser.hpp | 2 +- numba/core/lowering.py | 1 + 8 files changed, 324 insertions(+), 27 deletions(-) create mode 100644 mlir-compiler/dialect.cpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index bb7c2057341..a563f5f39ca 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -20,6 +20,7 @@ include(HandleLLVMOptions) add_subdirectory(include/plier) set(SOURCES_LIST + dialect.cpp lowering.cpp module.cpp type_parser.cpp diff --git a/mlir-compiler/dialect.cpp b/mlir-compiler/dialect.cpp new file mode 100644 index 00000000000..e8a34ec3408 --- /dev/null +++ b/mlir-compiler/dialect.cpp @@ -0,0 +1,61 @@ +#include "plier/dialect.hpp" + +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Builders.h" + +#include "mlir/Dialect/StandardOps/IR/Ops.h" + +namespace plier +{ + +void register_dialect() +{ + mlir::registerDialect(); + mlir::registerDialect(); +} + +PlierDialect::PlierDialect(mlir::MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "plier/PlierOps.cpp.inc" + >(); + addTypes(); +} + +mlir::Type PlierDialect::parseType(mlir::DialectAsmParser &parser) const { + parser.emitError(parser.getNameLoc(), "unknown type"); + return mlir::Type(); +} + +void PlierDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { + switch (type.getKind()) { + case plier::types::PyType: + os << "PyType"; + return; + default: + llvm_unreachable("unexpected type kind"); + } +} + +void ArgOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + unsigned index, mlir::StringRef name) { + ArgOp::build(builder, state, PyType::get(state.getContext()), llvm::APInt(32, index), name); +} + +void CastOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value val) { + CastOp::build(builder, state, PyType::get(state.getContext()), val); +} + +void ConstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Attribute val) { + ConstOp::build(builder, state, PyType::get(state.getContext()), val); +} + +#define GET_OP_CLASSES +#include "plier/PlierOps.cpp.inc" + +} +#include "plier/PlierOpsEnums.cpp.inc" diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index f56d5169cf1..ba35833ebcf 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -9,31 +9,56 @@ def Plier_Dialect : Dialect { let cppNamespace = "plier"; } -def Plier_Type : DialectType()">, "type">, - BuildableType<"$_builder.getType<::plier::Type>()"> { +def Plier_PyType : DialectType()">, "pytype">, + BuildableType<"$_builder.getType<::plier::PyType>()"> { } class Plier_Op traits = []> : Op; -def CastOp : Plier_Op<"cast", []> { +def ArgOp : Plier_Op<"arg", []> { let arguments = (ins - Plier_Type:$value); + UI32Attr:$index, + StrAttr:$name); + + let results = (outs Plier_PyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value obj"> + OpBuilder<"OpBuilder &b, OperationState &state, unsigned index, StringRef name"> ]; } +def ConstOp : Plier_Op<"const", []> { + let arguments = (ins + AnyAttr:$val); -def DelOp : Plier_Op<"del", []> { + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Attribute val"> + ]; +} + +def CastOp : Plier_Op<"cast", []> { let arguments = (ins - Plier_Type:$value); + Plier_PyType:$value); + + let results = (outs Plier_PyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value obj"> + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value val"> ]; } + +def DelOp : Plier_Op<"del", []> { + let arguments = (ins + Plier_PyType:$value); + +// let builders = [ +// OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value val"> +// ]; +} + #endif // PLIER_OPS diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 4d24b92337e..95a330894aa 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -6,14 +6,14 @@ #include "mlir/IR/Function.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "proto/PlierOpsEnums.h.inc" +#include "plier/PlierOpsEnums.h.inc" namespace plier { using namespace mlir; // TODO: remove -#include "proto/PlierOpsDialect.h.inc" +#include "plier/PlierOpsDialect.h.inc" #define GET_OP_CLASSES -#include "proto/PlierOps.h.inc" +#include "plier/PlierOps.h.inc" } namespace plier @@ -26,9 +26,19 @@ namespace types enum Kind { // Dialect types. -// PyState = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_3_TYPE, -// PyObject, + PyType = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_3_TYPE, }; } +class PyType : public mlir::Type::TypeBase { +public: + using Base::Base; + static bool kindof(unsigned kind) { return kind == types::PyType; } + static PyType get(mlir::MLIRContext *context) + { + return Base::get(context, types::PyType); + } +}; + + } diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index f6508646bbf..acf41072d24 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -22,6 +22,9 @@ #include +#include +#include "plier/dialect.hpp" + #include namespace py = pybind11; @@ -54,9 +57,10 @@ std::string to_str(T& obj) return ret; } -mllvm::LLVMDialect& get_dialect(mlir::MLIRContext& ctx) +template +T& get_dialect(mlir::MLIRContext& ctx) { - auto dialect = ctx.getRegisteredDialect(); + auto dialect = ctx.getRegisteredDialect(); assert(nullptr != dialect); return *dialect; } @@ -143,7 +147,7 @@ struct type_cache { using Type = mllvm::LLVMType; - Type get_type(mllvm::LLVMDialect& dialect, llvm::StringRef str) + Type get_type(mlir::MLIRContext& context, llvm::StringRef str) { assert(!str.empty()); auto s = str.str(); @@ -152,7 +156,7 @@ struct type_cache { return it->second; } - auto type = parse_type(dialect, str); + auto type = parse_type(context, str); typemap[s] = type; return type; } @@ -169,19 +173,16 @@ struct lowerer_base mlir::MLIRContext ctx; mlir::OpBuilder builder; mlir::Block::BlockArgListType fnargs; - mlir::Block* entry_bb = nullptr; std::vector blocks; std::unordered_map blocks_map; - std::vector locals; std::unordered_map vars; inst_handles insts; - type_cache types; }; struct lowerer : public lowerer_base { lowerer(): - dialect(get_dialect(ctx)) + dialect(get_dialect(ctx)) { } @@ -205,9 +206,10 @@ struct lowerer : public lowerer_base return py::bytes(serialize_mod(*llvmmod)); } private: - mllvm::LLVMDialect& dialect; mllvm::LLVMFuncOp func; + mlir::Block* entry_bb = nullptr; + type_cache types; py::handle var_type_resolver; void lower_func_body(const py::object& func_ir) @@ -507,7 +509,7 @@ struct lowerer : public lowerer_base mllvm::LLVMType parse_type(llvm::StringRef str) { - return types.get_type(dialect, str); + return types.get_type(ctx, str); } mllvm::LLVMType get_func_type(const py::handle& typedesc) @@ -526,10 +528,207 @@ struct lowerer : public lowerer_base return Type::getFunctionTy(ret, args, false); } }; + +struct plier_lowerer : public lowerer_base +{ + plier_lowerer(): + dialect(get_dialect(ctx)) + { + + } + + py::bytes lower(const py::object& compilation_context, const py::object& func_ir) + { + auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); + auto name = compilation_context["fnname"]().cast(); + auto typ = get_func_type(compilation_context["fndesc"]); + func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); + lower_func_body(func_ir); + mod.push_back(func); + mod.dump(); + if (mlir::failed(mod.verify())) + { + report_error("MLIR module validation failed"); + } +// var_type_resolver = compilation_context["get_var_type"]; +// auto typ = get_func_type(compilation_context["fntype"]); +// func = builder.create(builder.getUnknownLoc(), name, typ); +// +// auto llvmmod = mlir::translateModuleToLLVMIR(mod); +// // llvmmod->dump(); +// return py::bytes(serialize_mod(*llvmmod)); + return {}; + } +private: + plier::PlierDialect& dialect; + mlir::FuncOp func; + + void lower_func_body(const py::object& func_ir) + { + auto ir_blocks = get_blocks(func_ir); + assert(!ir_blocks.empty()); + blocks.reserve(ir_blocks.size()); + for (std::size_t i = 0; i < ir_blocks.size(); ++i) + { + auto block = (0 == i ? func.addEntryBlock() : func.addBlock()); + blocks.push_back(block); + blocks_map[ir_blocks[i].first] = block; + } + fnargs = func.getArguments(); + + for (std::size_t i = 0; i < ir_blocks.size(); ++i) + { + lower_block(blocks[i], ir_blocks[i].second); + } + } + + void lower_block(mlir::Block* bb, const py::handle& ir_block) + { + assert(nullptr != bb); + vars.clear(); + builder.setInsertionPointToEnd(bb); + for (auto it : get_body(ir_block)) + { + lower_inst(it); + } + } + + void lower_inst(const py::handle& inst) + { + if (py::isinstance(inst, insts.Assign)) + { + auto name = inst.attr("target").attr("name"); + auto val = lower_assign(inst, name); + storevar(val, inst, name); + } + else if (py::isinstance(inst, insts.Del)) + { + delvar(inst.attr("value")); + } + else if (py::isinstance(inst, insts.Return)) + { + retvar(inst.attr("value").attr("name")); + } +// else if (py::isinstance(inst, insts.Branch)) +// { +// branch(inst.attr("cond").attr("name"), inst.attr("truebr"), inst.attr("falsebr")); +// } +// else if (py::isinstance(inst, insts.Jump)) +// { +// jump(inst.attr("target")); +// } + else + { + report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); + } + } + + mlir::Value lower_assign(const py::handle& inst, const py::handle& name) + { + auto value = inst.attr("value"); + if (py::isinstance(value, insts.Arg)) + { + auto index = value.attr("index").cast(); + return builder.create(builder.getUnknownLoc(), index, name.cast()); + } + if(py::isinstance(value, insts.Expr)) + { + return lower_expr(value); + } + if (py::isinstance(value, insts.Const)) + { + auto val = get_const_val(value.attr("value")); + return builder.create(builder.getUnknownLoc(), val); + } + + report_error(llvm::Twine("lower_assign not handled: \"") + py::str(value.get_type()).cast() + "\""); + } + + mlir::Value lower_expr(const py::handle& expr) + { + auto op = expr.attr("op").cast(); + if (op == "binop") + { + return lower_binop(expr, expr.attr("fn")); + } + if (op == "cast") + { + auto val = loadvar(expr.attr("value").attr("name")); + return builder.create(builder.getUnknownLoc(), val); + } +// if (op == "call") +// { +// return lower_call(expr); +// } + report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); + } + + mlir::Value lower_binop(const py::handle& expr, const py::handle& op) + { + auto lhs_name = expr.attr("lhs").attr("name"); + auto rhs_name = expr.attr("rhs").attr("name"); + auto lhs = loadvar(lhs_name); + auto rhs = loadvar(rhs_name); +// return resolve_op(lhs, rhs, op); + } + + void storevar(mlir::Value val, const py::handle& inst, const py::handle& name) + { + auto name_str = name.cast(); + vars[name_str] = val; + } + + mlir::Value loadvar(const py::handle& name) + { + auto it = vars.find(name.cast()); + assert(vars.end() != it); + return it->second; + } + + void delvar(const py::handle& name) + { + auto var = loadvar(name); + builder.create(builder.getUnknownLoc(), var); + vars.erase(name.cast()); + } + + void retvar(const py::handle& name) + { + auto var = loadvar(name); + builder.create(builder.getUnknownLoc(), var); + } + + mlir::Attribute get_const_val(const py::handle& val) + { + if (py::isinstance(val)) + { + return builder.getI64IntegerAttr(val.cast()); + } + report_error(llvm::Twine("get_const_val unhandled type \"") + py::str(val.get_type()).cast() + "\""); + } + + mlir::FunctionType get_func_type(const py::handle& typedesc) + { + auto get_type = [&](const auto& h) { +// return parse_type(py::str(h).cast()); + return plier::PyType::get(&ctx); + }; + auto p_func = typedesc(); + auto ret = get_type(p_func.attr("restype")); + llvm::SmallVector args; +// for (auto arg : p_func.attr("argtypes")) +// { +// args.push_back(get_type(arg)); +// } + return mlir::FunctionType::get(args, {ret}, &ctx); + } +}; } py::bytes lower_function(const py::object& compilation_context, const py::object& func_ir) { mlir::registerDialect(); + plier::register_dialect(); + plier_lowerer().lower(compilation_context, func_ir); return lowerer().lower(compilation_context, func_ir); } diff --git a/mlir-compiler/type_parser.cpp b/mlir-compiler/type_parser.cpp index edbb38e5953..d10d58692b6 100644 --- a/mlir-compiler/type_parser.cpp +++ b/mlir-compiler/type_parser.cpp @@ -11,11 +11,11 @@ namespace } } -mlir::LLVM::LLVMType parse_type(mlir::LLVM::LLVMDialect& dialect, llvm::StringRef str) +mlir::LLVM::LLVMType parse_type(mlir::MLIRContext& context, llvm::StringRef str) { assert(!str.empty()); auto mlir_type = (std::string("!llvm<\"") + str + "\">").str(); - auto res = mlir::parseType(mlir_type, dialect.getContext()).dyn_cast_or_null(); + auto res = mlir::parseType(mlir_type, &context).dyn_cast_or_null(); if (mlir::Type() == res) { report_error(llvm::Twine("cannot parse type: \"") + str + "\""); diff --git a/mlir-compiler/type_parser.hpp b/mlir-compiler/type_parser.hpp index 6de1c61eff5..6d95aeef219 100644 --- a/mlir-compiler/type_parser.hpp +++ b/mlir-compiler/type_parser.hpp @@ -2,4 +2,4 @@ #include -mlir::LLVM::LLVMType parse_type(mlir::LLVM::LLVMDialect& dialect, llvm::StringRef str); +mlir::LLVM::LLVMType parse_type(mlir::MLIRContext& context, llvm::StringRef str); diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 93294b91748..be8ee491141 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -187,6 +187,7 @@ def lower_normal_function(self, fndesc): """ if _use_mlir: ctx = {} + ctx['fndesc'] = lambda: fndesc ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) ctx['fnname'] = lambda: fndesc.mangled_name ctx['get_var_type'] = lambda name: self.context.get_value_type(self.typeof(name)) From 3f336e1e22d73c549d24de02c173c1384fcb8d29 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 16 Sep 2020 21:24:50 +0300 Subject: [PATCH 019/259] work on plier dialect --- mlir-compiler/dialect.cpp | 39 +++++++++-- mlir-compiler/include/plier/PlierOps.td | 39 +++++++++++ mlir-compiler/lowering.cpp | 90 +++++++++++++++++++++---- 3 files changed, 152 insertions(+), 16 deletions(-) diff --git a/mlir-compiler/dialect.cpp b/mlir-compiler/dialect.cpp index e8a34ec3408..5ef051019b7 100644 --- a/mlir-compiler/dialect.cpp +++ b/mlir-compiler/dialect.cpp @@ -41,7 +41,24 @@ void PlierDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const void ArgOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, unsigned index, mlir::StringRef name) { - ArgOp::build(builder, state, PyType::get(state.getContext()), llvm::APInt(32, index), name); + ArgOp::build(builder, state, PyType::get(state.getContext()), + llvm::APInt(32, index), name); +} + +void ConstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + + mlir::Attribute val) { + ConstOp::build(builder, state, PyType::get(state.getContext()), val); +} + +void GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::StringRef name) { + GlobalOp::build(builder, state, PyType::get(state.getContext()), name); +} + +void BinOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value lhs, mlir::Value rhs, mlir::StringRef op) { + BinOp::build(builder, state, PyType::get(state.getContext()), lhs, rhs, op); } void CastOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, @@ -49,9 +66,23 @@ void CastOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, CastOp::build(builder, state, PyType::get(state.getContext()), val); } -void ConstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Attribute val) { - ConstOp::build(builder, state, PyType::get(state.getContext()), val); +void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func, + mlir::ValueRange args, + mlir::ArrayRef> kwargs) { + auto ctx = builder.getContext(); + mlir::SmallVector all_args; + all_args.reserve(args.size() + kwargs.size()); + std::copy(args.begin(), args.end(), std::back_inserter(all_args)); + auto kw_start = llvm::APInt(32, all_args.size()); + mlir::SmallVector kw_names; + kw_names.reserve(kwargs.size()); + for (auto& a : kwargs) + { + kw_names.push_back(mlir::StringAttr::get(a.first, ctx)); + all_args.push_back(a.second); + } + PyCallOp::build(builder, state, PyType::get(state.getContext()), func, + all_args, kw_start, mlir::ArrayAttr::get(kw_names, ctx)); } #define GET_OP_CLASSES diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index ba35833ebcf..ef0eddf6297 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -40,6 +40,30 @@ def ConstOp : Plier_Op<"const", []> { ]; } +def GlobalOp : Plier_Op<"global", []> { + let arguments = (ins + StrAttr:$name); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, StringRef name"> + ]; +} + +def BinOp : Plier_Op<"binop", []> { + let arguments = (ins + Plier_PyType:$rhs, + Plier_PyType:$lhs, + StrAttr:$op); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value rhs, ::mlir::Value lhs, StringRef op"> + ]; +} + def CastOp : Plier_Op<"cast", []> { let arguments = (ins Plier_PyType:$value); @@ -51,6 +75,21 @@ def CastOp : Plier_Op<"cast", []> { ]; } +def PyCallOp : Plier_Op<"call", []> { + let arguments = (ins + Plier_PyType:$func, + Variadic:$args, + UI32Attr:$kw_start, + ArrayAttr:$kw_names); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value func, " + "::mlir::ValueRange args, " + "::mlir::ArrayRef> kwargs"> + ]; +} def DelOp : Plier_Op<"del", []> { let arguments = (ins diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index acf41072d24..ee9974cb1d4 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -585,7 +585,7 @@ struct plier_lowerer : public lowerer_base void lower_block(mlir::Block* bb, const py::handle& ir_block) { assert(nullptr != bb); - vars.clear(); +// vars.clear(); builder.setInsertionPointToEnd(bb); for (auto it : get_body(ir_block)) { @@ -609,10 +609,10 @@ struct plier_lowerer : public lowerer_base { retvar(inst.attr("value").attr("name")); } -// else if (py::isinstance(inst, insts.Branch)) -// { -// branch(inst.attr("cond").attr("name"), inst.attr("truebr"), inst.attr("falsebr")); -// } + else if (py::isinstance(inst, insts.Branch)) + { + branch(inst.attr("cond").attr("name"), inst.attr("truebr"), inst.attr("falsebr")); + } // else if (py::isinstance(inst, insts.Jump)) // { // jump(inst.attr("target")); @@ -629,7 +629,8 @@ struct plier_lowerer : public lowerer_base if (py::isinstance(value, insts.Arg)) { auto index = value.attr("index").cast(); - return builder.create(builder.getUnknownLoc(), index, name.cast()); + return builder.create(builder.getUnknownLoc(), index, + name.cast()); } if(py::isinstance(value, insts.Expr)) { @@ -640,6 +641,12 @@ struct plier_lowerer : public lowerer_base auto val = get_const_val(value.attr("value")); return builder.create(builder.getUnknownLoc(), val); } + if (py::isinstance(value, insts.Global)) + { + auto name = value.attr("name").cast(); + return builder.create(builder.getUnknownLoc(), + name); + } report_error(llvm::Twine("lower_assign not handled: \"") + py::str(value.get_type()).cast() + "\""); } @@ -656,20 +663,71 @@ struct plier_lowerer : public lowerer_base auto val = loadvar(expr.attr("value").attr("name")); return builder.create(builder.getUnknownLoc(), val); } -// if (op == "call") -// { -// return lower_call(expr); -// } + if (op == "call") + { + return lower_call(expr); + } report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); } + mlir::Value lower_call(const py::handle& expr) + { + auto func = loadvar(expr.attr("func").attr("name")); + auto args = expr.attr("args").cast(); + auto kws = expr.attr("kws").cast(); + auto vararg = expr.attr("vararg"); +// std::cout << py::str(args).cast() << std::endl; +// std::cout << py::str(kws).cast() << std::endl; +// std::cout << py::str(vararg).cast() << std::endl; + + mlir::SmallVector args_list; + mlir::SmallVector, 8> kwargs_list; + for (auto a : args) + { + args_list.push_back(loadvar(a.attr("name"))); + } + for (auto a : kws) + { + auto item = a.cast(); + auto name = item[0]; + auto val_name = item[1].attr("name"); + kwargs_list.push_back({name.cast(), loadvar(val_name)}); + } + + return builder.create(builder.getUnknownLoc(), func, + args_list, kwargs_list); + } + mlir::Value lower_binop(const py::handle& expr, const py::handle& op) { auto lhs_name = expr.attr("lhs").attr("name"); auto rhs_name = expr.attr("rhs").attr("name"); auto lhs = loadvar(lhs_name); auto rhs = loadvar(rhs_name); -// return resolve_op(lhs, rhs, op); + return resolve_op(lhs, rhs, op); + } + + mlir::Value resolve_op(mlir::Value lhs, mlir::Value rhs, const py::handle& op) + { + // TODO unhardcode + if (op.is(insts.add)) + { + return builder.create(builder.getUnknownLoc(), lhs, rhs, "+"); + } +// if (op.is(insts.eq)) +// { +// assert(lhs.getType() == rhs.getType()); +// if (lhs.getType().cast().isIntegerTy()) +// { +// return builder.create(builder.getUnknownLoc(), mllvm::ICmpPredicate::eq, lhs, rhs); +// } +// } + if (op.is(insts.gt)) + { + return builder.create(builder.getUnknownLoc(), lhs, rhs, ">"); + } + + report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); } void storevar(mlir::Value val, const py::handle& inst, const py::handle& name) @@ -689,7 +747,7 @@ struct plier_lowerer : public lowerer_base { auto var = loadvar(name); builder.create(builder.getUnknownLoc(), var); - vars.erase(name.cast()); +// vars.erase(name.cast()); } void retvar(const py::handle& name) @@ -698,6 +756,14 @@ struct plier_lowerer : public lowerer_base builder.create(builder.getUnknownLoc(), var); } + void branch(const py::handle& cond, const py::handle& tr, const py::handle& fl) + { + auto c = loadvar(cond); + auto tr_block = blocks_map.find(tr.cast())->second; + auto fl_block = blocks_map.find(fl.cast())->second; + builder.create(builder.getUnknownLoc(), c, tr_block, fl_block); + } + mlir::Attribute get_const_val(const py::handle& val) { if (py::isinstance(val)) From 035c7199a6a5b747889c535fc1f6c857cb4c4a74 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 17 Sep 2020 18:10:59 +0300 Subject: [PATCH 020/259] some work --- mlir-compiler/dialect.cpp | 5 +++++ mlir-compiler/include/plier/PlierOps.td | 12 ++++++++++++ mlir-compiler/lowering.cpp | 23 ++++++++++++++++++----- 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/dialect.cpp b/mlir-compiler/dialect.cpp index 5ef051019b7..069c93deb2d 100644 --- a/mlir-compiler/dialect.cpp +++ b/mlir-compiler/dialect.cpp @@ -66,6 +66,11 @@ void CastOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, CastOp::build(builder, state, PyType::get(state.getContext()), val); } +void AssignOp::build(OpBuilder &builder, OperationState &state, + mlir::Value value, StringRef name) { + AssignOp::build(builder, state, PyType::get(state.getContext()), value, name); +} + void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func, mlir::ValueRange args, mlir::ArrayRef> kwargs) { diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index ef0eddf6297..7fad777f873 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -75,6 +75,18 @@ def CastOp : Plier_Op<"cast", []> { ]; } +def AssignOp : Plier_Op<"assign", []> { + let arguments = (ins + Plier_PyType:$value, + StrAttr:$name); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, StringRef name"> + ]; +} + def PyCallOp : Plier_Op<"call", []> { let arguments = (ins Plier_PyType:$func, diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index ee9974cb1d4..c16037d7671 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -613,10 +613,10 @@ struct plier_lowerer : public lowerer_base { branch(inst.attr("cond").attr("name"), inst.attr("truebr"), inst.attr("falsebr")); } -// else if (py::isinstance(inst, insts.Jump)) -// { -// jump(inst.attr("target")); -// } + else if (py::isinstance(inst, insts.Jump)) + { + jump(inst.attr("target")); + } else { report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); @@ -636,6 +636,13 @@ struct plier_lowerer : public lowerer_base { return lower_expr(value); } + if(py::isinstance(value, insts.Var)) + { + auto var = loadvar(value.attr("name")); + return builder.create( + builder.getUnknownLoc(), var, + value.attr("name").cast()); + } if (py::isinstance(value, insts.Const)) { auto val = get_const_val(value.attr("value")); @@ -761,7 +768,13 @@ struct plier_lowerer : public lowerer_base auto c = loadvar(cond); auto tr_block = blocks_map.find(tr.cast())->second; auto fl_block = blocks_map.find(fl.cast())->second; - builder.create(builder.getUnknownLoc(), c, tr_block, fl_block); + builder.create(builder.getUnknownLoc(), c, tr_block, fl_block); + } + + void jump(const py::handle& target) + { + auto block = blocks_map.find(target.cast())->second; + builder.create(builder.getUnknownLoc(), mlir::None, block); } mlir::Attribute get_const_val(const py::handle& val) From 4475330575835db4a9a1fee1706bbb8e47dc4a31 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 23 Sep 2020 22:45:24 +0300 Subject: [PATCH 021/259] lowering of numba phi nodes --- mlir-compiler/lowering.cpp | 173 +++++++++++++++++++++++++++++-------- numba/core/compiler.py | 5 +- numba/core/lowering.py | 2 +- numba/core/typed_passes.py | 19 ++++ 4 files changed, 161 insertions(+), 38 deletions(-) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index c16037d7671..bdd825f7345 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -175,7 +175,6 @@ struct lowerer_base mlir::Block::BlockArgListType fnargs; std::vector blocks; std::unordered_map blocks_map; - std::unordered_map vars; inst_handles insts; }; @@ -210,6 +209,7 @@ struct lowerer : public lowerer_base mllvm::LLVMFuncOp func; mlir::Block* entry_bb = nullptr; type_cache types; + std::unordered_map vars; py::handle var_type_resolver; void lower_func_body(const py::object& func_ir) @@ -540,8 +540,9 @@ struct plier_lowerer : public lowerer_base py::bytes lower(const py::object& compilation_context, const py::object& func_ir) { auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); - auto name = compilation_context["fnname"]().cast(); - auto typ = get_func_type(compilation_context["fndesc"]); +// auto name = compilation_context["fnname"]().cast(); + auto name = "test"; + auto typ = get_func_type(/*compilation_context["fndesc"]*/nullptr); func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); lower_func_body(func_ir); mod.push_back(func); @@ -562,6 +563,19 @@ struct plier_lowerer : public lowerer_base private: plier::PlierDialect& dialect; mlir::FuncOp func; + std::unordered_map vars_map; + struct BlockInfo + { + struct PhiDesc + { + mlir::Block* dest_block = nullptr; + std::string var_name; + unsigned arg_index = 0; + }; + llvm::SmallVector outgoing_phi_nodes; + }; + + std::unordered_map block_infos; void lower_func_body(const py::object& func_ir) { @@ -580,12 +594,12 @@ struct plier_lowerer : public lowerer_base { lower_block(blocks[i], ir_blocks[i].second); } + fixup_phis(); } void lower_block(mlir::Block* bb, const py::handle& ir_block) { assert(nullptr != bb); -// vars.clear(); builder.setInsertionPointToEnd(bb); for (auto it : get_body(ir_block)) { @@ -597,9 +611,9 @@ struct plier_lowerer : public lowerer_base { if (py::isinstance(inst, insts.Assign)) { - auto name = inst.attr("target").attr("name"); - auto val = lower_assign(inst, name); - storevar(val, inst, name); + auto target = inst.attr("target"); + auto val = lower_assign(inst, target); + storevar(val, target); } else if (py::isinstance(inst, insts.Del)) { @@ -607,11 +621,11 @@ struct plier_lowerer : public lowerer_base } else if (py::isinstance(inst, insts.Return)) { - retvar(inst.attr("value").attr("name")); + retvar(inst.attr("value")); } else if (py::isinstance(inst, insts.Branch)) { - branch(inst.attr("cond").attr("name"), inst.attr("truebr"), inst.attr("falsebr")); + branch(inst.attr("cond"), inst.attr("truebr"), inst.attr("falsebr")); } else if (py::isinstance(inst, insts.Jump)) { @@ -623,14 +637,14 @@ struct plier_lowerer : public lowerer_base } } - mlir::Value lower_assign(const py::handle& inst, const py::handle& name) + mlir::Value lower_assign(const py::handle& inst, const py::handle& target) { auto value = inst.attr("value"); if (py::isinstance(value, insts.Arg)) { auto index = value.attr("index").cast(); return builder.create(builder.getUnknownLoc(), index, - name.cast()); + target.attr("name").cast()); } if(py::isinstance(value, insts.Expr)) { @@ -638,10 +652,11 @@ struct plier_lowerer : public lowerer_base } if(py::isinstance(value, insts.Var)) { - auto var = loadvar(value.attr("name")); - return builder.create( - builder.getUnknownLoc(), var, - value.attr("name").cast()); + auto var = loadvar(value); + return var; +// return builder.create( +// builder.getUnknownLoc(), var, +// value.attr("name").cast()); } if (py::isinstance(value, insts.Const)) { @@ -667,19 +682,47 @@ struct plier_lowerer : public lowerer_base } if (op == "cast") { - auto val = loadvar(expr.attr("value").attr("name")); + auto val = loadvar(expr.attr("value")); return builder.create(builder.getUnknownLoc(), val); } if (op == "call") { return lower_call(expr); } + if (op == "phi") + { + return lower_phi(expr); + } report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); } + mlir::Value lower_phi(const py::handle& inst) + { + auto incoming_vals = inst.attr("incoming_values").cast(); + auto incoming_blocks = inst.attr("incoming_blocks").cast(); + assert(incoming_vals.size() == incoming_blocks.size()); + + auto current_block = builder.getBlock(); + assert(nullptr != current_block); + + auto arg_index = current_block->getNumArguments(); + auto arg = current_block->addArgument(plier::PyType::get(&ctx)); + + auto count = incoming_vals.size(); + for (std::size_t i = 0; i < count; ++i) + { + auto var = incoming_vals[i].attr("name").cast(); + auto block = blocks_map.find(incoming_blocks[i].cast())->second; + block_infos[block].outgoing_phi_nodes.push_back({current_block, std::move(var), arg_index}); + } + + return arg; + } + + mlir::Value lower_call(const py::handle& expr) { - auto func = loadvar(expr.attr("func").attr("name")); + auto func = loadvar(expr.attr("func")); auto args = expr.attr("args").cast(); auto kws = expr.attr("kws").cast(); auto vararg = expr.attr("vararg"); @@ -691,13 +734,13 @@ struct plier_lowerer : public lowerer_base mlir::SmallVector, 8> kwargs_list; for (auto a : args) { - args_list.push_back(loadvar(a.attr("name"))); + args_list.push_back(loadvar(a)); } for (auto a : kws) { auto item = a.cast(); auto name = item[0]; - auto val_name = item[1].attr("name"); + auto val_name = item[1]; kwargs_list.push_back({name.cast(), loadvar(val_name)}); } @@ -707,8 +750,8 @@ struct plier_lowerer : public lowerer_base mlir::Value lower_binop(const py::handle& expr, const py::handle& op) { - auto lhs_name = expr.attr("lhs").attr("name"); - auto rhs_name = expr.attr("rhs").attr("name"); + auto lhs_name = expr.attr("lhs"); + auto rhs_name = expr.attr("rhs"); auto lhs = loadvar(lhs_name); auto rhs = loadvar(rhs_name); return resolve_op(lhs, rhs, op); @@ -737,29 +780,27 @@ struct plier_lowerer : public lowerer_base report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); } - void storevar(mlir::Value val, const py::handle& inst, const py::handle& name) + void storevar(mlir::Value val, const py::handle& inst) { - auto name_str = name.cast(); - vars[name_str] = val; + vars_map[inst.attr("name").cast()] = val; } - mlir::Value loadvar(const py::handle& name) + mlir::Value loadvar(const py::handle& inst) { - auto it = vars.find(name.cast()); - assert(vars.end() != it); + auto it = vars_map.find(inst.attr("name").cast()); + assert(vars_map.end() != it); return it->second; } - void delvar(const py::handle& name) + void delvar(const py::handle& inst) { - auto var = loadvar(name); + auto var = loadvar(inst); builder.create(builder.getUnknownLoc(), var); -// vars.erase(name.cast()); } - void retvar(const py::handle& name) + void retvar(const py::handle& inst) { - auto var = loadvar(name); + auto var = loadvar(inst); builder.create(builder.getUnknownLoc(), var); } @@ -792,8 +833,9 @@ struct plier_lowerer : public lowerer_base // return parse_type(py::str(h).cast()); return plier::PyType::get(&ctx); }; - auto p_func = typedesc(); - auto ret = get_type(p_func.attr("restype")); +// auto p_func = typedesc(); +// auto ret = get_type(p_func.attr("restype")); + auto ret = plier::PyType::get(&ctx); llvm::SmallVector args; // for (auto arg : p_func.attr("argtypes")) // { @@ -801,6 +843,65 @@ struct plier_lowerer : public lowerer_base // } return mlir::FunctionType::get(args, {ret}, &ctx); } + + void fixup_phis() + { + auto build_arg_list = [&](mlir::Block* block, auto& outgoing_phi_nodes, auto& list) + { + for (auto& o : outgoing_phi_nodes) + { + if (o.dest_block == block) + { + if (list.size() <= o.arg_index) + { + list.resize(o.arg_index + 1); + } + auto it = vars_map.find(o.var_name); + assert(vars_map.end() != it); + list[o.arg_index] = it->second; + } + } + }; + for (auto& bb : func) + { + auto it = block_infos.find(&bb); + if (block_infos.end() != it) + { + auto& info = it->second; + auto term = bb.getTerminator(); + if (nullptr == term) + { + report_error("broken ir: block withoout terminator"); + } + builder.setInsertionPointToEnd(&bb); + + if (auto op = mlir::dyn_cast(term)) + { + auto dest = op.getDest(); + mlir::SmallVector args; + build_arg_list(dest, info.outgoing_phi_nodes, args); + op.erase(); + builder.create(builder.getUnknownLoc(), dest, args); + } + else if (auto op = mlir::dyn_cast(term)) + { + auto true_dest = op.trueDest(); + auto false_dest = op.falseDest(); + auto cond = op.getCondition(); + mlir::SmallVector true_args; + mlir::SmallVector false_args; + build_arg_list(true_dest, info.outgoing_phi_nodes, true_args); + build_arg_list(false_dest, info.outgoing_phi_nodes, false_args); + op.erase(); + builder.create(builder.getUnknownLoc(), cond, true_dest, true_args, false_dest, false_args); + } + else + { + report_error(llvm::Twine("Unhandled terminator: ") + term->getName().getStringRef()); + } + } + } + } }; } @@ -808,6 +909,6 @@ py::bytes lower_function(const py::object& compilation_context, const py::object { mlir::registerDialect(); plier::register_dialect(); - plier_lowerer().lower(compilation_context, func_ir); - return lowerer().lower(compilation_context, func_ir); + return plier_lowerer().lower(compilation_context, func_ir); +// return lowerer().lower(compilation_context, func_ir); } diff --git a/numba/core/compiler.py b/numba/core/compiler.py index 6275a5c48ee..dd811a39066 100644 --- a/numba/core/compiler.py +++ b/numba/core/compiler.py @@ -28,7 +28,8 @@ NopythonRewrites, PreParforPass, ParforPass, DumpParforDiagnostics, IRLegalization, NoPythonBackend, - InlineOverloads, PreLowerStripPhis) + InlineOverloads, PreLowerStripPhis, + MlirBackend) from numba.core.object_mode_passes import (ObjectModeFrontEnd, ObjectModeBackEnd, CompileInterpMode) @@ -503,6 +504,8 @@ def define_typed_pipeline(state, name="typed"): pm.add_pass(NopythonTypeInference, "nopython frontend") pm.add_pass(AnnotateTypes, "annotate types") + pm.add_pass(MlirBackend, "mlir backend") + # strip phis pm.add_pass(PreLowerStripPhis, "remove phis nodes") diff --git a/numba/core/lowering.py b/numba/core/lowering.py index be8ee491141..5cbb2a1583d 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -11,7 +11,7 @@ from numba.core.funcdesc import default_mangler from numba.core.environment import Environment -_use_mlir = True +_use_mlir = False _VarArgItem = namedtuple("_VarArgItem", ("vararg", "index")) diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 3756a3e7c22..794a941f8c6 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -463,6 +463,25 @@ def run_pass(self, state): ) return True +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirBackend(LoweringPass): + + _name = "mlir_backend" + + def __init__(self): + # LoweringPass.__init__(self) + pass + + def run_pass(self, state): + ctx = {} + # ctx['fndesc'] = lambda: fndesc + # ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) + # ctx['fnname'] = lambda: fndesc.mangled_name + # ctx['get_var_type'] = lambda name: self.context.get_value_type(self.typeof(name)) + import mlir_compiler + mlir_compiler.lower_normal_function(ctx, state.func_ir) + return True + @register_pass(mutates_CFG=True, analysis_only=False) class InlineOverloads(FunctionPass): From c44ea4eaa85d9ea1b36d1b9c5e7dbb46acddd78e Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 24 Sep 2020 14:55:54 +0300 Subject: [PATCH 022/259] tuples --- mlir-compiler/dialect.cpp | 20 +++++++++---- mlir-compiler/include/plier/PlierOps.td | 37 +++++++++++++++++-------- mlir-compiler/lowering.cpp | 28 +++++++++++++++++++ mlir-compiler/test.py | 4 +-- 4 files changed, 70 insertions(+), 19 deletions(-) diff --git a/mlir-compiler/dialect.cpp b/mlir-compiler/dialect.cpp index 069c93deb2d..fc672d66320 100644 --- a/mlir-compiler/dialect.cpp +++ b/mlir-compiler/dialect.cpp @@ -66,11 +66,6 @@ void CastOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, CastOp::build(builder, state, PyType::get(state.getContext()), val); } -void AssignOp::build(OpBuilder &builder, OperationState &state, - mlir::Value value, StringRef name) { - AssignOp::build(builder, state, PyType::get(state.getContext()), value, name); -} - void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func, mlir::ValueRange args, mlir::ArrayRef> kwargs) { @@ -90,6 +85,21 @@ void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func all_args, kw_start, mlir::ArrayAttr::get(kw_names, ctx)); } +void BuildTupleOp::build(OpBuilder &builder, OperationState &state, + ::mlir::ValueRange args) +{ + BuildTupleOp::build(builder, state, PyType::get(state.getContext()), args); +} + +void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, + ::mlir::Value value, ::mlir::Value index_var, + unsigned int index) +{ + StaticGetItemOp::build(builder, state, PyType::get(state.getContext()), + value, index_var, llvm::APInt(32, index)); +} + + #define GET_OP_CLASSES #include "plier/PlierOps.cpp.inc" diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 7fad777f873..5f519c798bf 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -75,18 +75,6 @@ def CastOp : Plier_Op<"cast", []> { ]; } -def AssignOp : Plier_Op<"assign", []> { - let arguments = (ins - Plier_PyType:$value, - StrAttr:$name); - - let results = (outs Plier_PyType); - - let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, StringRef name"> - ]; -} - def PyCallOp : Plier_Op<"call", []> { let arguments = (ins Plier_PyType:$func, @@ -103,6 +91,31 @@ def PyCallOp : Plier_Op<"call", []> { ]; } +def BuildTupleOp : Plier_Op<"build_tuple", []> { + let arguments = (ins + Variadic:$args); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::ValueRange args"> +]; +} + +def StaticGetItemOp : Plier_Op<"static_getitem", []> { + let arguments = (ins + Plier_PyType:$value, + Plier_PyType:$index_var, + UI32Attr:$index); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, " + "::mlir::Value index_var, unsigned index"> +]; +} + def DelOp : Plier_Op<"del", []> { let arguments = (ins Plier_PyType:$value); diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index bdd825f7345..b600cda1de4 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -693,9 +693,37 @@ struct plier_lowerer : public lowerer_base { return lower_phi(expr); } + if (op == "build_tuple") + { + return lower_build_tuple(expr); + } + if (op == "static_getitem") + { + return lower_static_getitem(expr); + } report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); } + mlir::Value lower_static_getitem(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + auto index_var = loadvar(inst.attr("index_var")); + auto index = inst.attr("index").cast(); + return builder.create(builder.getUnknownLoc(), + value, index_var, index); + } + + mlir::Value lower_build_tuple(const py::handle& inst) + { + auto items = inst.attr("items").cast(); + mlir::SmallVector args; + for (auto item : items) + { + args.push_back(loadvar(item)); + } + return builder.create(builder.getUnknownLoc(), args); + } + mlir::Value lower_phi(const py::handle& inst) { auto incoming_vals = inst.attr("incoming_values").cast(); diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 8b832ae9eb3..c1b753c1143 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -65,6 +65,6 @@ def test(func, params): test(var, (8,)) test(jump, (1,8)) test(jump, (7,8)) -#test(call, (1,2,3)) -#test(tuple, (1,2,3)) +test(call, (1,2,3)) +test(tuple, (1,2,3)) #test(loop, (8,)) From cf7884c05b223e95b58f446fa6af8eecc0cd4748 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 24 Sep 2020 21:45:31 +0300 Subject: [PATCH 023/259] some work --- mlir-compiler/dialect.cpp | 24 ++++++++++++++ mlir-compiler/include/plier/PlierOps.td | 44 +++++++++++++++++++++++++ mlir-compiler/lowering.cpp | 26 +++++++++++++-- 3 files changed, 92 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/dialect.cpp b/mlir-compiler/dialect.cpp index fc672d66320..8581256e208 100644 --- a/mlir-compiler/dialect.cpp +++ b/mlir-compiler/dialect.cpp @@ -99,6 +99,30 @@ void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, value, index_var, llvm::APInt(32, index)); } +void GetiterOp::build(OpBuilder &builder, OperationState &state, + ::mlir::Value value) +{ + GetiterOp::build(builder, state, PyType::get(state.getContext()), value); +} + +void IternextOp::build(OpBuilder &builder, OperationState &state, + ::mlir::Value value) +{ + IternextOp::build(builder, state, PyType::get(state.getContext()), value); +} + +void PairfirstOp::build(OpBuilder &builder, OperationState &state, + ::mlir::Value value) +{ + PairfirstOp::build(builder, state, PyType::get(state.getContext()), value); +} + +void PairsecondOp::build(OpBuilder &builder, OperationState &state, + ::mlir::Value value) +{ + PairsecondOp::build(builder, state, PyType::get(state.getContext()), value); +} + #define GET_OP_CLASSES #include "plier/PlierOps.cpp.inc" diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 5f519c798bf..e35a94f7986 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -116,6 +116,50 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { ]; } +def GetiterOp : Plier_Op<"getiter", []> { + let arguments = (ins + Plier_PyType:$value); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> +]; +} + +def IternextOp : Plier_Op<"iternext", []> { + let arguments = (ins + Plier_PyType:$value); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> +]; +} + +def PairfirstOp : Plier_Op<"pair_first", []> { + let arguments = (ins + Plier_PyType:$value); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> +]; +} + +def PairsecondOp : Plier_Op<"pair_second", []> { + let arguments = (ins + Plier_PyType:$value); + + let results = (outs Plier_PyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> +]; +} + def DelOp : Plier_Op<"del", []> { let arguments = (ins Plier_PyType:$value); diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index b600cda1de4..b985d9679f3 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -682,8 +682,7 @@ struct plier_lowerer : public lowerer_base } if (op == "cast") { - auto val = loadvar(expr.attr("value")); - return builder.create(builder.getUnknownLoc(), val); + return lower_simple(expr); } if (op == "call") { @@ -701,9 +700,32 @@ struct plier_lowerer : public lowerer_base { return lower_static_getitem(expr); } + if (op == "getiter") + { + return lower_simple(expr); + } + if (op == "iternext") + { + return lower_simple(expr); + } + if (op == "pair_first") + { + return lower_simple(expr); + } + if (op == "pair_second") + { + return lower_simple(expr); + } report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); } + template + mlir::Value lower_simple(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + return builder.create(builder.getUnknownLoc(), value); + } + mlir::Value lower_static_getitem(const py::handle& inst) { auto value = loadvar(inst.attr("value")); From 1d5452c2657ffeb10456aa95008051a3aedadf83 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 25 Sep 2020 18:20:41 +0300 Subject: [PATCH 024/259] test --- mlir-compiler/test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index c1b753c1143..065f2688e92 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -39,7 +39,7 @@ def tuple(a,b,c): def loop(n): res = 0 for i in range(n): - res += i + res = res + i return res @@ -67,4 +67,4 @@ def test(func, params): test(jump, (7,8)) test(call, (1,2,3)) test(tuple, (1,2,3)) -#test(loop, (8,)) +test(loop, (8,)) From 51d467b339193bcec884235ad224463174e976cf Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 15:22:56 +0300 Subject: [PATCH 025/259] make type stateful --- mlir-compiler/dialect.cpp | 36 ++++++++++++++++++++++++- mlir-compiler/include/plier/dialect.hpp | 18 +++++++++---- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/dialect.cpp b/mlir-compiler/dialect.cpp index 8581256e208..c241dafaa7a 100644 --- a/mlir-compiler/dialect.cpp +++ b/mlir-compiler/dialect.cpp @@ -9,6 +9,30 @@ namespace plier { +namespace detail +{ +struct PyTypeStorage : public mlir::TypeStorage +{ + using KeyTy = mlir::StringRef; + + PyTypeStorage(mlir::StringRef name): name(name) {} + + bool operator==(const KeyTy& key) const + { + return key == name; + } + + static PyTypeStorage* construct(mlir::TypeStorageAllocator& allocator, + const KeyTy& key) + { + return new(allocator.allocate()) + PyTypeStorage(allocator.copyInto(key)); + } + + mlir::StringRef name; +}; +} + void register_dialect() { mlir::registerDialect(); @@ -32,13 +56,23 @@ mlir::Type PlierDialect::parseType(mlir::DialectAsmParser &parser) const { void PlierDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { switch (type.getKind()) { case plier::types::PyType: - os << "PyType"; + os << "PyType<" << type.cast().getName() << ">"; return; default: llvm_unreachable("unexpected type kind"); } } +PyType PyType::get(MLIRContext* context, llvm::StringRef name) +{ + return Base::get(context, types::PyType, name); +} + +llvm::StringRef PyType::getName() const +{ + return getImpl()->name; +} + void ArgOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, unsigned index, mlir::StringRef name) { ArgOp::build(builder, state, PyType::get(state.getContext()), diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 95a330894aa..2e0bacac6ed 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -19,6 +19,11 @@ using namespace mlir; // TODO: remove namespace plier { +namespace detail +{ +struct PyTypeStorage; +} + void register_dialect(); namespace types @@ -30,14 +35,17 @@ enum Kind }; } -class PyType : public mlir::Type::TypeBase { +class PyType : public mlir::Type::TypeBase +{ public: using Base::Base; + static bool kindof(unsigned kind) { return kind == types::PyType; } - static PyType get(mlir::MLIRContext *context) - { - return Base::get(context, types::PyType); - } + + static PyType get(mlir::MLIRContext *context, mlir::StringRef name = {}); + + mlir::StringRef getName() const; }; From e7a12e4367de0e75bb9cbd28baf367ca2b4ccc52 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 15:23:04 +0300 Subject: [PATCH 026/259] test counter --- mlir-compiler/test.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 065f2688e92..048911e5614 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -1,5 +1,8 @@ import numba +_tests_total = 0 +_tests_passes = 0 + def ret(a): return a @@ -44,6 +47,9 @@ def loop(n): def test(func, params): + global _tests_total + global _tests_passes + _tests_total += 1 print('test', func.__name__, params, '... ', end='') result = func(*params) wrapped = numba.njit()(func) @@ -52,6 +58,7 @@ def test(func, params): if (res != result): raise Exception(f'Invalid value "{res}", expected "{result}"') print('SUCCESS') + _tests_passes += 1 except Exception as e: print(e) print('FAILED') @@ -68,3 +75,5 @@ def test(func, params): test(call, (1,2,3)) test(tuple, (1,2,3)) test(loop, (8,)) + +print(f'Tests passed: {_tests_passes}/{_tests_total}') From 2f2b56bd3d958920520df96b2710fc1a8537a6e0 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 19:54:43 +0300 Subject: [PATCH 027/259] pass inferred types to plier --- mlir-compiler/lowering.cpp | 6 ++++++ numba/core/typed_passes.py | 1 + 2 files changed, 7 insertions(+) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index b985d9679f3..9ea1eb0731d 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -540,6 +540,7 @@ struct plier_lowerer : public lowerer_base py::bytes lower(const py::object& compilation_context, const py::object& func_ir) { auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); + typemap = compilation_context["typemap"]; // auto name = compilation_context["fnname"]().cast(); auto name = "test"; auto typ = get_func_type(/*compilation_context["fndesc"]*/nullptr); @@ -558,6 +559,7 @@ struct plier_lowerer : public lowerer_base // auto llvmmod = mlir::translateModuleToLLVMIR(mod); // // llvmmod->dump(); // return py::bytes(serialize_mod(*llvmmod)); + return {}; } private: @@ -574,6 +576,7 @@ struct plier_lowerer : public lowerer_base }; llvm::SmallVector outgoing_phi_nodes; }; + py::handle typemap; std::unordered_map block_infos; @@ -833,6 +836,8 @@ struct plier_lowerer : public lowerer_base void storevar(mlir::Value val, const py::handle& inst) { vars_map[inst.attr("name").cast()] = val; + auto type = typemap(inst); + val.setType(plier::PyType::get(&ctx, py::str(type).cast())); } mlir::Value loadvar(const py::handle& inst) @@ -952,6 +957,7 @@ struct plier_lowerer : public lowerer_base } } } + }; } diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 794a941f8c6..7d924018360 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -474,6 +474,7 @@ def __init__(self): def run_pass(self, state): ctx = {} + ctx['typemap'] = lambda op: state.typemap[op.name] # ctx['fndesc'] = lambda: fndesc # ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) # ctx['fnname'] = lambda: fndesc.mangled_name From 7e5ecaa3272e433c9f84d5636cecee1891dd2b10 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 21:00:14 +0300 Subject: [PATCH 028/259] some cleanup --- mlir-compiler/lowering.cpp | 16 ++-------------- numba/core/typed_passes.py | 2 +- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 9ea1eb0731d..020872be348 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -541,8 +541,7 @@ struct plier_lowerer : public lowerer_base { auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); typemap = compilation_context["typemap"]; -// auto name = compilation_context["fnname"]().cast(); - auto name = "test"; + auto name = compilation_context["fnname"]().cast(); auto typ = get_func_type(/*compilation_context["fndesc"]*/nullptr); func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); lower_func_body(func_ir); @@ -552,13 +551,6 @@ struct plier_lowerer : public lowerer_base { report_error("MLIR module validation failed"); } -// var_type_resolver = compilation_context["get_var_type"]; -// auto typ = get_func_type(compilation_context["fntype"]); -// func = builder.create(builder.getUnknownLoc(), name, typ); -// -// auto llvmmod = mlir::translateModuleToLLVMIR(mod); -// // llvmmod->dump(); -// return py::bytes(serialize_mod(*llvmmod)); return {}; } @@ -655,11 +647,7 @@ struct plier_lowerer : public lowerer_base } if(py::isinstance(value, insts.Var)) { - auto var = loadvar(value); - return var; -// return builder.create( -// builder.getUnknownLoc(), var, -// value.attr("name").cast()); + return loadvar(value); } if (py::isinstance(value, insts.Const)) { diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 7d924018360..a694da53f84 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -477,7 +477,7 @@ def run_pass(self, state): ctx['typemap'] = lambda op: state.typemap[op.name] # ctx['fndesc'] = lambda: fndesc # ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) - # ctx['fnname'] = lambda: fndesc.mangled_name + ctx['fnname'] = lambda: state.func_ir.func_id.func_qualname # ctx['get_var_type'] = lambda name: self.context.get_value_type(self.typeof(name)) import mlir_compiler mlir_compiler.lower_normal_function(ctx, state.func_ir) From 33f720e976046100be0945192ac014949cfb5df8 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 21:09:02 +0300 Subject: [PATCH 029/259] some refac --- mlir-compiler/lowering.cpp | 59 +++++++++++++------------------------- 1 file changed, 20 insertions(+), 39 deletions(-) diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 020872be348..69f3171c122 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -667,45 +667,25 @@ struct plier_lowerer : public lowerer_base mlir::Value lower_expr(const py::handle& expr) { auto op = expr.attr("op").cast(); - if (op == "binop") - { - return lower_binop(expr, expr.attr("fn")); - } - if (op == "cast") - { - return lower_simple(expr); - } - if (op == "call") - { - return lower_call(expr); - } - if (op == "phi") - { - return lower_phi(expr); - } - if (op == "build_tuple") - { - return lower_build_tuple(expr); - } - if (op == "static_getitem") - { - return lower_static_getitem(expr); - } - if (op == "getiter") - { - return lower_simple(expr); - } - if (op == "iternext") - { - return lower_simple(expr); - } - if (op == "pair_first") - { - return lower_simple(expr); - } - if (op == "pair_second") + using func_t = mlir::Value (plier_lowerer::*)(const py::handle&); + const std::pair handlers[] = { + {"binop", &plier_lowerer::lower_binop}, + {"cast", &plier_lowerer::lower_simple}, + {"call", &plier_lowerer::lower_call}, + {"phi", &plier_lowerer::lower_phi}, + {"build_tuple", &plier_lowerer::lower_build_tuple}, + {"static_getitem", &plier_lowerer::lower_static_getitem}, + {"getiter", &plier_lowerer::lower_simple}, + {"iternext", &plier_lowerer::lower_simple}, + {"pair_first", &plier_lowerer::lower_simple}, + {"pair_second", &plier_lowerer::lower_simple}, + }; + for (auto& h : handlers) { - return lower_simple(expr); + if (h.first == op) + { + return (this->*h.second)(expr); + } } report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); } @@ -789,8 +769,9 @@ struct plier_lowerer : public lowerer_base args_list, kwargs_list); } - mlir::Value lower_binop(const py::handle& expr, const py::handle& op) + mlir::Value lower_binop(const py::handle& expr) { + auto op = expr.attr("fn"); auto lhs_name = expr.attr("lhs"); auto rhs_name = expr.attr("rhs"); auto lhs = loadvar(lhs_name); From 73125ac783b986692a55174e51a611d5f18c94bc Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 21:32:15 +0300 Subject: [PATCH 030/259] remove old lowering --- mlir-compiler/dialect.cpp | 6 - mlir-compiler/include/plier/dialect.hpp | 2 - mlir-compiler/lowering.cpp | 356 +----------------------- 3 files changed, 2 insertions(+), 362 deletions(-) diff --git a/mlir-compiler/dialect.cpp b/mlir-compiler/dialect.cpp index c241dafaa7a..cf44b0292d5 100644 --- a/mlir-compiler/dialect.cpp +++ b/mlir-compiler/dialect.cpp @@ -33,12 +33,6 @@ struct PyTypeStorage : public mlir::TypeStorage }; } -void register_dialect() -{ - mlir::registerDialect(); - mlir::registerDialect(); -} - PlierDialect::PlierDialect(mlir::MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 2e0bacac6ed..164da6fd1da 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -24,8 +24,6 @@ namespace detail struct PyTypeStorage; } -void register_dialect(); - namespace types { enum Kind diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/lowering.cpp index 69f3171c122..b66531e9966 100644 --- a/mlir-compiler/lowering.cpp +++ b/mlir-compiler/lowering.cpp @@ -178,357 +178,6 @@ struct lowerer_base inst_handles insts; }; -struct lowerer : public lowerer_base -{ - lowerer(): - dialect(get_dialect(ctx)) - { - - } - - py::bytes lower(const py::object& compilation_context, const py::object& func_ir) - { - auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); - var_type_resolver = compilation_context["get_var_type"]; - auto typ = get_func_type(compilation_context["fntype"]); - auto name = compilation_context["fnname"]().cast(); - func = builder.create(builder.getUnknownLoc(), name, typ); - lower_func_body(func_ir); - mod.push_back(func); -// mod.dump(); - if (mlir::failed(mod.verify())) - { - report_error("MLIR module validation failed"); - } - auto llvmmod = mlir::translateModuleToLLVMIR(mod); -// llvmmod->dump(); - return py::bytes(serialize_mod(*llvmmod)); - } -private: - mllvm::LLVMDialect& dialect; - mllvm::LLVMFuncOp func; - mlir::Block* entry_bb = nullptr; - type_cache types; - std::unordered_map vars; - py::handle var_type_resolver; - - void lower_func_body(const py::object& func_ir) - { - entry_bb = func.addEntryBlock(); - assert(func.getNumArguments() >= 2); - fnargs = func.getArguments().slice(2); - auto ir_blocks = get_blocks(func_ir); - assert(!ir_blocks.empty()); - blocks.reserve(ir_blocks.size()); - for (std::size_t i = 0; i < ir_blocks.size(); ++i) - { - blocks.push_back(func.addBlock()); - blocks_map[ir_blocks[i].first] = blocks.back(); - } - - for (std::size_t i = 0; i < ir_blocks.size(); ++i) - { - lower_block(blocks[i], ir_blocks[i].second); - } - - builder.setInsertionPointToEnd(entry_bb); - builder.create(builder.getUnknownLoc(), mlir::None, blocks.front()); - } - - void lower_block(mlir::Block* bb, const py::handle& ir_block) - { - assert(nullptr != bb); - builder.setInsertionPointToEnd(bb); - for (auto it : get_body(ir_block)) - { - lower_inst(it); - } - } - - void lower_inst(const py::handle& inst) - { - if (py::isinstance(inst, insts.Assign)) - { - auto name = inst.attr("target").attr("name"); - auto val = lower_assign(inst, name); - storevar(val, inst, name); - } - else if (py::isinstance(inst, insts.Del)) - { - delvar(inst.attr("value")); - } - else if (py::isinstance(inst, insts.Return)) - { - retvar(inst.attr("value").attr("name")); - } - else if (py::isinstance(inst, insts.Branch)) - { - branch(inst.attr("cond").attr("name"), inst.attr("truebr"), inst.attr("falsebr")); - } - else if (py::isinstance(inst, insts.Jump)) - { - jump(inst.attr("target")); - } - else - { - report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); - } - } - - mllvm::LLVMType get_ll_type(const py::handle& name) - { - return parse_type(py::str(var_type_resolver(name)).cast()); - } - - mlir::Value resolve_op(mlir::Value lhs, mlir::Value rhs, const py::handle& op) - { - // TODO unhardcode - if (op.is(insts.add)) - { - return builder.create(builder.getUnknownLoc(), lhs, rhs); - } - if (op.is(insts.eq)) - { - assert(lhs.getType() == rhs.getType()); - if (lhs.getType().cast().isIntegerTy()) - { - return builder.create(builder.getUnknownLoc(), mllvm::ICmpPredicate::eq, lhs, rhs); - } - } - if (op.is(insts.gt)) - { - assert(lhs.getType() == rhs.getType()); - if (lhs.getType().cast().isIntegerTy()) - { - return builder.create(builder.getUnknownLoc(), mllvm::ICmpPredicate::sgt, lhs, rhs); - } - } - - report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); - } - - mlir::Value lower_binop(const py::handle& expr, const py::handle& op) - { - auto lhs_name = expr.attr("lhs").attr("name"); - auto rhs_name = expr.attr("rhs").attr("name"); - auto lhs = loadvar(lhs_name); - auto rhs = loadvar(rhs_name); - // TODO casts - return resolve_op(lhs, rhs, op); - } - - mlir::Value lower_call(const py::handle& expr) - { - auto args = expr.attr("args").cast(); - auto vararg = expr.attr("vararg"); - auto kws = expr.attr("kws"); - // TODO fold args - - // TODO: hardcode for bool - return loadvar(args[0].attr("name")); - } - - mlir::Value lower_expr(const py::handle& expr) - { - auto op = expr.attr("op").cast(); - if (op == "binop") - { - return lower_binop(expr, expr.attr("fn")); - } - else if (op == "cast") - { - auto val = loadvar(expr.attr("value").attr("name")); - // TODO cast - return val; - } - else if (op == "call") - { - return lower_call(expr); - } - else - { - report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); - } - } - - mlir::Value get_const_val(const py::handle& val) - { - if (py::isinstance(val)) - { - auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 64); - auto value = builder.getI64IntegerAttr(val.cast()); - return builder.create(builder.getUnknownLoc(), mlir_type, value); - } -// if (py::isinstance(val)) -// { -// auto b = val.cast(); -// auto mlir_type = mllvm::LLVMType::getInt1Ty(&dialect); -// auto value = builder.getBoolAttr(b); -// return builder.create(builder.getUnknownLoc(), mlir_type, value); -// } - - // assume it is a PyObject* - auto mlir_type = mllvm::LLVMType::getInt8Ty(&dialect).getPointerTo(); - return builder.create(builder.getUnknownLoc(), mlir_type); - -// report_error(llvm::Twine("get_const_val unhandled type \"") + py::str(val.get_type()).cast() + "\""); - } - - mlir::Value lower_assign(const py::handle& inst, const py::handle& name) - { - auto value = inst.attr("value"); - if (py::isinstance(value, insts.Arg)) - { - auto index = value.attr("index").cast(); - // TODO: incref - // TODO: cast - return fnargs[index]; - } - if(py::isinstance(value, insts.Expr)) - { - return lower_expr(value); - } - if(py::isinstance(value, insts.Var)) - { - auto var = loadvar(value.attr("name")); - - // TODO: cast - // TODO: incref - return var; - } - if (py::isinstance(value, insts.Const) || py::isinstance(value, insts.Global)) - { - // TODO unhardcode - // TODO incref -// auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 64); -// auto val = builder.getI64IntegerAttr(value.attr("value").cast()); -// return builder.create(builder.getUnknownLoc(), mlir_type, val); - return get_const_val(value.attr("value")); - } - - report_error(llvm::Twine("lower_assign not handled: \"") + py::str(value.get_type()).cast() + "\""); - } - - void alloca_var(const py::handle& name) - { - auto name_str = name.cast(); - if (0 == vars.count(name_str)) - { - scoped_goto_block s(builder, entry_bb); - auto size_type = mllvm::LLVMType::getIntNTy(&dialect, 64); - auto size_val = builder.getI64IntegerAttr(/*TODO*/1); - auto size = builder.create(builder.getUnknownLoc(), size_type, size_val); - auto type = get_ll_type(name); - auto ptype = type.getPointerTo(); - auto op = builder.create(builder.getUnknownLoc(), ptype, size, /*align*/0); - auto null = zero_val(type); - builder.create(builder.getUnknownLoc(), null, op); - vars[name_str] = op; - } - } - - mlir::Value get_var(const py::handle& name) - { - auto it = vars.find(name.cast()); - assert(vars.end() != it); - return it->second; - } - - mlir::Value loadvar(const py::handle& name) - { - auto type = get_ll_type(name); - return builder.create(builder.getUnknownLoc(), type, get_var(name)); - } - - void storevar(mlir::Value val, const py::handle& inst, const py::handle& name) - { - alloca_var(name); - auto old = loadvar(name); - // TODO decref old - auto ptr = get_var(name); - builder.create(builder.getUnknownLoc(), val, ptr); - } - - mlir::Value zero_val(mllvm::LLVMType type) - { - if (type.isPointerTy()) - { - return builder.create(builder.getUnknownLoc(), type); - } - else if (type.isIntegerTy()) - { - return builder.create(builder.getUnknownLoc(), type, builder.getI64IntegerAttr(0)); - } - else - { - report_error(llvm::Twine("zero_val unhandled type ") + to_str(type)); - } - } - - void delvar(const py::handle& name) - { - alloca_var(name); - auto ptr = get_var(name); - // TODO decref - - // TODO - auto type = get_ll_type(name); - auto null = zero_val(type); - builder.create(builder.getUnknownLoc(), null, ptr); - } - - void retvar(const py::handle& name) - { - alloca_var(name); - auto val = loadvar(name); - // TODO casts - - auto ret_ptr = func.getArgument(0); - builder.create(builder.getUnknownLoc(), val, ret_ptr); - - auto mlir_type = mllvm::LLVMType::getIntNTy(&dialect, 32); - mlir::Value ret = builder.create(builder.getUnknownLoc(), mlir_type, builder.getI32IntegerAttr(0)); - builder.create(builder.getUnknownLoc(), ret); - } - - void branch(const py::handle& cond, const py::handle& tr, const py::handle& fl) - { - auto c = loadvar(cond); - auto tr_block = blocks_map.find(tr.cast())->second; - auto fl_block = blocks_map.find(fl.cast())->second; - // TODO: casts - - builder.create(builder.getUnknownLoc(), c, tr_block, fl_block); - } - - void jump(const py::handle& target) - { - auto block = blocks_map.find(target.cast())->second; - builder.create(builder.getUnknownLoc(), mlir::None, block); - } - - mllvm::LLVMType parse_type(llvm::StringRef str) - { - return types.get_type(ctx, str); - } - - mllvm::LLVMType get_func_type(const py::handle& typedesc) - { - auto get_type = [&](const auto& h) { - return parse_type(py::str(h).cast()); - }; - auto p_func = typedesc(); - using Type = mllvm::LLVMType; - auto ret = get_type(p_func.attr("return_type")); - llvm::SmallVector args; - for (auto arg : p_func.attr("args")) - { - args.push_back(get_type(arg)); - } - return Type::getFunctionTy(ret, args, false); - } -}; - struct plier_lowerer : public lowerer_base { plier_lowerer(): @@ -932,8 +581,7 @@ struct plier_lowerer : public lowerer_base py::bytes lower_function(const py::object& compilation_context, const py::object& func_ir) { - mlir::registerDialect(); - plier::register_dialect(); + mlir::registerDialect(); + mlir::registerDialect(); return plier_lowerer().lower(compilation_context, func_ir); -// return lowerer().lower(compilation_context, func_ir); } From 7c353191ce77d0ee0bec28c453ef60bed683cd53 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 21:34:45 +0300 Subject: [PATCH 031/259] change structure --- mlir-compiler/CMakeLists.txt | 12 ++++++------ mlir-compiler/{ => src}/dialect.cpp | 0 mlir-compiler/{ => src}/lowering.cpp | 0 mlir-compiler/{ => src}/lowering.hpp | 0 mlir-compiler/{ => src}/module.cpp | 0 mlir-compiler/{ => src}/type_parser.cpp | 0 mlir-compiler/{ => src}/type_parser.hpp | 0 7 files changed, 6 insertions(+), 6 deletions(-) rename mlir-compiler/{ => src}/dialect.cpp (100%) rename mlir-compiler/{ => src}/lowering.cpp (100%) rename mlir-compiler/{ => src}/lowering.hpp (100%) rename mlir-compiler/{ => src}/module.cpp (100%) rename mlir-compiler/{ => src}/type_parser.cpp (100%) rename mlir-compiler/{ => src}/type_parser.hpp (100%) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index a563f5f39ca..50e58ac581c 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -20,14 +20,14 @@ include(HandleLLVMOptions) add_subdirectory(include/plier) set(SOURCES_LIST - dialect.cpp - lowering.cpp - module.cpp - type_parser.cpp + src/dialect.cpp + src/lowering.cpp + src/module.cpp + src/type_parser.cpp ) set(HEADERS_LIST - lowering.hpp - type_parser.hpp + src/lowering.hpp + src/type_parser.hpp include/plier/dialect.hpp include/plier/PlierOps.td ) diff --git a/mlir-compiler/dialect.cpp b/mlir-compiler/src/dialect.cpp similarity index 100% rename from mlir-compiler/dialect.cpp rename to mlir-compiler/src/dialect.cpp diff --git a/mlir-compiler/lowering.cpp b/mlir-compiler/src/lowering.cpp similarity index 100% rename from mlir-compiler/lowering.cpp rename to mlir-compiler/src/lowering.cpp diff --git a/mlir-compiler/lowering.hpp b/mlir-compiler/src/lowering.hpp similarity index 100% rename from mlir-compiler/lowering.hpp rename to mlir-compiler/src/lowering.hpp diff --git a/mlir-compiler/module.cpp b/mlir-compiler/src/module.cpp similarity index 100% rename from mlir-compiler/module.cpp rename to mlir-compiler/src/module.cpp diff --git a/mlir-compiler/type_parser.cpp b/mlir-compiler/src/type_parser.cpp similarity index 100% rename from mlir-compiler/type_parser.cpp rename to mlir-compiler/src/type_parser.cpp diff --git a/mlir-compiler/type_parser.hpp b/mlir-compiler/src/type_parser.hpp similarity index 100% rename from mlir-compiler/type_parser.hpp rename to mlir-compiler/src/type_parser.hpp From 3780e11b0194fa5a974e75d6ac8b9f10a5b194ce Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:41:59 +0300 Subject: [PATCH 032/259] some refac --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/src/lowering.cpp | 82 ++++++++++++++-------------------- mlir-compiler/src/utils.cpp | 11 +++++ mlir-compiler/src/utils.hpp | 8 ++++ 4 files changed, 54 insertions(+), 49 deletions(-) create mode 100644 mlir-compiler/src/utils.cpp create mode 100644 mlir-compiler/src/utils.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 50e58ac581c..0868b675266 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -24,10 +24,12 @@ set(SOURCES_LIST src/lowering.cpp src/module.cpp src/type_parser.cpp + src/utils.cpp ) set(HEADERS_LIST src/lowering.hpp src/type_parser.hpp + src/utils.hpp include/plier/dialect.hpp include/plier/PlierOps.td ) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index b66531e9966..83630d239a8 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -2,50 +2,33 @@ #include #include -#include #include #include #include -#include - #include #include #include -#include - -#include - -#include "type_parser.hpp" - -#include - #include #include "plier/dialect.hpp" -#include +#include "utils.hpp" namespace py = pybind11; -namespace mllvm = mlir::LLVM; namespace { -[[noreturn]] void report_error(const llvm::Twine& msg) -{ - auto str = msg.str(); - throw std::exception(str.c_str()); -} -std::string serialize_mod(const llvm::Module& mod) -{ - std::string ret; - llvm::raw_string_ostream stream(ret); -// mod.print(stream, nullptr); - llvm::WriteBitcodeToFile(mod, stream); - stream.flush(); - return ret; -} +//std::string serialize_mod(const llvm::Module& mod) +//{ +// std::string ret; +// llvm::raw_string_ostream stream(ret); +//// mod.print(stream, nullptr); +// llvm::WriteBitcodeToFile(mod, stream); +// stream.flush(); +// return ret; +//} template std::string to_str(T& obj) @@ -143,27 +126,27 @@ struct inst_handles py::handle gt; }; -struct type_cache -{ - using Type = mllvm::LLVMType; +//struct type_cache +//{ +// using Type = mllvm::LLVMType; - Type get_type(mlir::MLIRContext& context, llvm::StringRef str) - { - assert(!str.empty()); - auto s = str.str(); - auto it = typemap.find(s); - if (typemap.end() != it) - { - return it->second; - } - auto type = parse_type(context, str); - typemap[s] = type; - return type; - } +// Type get_type(mlir::MLIRContext& context, llvm::StringRef str) +// { +// assert(!str.empty()); +// auto s = str.str(); +// auto it = typemap.find(s); +// if (typemap.end() != it) +// { +// return it->second; +// } +// auto type = parse_type(context, str); +// typemap[s] = type; +// return type; +// } -private: - std::unordered_map typemap; -}; +//private: +// std::unordered_map typemap; +//}; struct lowerer_base { @@ -186,7 +169,7 @@ struct plier_lowerer : public lowerer_base } - py::bytes lower(const py::object& compilation_context, const py::object& func_ir) + mlir::ModuleOp lower(const py::object& compilation_context, const py::object& func_ir) { auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); typemap = compilation_context["typemap"]; @@ -201,7 +184,7 @@ struct plier_lowerer : public lowerer_base report_error("MLIR module validation failed"); } - return {}; + return mod; } private: plier::PlierDialect& dialect; @@ -583,5 +566,6 @@ py::bytes lower_function(const py::object& compilation_context, const py::object { mlir::registerDialect(); mlir::registerDialect(); - return plier_lowerer().lower(compilation_context, func_ir); + auto mod = plier_lowerer().lower(compilation_context, func_ir); + return {}; } diff --git a/mlir-compiler/src/utils.cpp b/mlir-compiler/src/utils.cpp new file mode 100644 index 00000000000..6b760cc1125 --- /dev/null +++ b/mlir-compiler/src/utils.cpp @@ -0,0 +1,11 @@ +#include "utils.hpp" + +#include + +#include "llvm/ADT/Twine.h" + +void report_error(const llvm::Twine& msg) +{ + auto str = msg.str(); + throw std::exception(str.c_str()); +} diff --git a/mlir-compiler/src/utils.hpp b/mlir-compiler/src/utils.hpp new file mode 100644 index 00000000000..adb98e2457a --- /dev/null +++ b/mlir-compiler/src/utils.hpp @@ -0,0 +1,8 @@ +#pragma once + +namespace llvm +{ +class Twine; +} + +[[noreturn]] void report_error(const llvm::Twine& msg); From 450b8f3dccf208287c36fd6dc4b36cdc1010360a Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 033/259] refac --- mlir-compiler/src/lowering.cpp | 48 +++++++++++++++++----------------- mlir-compiler/src/lowering.hpp | 9 +++++-- 2 files changed, 31 insertions(+), 26 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 83630d239a8..b769acccd7e 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -30,15 +30,15 @@ namespace // return ret; //} -template -std::string to_str(T& obj) -{ - std::string ret; - llvm::raw_string_ostream stream(ret); - obj.print(stream); - stream.flush(); - return ret; -} +//template +//std::string to_str(T& obj) +//{ +// std::string ret; +// llvm::raw_string_ostream stream(ret); +// obj.print(stream); +// stream.flush(); +// return ret; +//} template T& get_dialect(mlir::MLIRContext& ctx) @@ -65,23 +65,23 @@ py::list get_body(const py::handle& block) return block.attr("body").cast(); } -struct scoped_goto_block -{ - scoped_goto_block(mlir::OpBuilder& b, mlir::Block* new_block): - builder(b), - old_block(b.getBlock()) - { - builder.setInsertionPointToEnd(new_block); - } +//struct scoped_goto_block +//{ +// scoped_goto_block(mlir::OpBuilder& b, mlir::Block* new_block): +// builder(b), +// old_block(b.getBlock()) +// { +// builder.setInsertionPointToEnd(new_block); +// } - ~scoped_goto_block() - { - builder.setInsertionPointToEnd(old_block); - } +// ~scoped_goto_block() +// { +// builder.setInsertionPointToEnd(old_block); +// } - mlir::OpBuilder& builder; - mlir::Block* old_block = nullptr; -}; +// mlir::OpBuilder& builder; +// mlir::Block* old_block = nullptr; +//}; struct inst_handles { diff --git a/mlir-compiler/src/lowering.hpp b/mlir-compiler/src/lowering.hpp index d5ee4e5d9b2..ce709920840 100644 --- a/mlir-compiler/src/lowering.hpp +++ b/mlir-compiler/src/lowering.hpp @@ -1,5 +1,10 @@ #pragma once -#include +namespace pybind11 +{ +class bytes; +class object; +} -pybind11::bytes lower_function(const pybind11::object& compilation_context, const pybind11::object& func_ir); +pybind11::bytes lower_function(const pybind11::object& compilation_context, + const pybind11::object& func_ir); From af98ac8618123f67802f3ad080b98c267e6f8a4e Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 034/259] update tests --- mlir-compiler/test.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 048911e5614..2bb322eacc2 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -2,6 +2,7 @@ _tests_total = 0 _tests_passes = 0 +_failed_tests = [] def ret(a): return a @@ -49,8 +50,10 @@ def loop(n): def test(func, params): global _tests_total global _tests_passes + global _failed_tests _tests_total += 1 - print('test', func.__name__, params, '... ', end='') + test_name = f'{func.__name__} {params}' + print('test', test_name, '... ', end='') result = func(*params) wrapped = numba.njit()(func) try: @@ -62,6 +65,7 @@ def test(func, params): except Exception as e: print(e) print('FAILED') + _failed_tests.append(test_name) test(ret, (7,)) @@ -77,3 +81,7 @@ def test(func, params): test(loop, (8,)) print(f'Tests passed: {_tests_passes}/{_tests_total}') +if (len(_failed_tests) != 0): + print('Failed:') + for t in _failed_tests: + print(t) From 6981432529e493b98c7535acb5cd6c823104e9b7 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 035/259] basic compiler infrastructure --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/include/plier/PlierOps.td | 2 +- mlir-compiler/src/compiler.cpp | 41 ++++++++++++ mlir-compiler/src/compiler.hpp | 27 ++++++++ mlir-compiler/src/dialect.cpp | 8 +-- mlir-compiler/src/lowering.cpp | 87 ++++++++++++++++--------- 6 files changed, 132 insertions(+), 35 deletions(-) create mode 100644 mlir-compiler/src/compiler.cpp create mode 100644 mlir-compiler/src/compiler.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 0868b675266..29df83a6c22 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -20,6 +20,7 @@ include(HandleLLVMOptions) add_subdirectory(include/plier) set(SOURCES_LIST + src/compiler.cpp src/dialect.cpp src/lowering.cpp src/module.cpp @@ -27,6 +28,7 @@ set(SOURCES_LIST src/utils.cpp ) set(HEADERS_LIST + src/compiler.hpp src/lowering.hpp src/type_parser.hpp src/utils.hpp diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index e35a94f7986..f7535c771cf 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -68,7 +68,7 @@ def CastOp : Plier_Op<"cast", []> { let arguments = (ins Plier_PyType:$value); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value val"> diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp new file mode 100644 index 00000000000..fe9a727541b --- /dev/null +++ b/mlir-compiler/src/compiler.cpp @@ -0,0 +1,41 @@ +#include "compiler.hpp" + +#include +#include +#include + +#include "utils.hpp" + +class CompilerContext::CompilerContextImpl +{ +public: + CompilerContextImpl(mlir::MLIRContext& ctx): + pm(&ctx) + { + auto& funcPm = pm.nest(); + // TODO + } + + mlir::PassManager& get_pm() { return pm; } +private: + mlir::PassManager pm; +}; + +CompilerContext::CompilerContext(mlir::MLIRContext& ctx): + impl(std::make_unique(ctx)) +{ + +} + +CompilerContext::~CompilerContext() +{ + +} + +void CompilerContext::run(mlir::ModuleOp module) +{ + if (mlir::failed(impl->get_pm().run(module))) + { + report_error("Compiler pipeline failed"); + } +} diff --git a/mlir-compiler/src/compiler.hpp b/mlir-compiler/src/compiler.hpp new file mode 100644 index 00000000000..67f6a9591a6 --- /dev/null +++ b/mlir-compiler/src/compiler.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include + +namespace mlir +{ +class MLIRContext; +class ModuleOp; +} + +class CompilerContext +{ +public: + class CompilerContextImpl; + + CompilerContext(mlir::MLIRContext& ctx); + ~CompilerContext(); + + CompilerContext(CompilerContext&&) = default; + + void run(mlir::ModuleOp module); + +private: + std::unique_ptr impl; +}; + +void run_compiler(CompilerContext& context, mlir::ModuleOp module); diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index cf44b0292d5..00aedb6f409 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -1,10 +1,10 @@ #include "plier/dialect.hpp" -#include "mlir/IR/DialectImplementation.h" -#include "mlir/IR/StandardTypes.h" -#include "mlir/IR/Builders.h" +#include +#include +#include -#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include namespace plier { diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index b769acccd7e..69c0ac1f5c3 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -12,8 +12,10 @@ #include #include + #include "plier/dialect.hpp" +#include "compiler.hpp" #include "utils.hpp" namespace py = pybind11; @@ -30,15 +32,15 @@ namespace // return ret; //} -//template -//std::string to_str(T& obj) -//{ -// std::string ret; -// llvm::raw_string_ostream stream(ret); -// obj.print(stream); -// stream.flush(); -// return ret; -//} +template +std::string to_str(T& obj) +{ + std::string ret; + llvm::raw_string_ostream stream(ret); + obj.print(stream); + stream.flush(); + return ret; +} template T& get_dialect(mlir::MLIRContext& ctx) @@ -150,10 +152,10 @@ struct inst_handles struct lowerer_base { - lowerer_base(): builder(&ctx) {} + lowerer_base(mlir::MLIRContext& context): ctx(context), builder(&ctx) {} protected: - mlir::MLIRContext ctx; + mlir::MLIRContext& ctx; mlir::OpBuilder builder; mlir::Block::BlockArgListType fnargs; std::vector blocks; @@ -163,7 +165,8 @@ struct lowerer_base struct plier_lowerer : public lowerer_base { - plier_lowerer(): + plier_lowerer(mlir::MLIRContext& context): + lowerer_base(context), dialect(get_dialect(ctx)) { @@ -178,12 +181,6 @@ struct plier_lowerer : public lowerer_base func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); lower_func_body(func_ir); mod.push_back(func); - mod.dump(); - if (mlir::failed(mod.verify())) - { - report_error("MLIR module validation failed"); - } - return mod; } private: @@ -200,10 +197,17 @@ struct plier_lowerer : public lowerer_base }; llvm::SmallVector outgoing_phi_nodes; }; + py::handle current_instr; py::handle typemap; std::unordered_map block_infos; + plier::PyType get_type(const py::handle& inst) const + { + auto type = typemap(inst); + return plier::PyType::get(&ctx, py::str(type).cast()); + } + void lower_func_body(const py::object& func_ir) { auto ir_blocks = get_blocks(func_ir); @@ -230,7 +234,9 @@ struct plier_lowerer : public lowerer_base builder.setInsertionPointToEnd(bb); for (auto it : get_body(ir_block)) { + current_instr = it; lower_inst(it); + current_instr = nullptr; } } @@ -349,17 +355,17 @@ struct plier_lowerer : public lowerer_base return builder.create(builder.getUnknownLoc(), args); } - mlir::Value lower_phi(const py::handle& inst) + mlir::Value lower_phi(const py::handle& expr) { - auto incoming_vals = inst.attr("incoming_values").cast(); - auto incoming_blocks = inst.attr("incoming_blocks").cast(); + auto incoming_vals = expr.attr("incoming_values").cast(); + auto incoming_blocks = expr.attr("incoming_blocks").cast(); assert(incoming_vals.size() == incoming_blocks.size()); auto current_block = builder.getBlock(); assert(nullptr != current_block); auto arg_index = current_block->getNumArguments(); - auto arg = current_block->addArgument(plier::PyType::get(&ctx)); + auto arg = current_block->addArgument(get_type(current_instr.attr("target"))); auto count = incoming_vals.size(); for (std::size_t i = 0; i < count; ++i) @@ -437,8 +443,7 @@ struct plier_lowerer : public lowerer_base void storevar(mlir::Value val, const py::handle& inst) { vars_map[inst.attr("name").cast()] = val; - auto type = typemap(inst); - val.setType(plier::PyType::get(&ctx, py::str(type).cast())); + val.setType(get_type(inst)); } mlir::Value loadvar(const py::handle& inst) @@ -458,6 +463,19 @@ struct plier_lowerer : public lowerer_base { auto var = loadvar(inst); builder.create(builder.getUnknownLoc(), var); + auto func_type = func.getType(); + auto ret_type = func_type.getResult(0); + auto new_ret_type = var.getType(); + if (ret_type != new_ret_type) + { + auto def_type = plier::PyType::get(&ctx); + if (ret_type != def_type) + { + report_error(llvm::Twine("Conflicting return types: ") + to_str(ret_type) + " and " + to_str(new_ret_type)); + } + auto new_func_type = mlir::FunctionType::get(func_type.getInputs(), new_ret_type, &ctx); + func.setType(new_func_type); + } } void branch(const py::handle& cond, const py::handle& tr, const py::handle& fl) @@ -465,7 +483,8 @@ struct plier_lowerer : public lowerer_base auto c = loadvar(cond); auto tr_block = blocks_map.find(tr.cast())->second; auto fl_block = blocks_map.find(fl.cast())->second; - builder.create(builder.getUnknownLoc(), c, tr_block, fl_block); + auto cond_val = builder.create(builder.getUnknownLoc(), mlir::IntegerType::get(1, &ctx), c); + builder.create(builder.getUnknownLoc(), cond_val, tr_block, fl_block); } void jump(const py::handle& target) @@ -508,13 +527,16 @@ struct plier_lowerer : public lowerer_base { if (o.dest_block == block) { - if (list.size() <= o.arg_index) + auto arg_index = o.arg_index; + if (list.size() <= arg_index) { - list.resize(o.arg_index + 1); + list.resize(arg_index + 1); } auto it = vars_map.find(o.var_name); assert(vars_map.end() != it); - list[o.arg_index] = it->second; + auto arg_type = block->getArgument(arg_index).getType(); + auto val = builder.create(builder.getUnknownLoc(), arg_type, it->second); + list[arg_index] = val; } } }; @@ -527,7 +549,7 @@ struct plier_lowerer : public lowerer_base auto term = bb.getTerminator(); if (nullptr == term) { - report_error("broken ir: block withoout terminator"); + report_error("broken ir: block without terminator"); } builder.setInsertionPointToEnd(&bb); @@ -566,6 +588,11 @@ py::bytes lower_function(const py::object& compilation_context, const py::object { mlir::registerDialect(); mlir::registerDialect(); - auto mod = plier_lowerer().lower(compilation_context, func_ir); + mlir::MLIRContext context; + auto mod = plier_lowerer(context).lower(compilation_context, func_ir); + mod.dump(); + CompilerContext compiler(context); + compiler.run(mod); +// mod.dump(); return {}; } From 7fd5f80f570d2a9a29777f2a7fe2cc062e36cf16 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 036/259] cast folding --- mlir-compiler/CMakeLists.txt | 1 + mlir-compiler/include/plier/PlierOps.td | 1 + mlir-compiler/src/compiler.cpp | 16 +++++++++++----- mlir-compiler/src/dialect.cpp | 11 +++++++++++ mlir-compiler/src/lowering.cpp | 2 +- 5 files changed, 25 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 29df83a6c22..6c7d9dd9ab5 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -50,6 +50,7 @@ target_link_libraries(${PROJECT_NAME} MLIRLLVMIR MLIRStandardOps MLIRTargetLLVMIR + MLIRTransforms ) target_include_directories(${PROJECT_NAME} diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index f7535c771cf..d201e70805f 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -69,6 +69,7 @@ def CastOp : Plier_Op<"cast", []> { Plier_PyType:$value); let results = (outs AnyType); + let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value val"> diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index fe9a727541b..6c6e7fe6a52 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -2,7 +2,9 @@ #include #include +#include #include +#include #include "utils.hpp" @@ -12,11 +14,18 @@ class CompilerContext::CompilerContextImpl CompilerContextImpl(mlir::MLIRContext& ctx): pm(&ctx) { + pm.addNestedPass(mlir::createCanonicalizerPass()); auto& funcPm = pm.nest(); // TODO } - mlir::PassManager& get_pm() { return pm; } + void run(mlir::ModuleOp& module) + { + if (mlir::failed(pm.run(module))) + { + report_error("Compiler pipeline failed"); + } + } private: mlir::PassManager pm; }; @@ -34,8 +43,5 @@ CompilerContext::~CompilerContext() void CompilerContext::run(mlir::ModuleOp module) { - if (mlir::failed(impl->get_pm().run(module))) - { - report_error("Compiler pipeline failed"); - } + impl->run(module); } diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 00aedb6f409..86d46230ff7 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -94,6 +94,17 @@ void CastOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, CastOp::build(builder, state, PyType::get(state.getContext()), val); } +mlir::OpFoldResult CastOp::fold(llvm::ArrayRef /*operands*/) +{ + auto op_type = getOperand().getType(); + auto ret_type = getType(); + if (op_type == ret_type && op_type != PyType::get(getContext())) + { + return getOperand(); + } + return nullptr; +} + void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func, mlir::ValueRange args, mlir::ArrayRef> kwargs) { diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 69c0ac1f5c3..1f91ce7c6ff 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -593,6 +593,6 @@ py::bytes lower_function(const py::object& compilation_context, const py::object mod.dump(); CompilerContext compiler(context); compiler.run(mod); -// mod.dump(); + mod.dump(); return {}; } From 8905592e9480d8f35d729f18df54a6dbca637d25 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 037/259] refac --- mlir-compiler/include/plier/dialect.hpp | 3 +- mlir-compiler/src/dialect.cpp | 44 +++++++++++++++++-------- mlir-compiler/src/lowering.cpp | 6 ++-- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 164da6fd1da..b649b6ffcc7 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -41,7 +41,8 @@ class PyType : public mlir::Type::TypeBasename; @@ -69,36 +75,40 @@ llvm::StringRef PyType::getName() const void ArgOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, unsigned index, mlir::StringRef name) { - ArgOp::build(builder, state, PyType::get(state.getContext()), + ArgOp::build(builder, state, PyType::getUndefined(state.getContext()), llvm::APInt(32, index), name); } void ConstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Attribute val) { - ConstOp::build(builder, state, PyType::get(state.getContext()), val); + ConstOp::build(builder, state, PyType::getUndefined(state.getContext()), + val); } void GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::StringRef name) { - GlobalOp::build(builder, state, PyType::get(state.getContext()), name); + GlobalOp::build(builder, state, PyType::getUndefined(state.getContext()), + name); } void BinOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value lhs, mlir::Value rhs, mlir::StringRef op) { - BinOp::build(builder, state, PyType::get(state.getContext()), lhs, rhs, op); + BinOp::build(builder, state, PyType::getUndefined(state.getContext()), lhs, + rhs, op); } void CastOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value val) { - CastOp::build(builder, state, PyType::get(state.getContext()), val); + CastOp::build(builder, state, PyType::getUndefined(state.getContext()), + val); } mlir::OpFoldResult CastOp::fold(llvm::ArrayRef /*operands*/) { auto op_type = getOperand().getType(); auto ret_type = getType(); - if (op_type == ret_type && op_type != PyType::get(getContext())) + if (op_type == ret_type && op_type != PyType::getUndefined(getContext())) { return getOperand(); } @@ -120,46 +130,52 @@ void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func kw_names.push_back(mlir::StringAttr::get(a.first, ctx)); all_args.push_back(a.second); } - PyCallOp::build(builder, state, PyType::get(state.getContext()), func, - all_args, kw_start, mlir::ArrayAttr::get(kw_names, ctx)); + PyCallOp::build(builder, state, PyType::getUndefined(state.getContext()), + func, all_args, kw_start, mlir::ArrayAttr::get(kw_names, ctx)); } void BuildTupleOp::build(OpBuilder &builder, OperationState &state, ::mlir::ValueRange args) { - BuildTupleOp::build(builder, state, PyType::get(state.getContext()), args); + BuildTupleOp::build(builder, state, + PyType::getUndefined(state.getContext()), args); } void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value, ::mlir::Value index_var, unsigned int index) { - StaticGetItemOp::build(builder, state, PyType::get(state.getContext()), + StaticGetItemOp::build(builder, state, + PyType::getUndefined(state.getContext()), value, index_var, llvm::APInt(32, index)); } void GetiterOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value) { - GetiterOp::build(builder, state, PyType::get(state.getContext()), value); + GetiterOp::build(builder, state, PyType::getUndefined(state.getContext()), + value); } void IternextOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value) { - IternextOp::build(builder, state, PyType::get(state.getContext()), value); + IternextOp::build(builder, state, PyType::getUndefined(state.getContext()), + value); } void PairfirstOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value) { - PairfirstOp::build(builder, state, PyType::get(state.getContext()), value); + PairfirstOp::build(builder, state, PyType::getUndefined(state.getContext()), + value); } void PairsecondOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value) { - PairsecondOp::build(builder, state, PyType::get(state.getContext()), value); + PairsecondOp::build(builder, state, + PyType::getUndefined(state.getContext()), value); } diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 1f91ce7c6ff..d749ad305ae 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -468,7 +468,7 @@ struct plier_lowerer : public lowerer_base auto new_ret_type = var.getType(); if (ret_type != new_ret_type) { - auto def_type = plier::PyType::get(&ctx); + auto def_type = plier::PyType::getUndefined(&ctx); if (ret_type != def_type) { report_error(llvm::Twine("Conflicting return types: ") + to_str(ret_type) + " and " + to_str(new_ret_type)); @@ -506,11 +506,11 @@ struct plier_lowerer : public lowerer_base { auto get_type = [&](const auto& h) { // return parse_type(py::str(h).cast()); - return plier::PyType::get(&ctx); + return plier::PyType::getUndefined(&ctx); }; // auto p_func = typedesc(); // auto ret = get_type(p_func.attr("restype")); - auto ret = plier::PyType::get(&ctx); + auto ret = plier::PyType::getUndefined(&ctx); llvm::SmallVector args; // for (auto arg : p_func.attr("argtypes")) // { From ac64f460a3846c23d479c29a1232387ad7bb6da8 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 038/259] refac --- mlir-compiler/include/plier/dialect.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index b649b6ffcc7..8963740386c 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -1,10 +1,10 @@ #pragma once -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Function.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" +#include +#include +#include +#include +#include #include "plier/PlierOpsEnums.h.inc" From be9babc36474417aee93ea6e9a19a466062a8fde Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 039/259] pass stub --- mlir-compiler/CMakeLists.txt | 2 ++ mlir-compiler/src/passes/plier_to_std.cpp | 23 +++++++++++++++++++++++ mlir-compiler/src/passes/plier_to_std.hpp | 10 ++++++++++ 3 files changed, 35 insertions(+) create mode 100644 mlir-compiler/src/passes/plier_to_std.cpp create mode 100644 mlir-compiler/src/passes/plier_to_std.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 6c7d9dd9ab5..06252e5735a 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -24,12 +24,14 @@ set(SOURCES_LIST src/dialect.cpp src/lowering.cpp src/module.cpp + src/passes/plier_to_std.cpp src/type_parser.cpp src/utils.cpp ) set(HEADERS_LIST src/compiler.hpp src/lowering.hpp + src/passes/plier_to_std.hpp src/type_parser.hpp src/utils.hpp include/plier/dialect.hpp diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp new file mode 100644 index 00000000000..16eaed6fbca --- /dev/null +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -0,0 +1,23 @@ +#include "plier_to_std.hpp" + +#include +#include +#include +#include + +namespace +{ +struct PlierToStdPass : + public mlir::PassWrapper +{ + void runOnFunction() + { + + } +}; +} + +std::unique_ptr createPlierToStdPass() +{ + return std::make_unique(); +} diff --git a/mlir-compiler/src/passes/plier_to_std.hpp b/mlir-compiler/src/passes/plier_to_std.hpp new file mode 100644 index 00000000000..b43c608fb81 --- /dev/null +++ b/mlir-compiler/src/passes/plier_to_std.hpp @@ -0,0 +1,10 @@ +#pragma once + +#include + +namespace mlir +{ +class Pass; +} + +std::unique_ptr createPlierToStdPass(); From 8ca1f7c4e24d14ae1159b76679a3e38a094df1fa Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 040/259] update to latest mlir --- mlir-compiler/CMakeLists.txt | 2 +- mlir-compiler/include/plier/dialect.hpp | 18 ++++----------- mlir-compiler/src/dialect.cpp | 29 +++++++++++-------------- mlir-compiler/src/lowering.cpp | 5 +++-- 4 files changed, 21 insertions(+), 33 deletions(-) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 06252e5735a..0e84c685878 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -50,7 +50,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE MLIRSupport MLIRLLVMIR - MLIRStandardOps + MLIRStandard MLIRTargetLLVMIR MLIRTransforms ) diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 8963740386c..451652df992 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -11,10 +11,11 @@ namespace plier { using namespace mlir; // TODO: remove +} + #include "plier/PlierOpsDialect.h.inc" #define GET_OP_CLASSES #include "plier/PlierOps.h.inc" -} namespace plier { @@ -24,23 +25,12 @@ namespace detail struct PyTypeStorage; } -namespace types -{ -enum Kind -{ - // Dialect types. - PyType = mlir::Type::FIRST_PRIVATE_EXPERIMENTAL_3_TYPE, -}; -} - -class PyType : public mlir::Type::TypeBase +class PyType : public mlir::Type::TypeBase<::plier::PyType, mlir::Type, + ::plier::detail::PyTypeStorage> { public: using Base::Base; - static bool kindof(unsigned kind) { return kind == types::PyType; } - static PyType get(mlir::MLIRContext *context, mlir::StringRef name); static PyType getUndefined(mlir::MLIRContext *context); diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 40973e220d7..b7a8a1152b1 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -5,6 +5,7 @@ #include #include +#include namespace plier { @@ -33,8 +34,8 @@ struct PyTypeStorage : public mlir::TypeStorage }; } -PlierDialect::PlierDialect(mlir::MLIRContext *context) - : Dialect(getDialectNamespace(), context) { +void PlierDialect::initialize() +{ addOperations< #define GET_OP_LIST #include "plier/PlierOps.cpp.inc" @@ -48,24 +49,20 @@ mlir::Type PlierDialect::parseType(mlir::DialectAsmParser &parser) const { } void PlierDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { - switch (type.getKind()) { - case plier::types::PyType: - os << "PyType<" << type.cast().getName() << ">"; - return; - default: - llvm_unreachable("unexpected type kind"); - } + llvm::TypeSwitch(type) + .Case([&](auto t){ os << "PyType<" << t.getName() << ">"; }) + .Default([](auto){ llvm_unreachable("unexpected type"); }); } -PyType PyType::get(MLIRContext* context, llvm::StringRef name) +PyType PyType::get(mlir::MLIRContext* context, llvm::StringRef name) { assert(!name.empty()); - return Base::get(context, types::PyType, name); + return Base::get(context, name); } PyType PyType::getUndefined(MLIRContext* context) { - return Base::get(context, types::PyType, ""); + return Base::get(context, ""); } llvm::StringRef PyType::getName() const @@ -76,7 +73,7 @@ llvm::StringRef PyType::getName() const void ArgOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, unsigned index, mlir::StringRef name) { ArgOp::build(builder, state, PyType::getUndefined(state.getContext()), - llvm::APInt(32, index), name); + index, name); } void ConstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, @@ -122,7 +119,7 @@ void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func mlir::SmallVector all_args; all_args.reserve(args.size() + kwargs.size()); std::copy(args.begin(), args.end(), std::back_inserter(all_args)); - auto kw_start = llvm::APInt(32, all_args.size()); + auto kw_start = static_cast(all_args.size()); mlir::SmallVector kw_names; kw_names.reserve(kwargs.size()); for (auto& a : kwargs) @@ -147,7 +144,7 @@ void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, { StaticGetItemOp::build(builder, state, PyType::getUndefined(state.getContext()), - value, index_var, llvm::APInt(32, index)); + value, index_var, index); } void GetiterOp::build(OpBuilder &builder, OperationState &state, @@ -178,9 +175,9 @@ void PairsecondOp::build(OpBuilder &builder, OperationState &state, PyType::getUndefined(state.getContext()), value); } +} #define GET_OP_CLASSES #include "plier/PlierOps.cpp.inc" -} #include "plier/PlierOpsEnums.cpp.inc" diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index d749ad305ae..b4028d09d64 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -45,7 +45,7 @@ std::string to_str(T& obj) template T& get_dialect(mlir::MLIRContext& ctx) { - auto dialect = ctx.getRegisteredDialect(); + auto dialect = ctx.getOrLoadDialect(); assert(nullptr != dialect); return *dialect; } @@ -169,7 +169,8 @@ struct plier_lowerer : public lowerer_base lowerer_base(context), dialect(get_dialect(ctx)) { - + ctx.loadDialect(); + ctx.loadDialect(); } mlir::ModuleOp lower(const py::object& compilation_context, const py::object& func_ir) From 4ed069ac9fcb8db54369ce4a2692aedbb9b07b59 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 041/259] refac --- mlir-compiler/src/lowering.cpp | 34 ++++++++-------------------------- 1 file changed, 8 insertions(+), 26 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index b4028d09d64..a525c50181c 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -42,14 +42,6 @@ std::string to_str(T& obj) return ret; } -template -T& get_dialect(mlir::MLIRContext& ctx) -{ - auto dialect = ctx.getOrLoadDialect(); - assert(nullptr != dialect); - return *dialect; -} - std::vector> get_blocks(const py::object& func) { std::vector> ret; @@ -150,24 +142,11 @@ struct inst_handles // std::unordered_map typemap; //}; -struct lowerer_base -{ - lowerer_base(mlir::MLIRContext& context): ctx(context), builder(&ctx) {} - -protected: - mlir::MLIRContext& ctx; - mlir::OpBuilder builder; - mlir::Block::BlockArgListType fnargs; - std::vector blocks; - std::unordered_map blocks_map; - inst_handles insts; -}; - -struct plier_lowerer : public lowerer_base +struct plier_lowerer { plier_lowerer(mlir::MLIRContext& context): - lowerer_base(context), - dialect(get_dialect(ctx)) + ctx(context), + builder(&ctx) { ctx.loadDialect(); ctx.loadDialect(); @@ -185,7 +164,11 @@ struct plier_lowerer : public lowerer_base return mod; } private: - plier::PlierDialect& dialect; + mlir::MLIRContext& ctx; + mlir::OpBuilder builder; + std::vector blocks; + std::unordered_map blocks_map; + inst_handles insts; mlir::FuncOp func; std::unordered_map vars_map; struct BlockInfo @@ -220,7 +203,6 @@ struct plier_lowerer : public lowerer_base blocks.push_back(block); blocks_map[ir_blocks[i].first] = block; } - fnargs = func.getArguments(); for (std::size_t i = 0; i < ir_blocks.size(); ++i) { From 0429bd3fcf698bcf62cb7ff3b72c7071535abcc6 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 042/259] func args --- mlir-compiler/src/lowering.cpp | 25 ++++++++++++------------- numba/core/typed_passes.py | 2 +- 2 files changed, 13 insertions(+), 14 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index a525c50181c..b940b76475f 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -157,7 +157,7 @@ struct plier_lowerer auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); typemap = compilation_context["typemap"]; auto name = compilation_context["fnname"]().cast(); - auto typ = get_func_type(/*compilation_context["fndesc"]*/nullptr); + auto typ = get_func_type(compilation_context["fnargs"]); func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); lower_func_body(func_ir); mod.push_back(func); @@ -186,10 +186,15 @@ struct plier_lowerer std::unordered_map block_infos; + plier::PyType get_obj_type(const py::handle& obj) const + { + return plier::PyType::get(&ctx, py::str(obj).cast()); + } + plier::PyType get_type(const py::handle& inst) const { auto type = typemap(inst); - return plier::PyType::get(&ctx, py::str(type).cast()); + return get_obj_type(type); } void lower_func_body(const py::object& func_ir) @@ -485,20 +490,14 @@ struct plier_lowerer report_error(llvm::Twine("get_const_val unhandled type \"") + py::str(val.get_type()).cast() + "\""); } - mlir::FunctionType get_func_type(const py::handle& typedesc) + mlir::FunctionType get_func_type(const py::handle& fnargs) { - auto get_type = [&](const auto& h) { -// return parse_type(py::str(h).cast()); - return plier::PyType::getUndefined(&ctx); - }; -// auto p_func = typedesc(); -// auto ret = get_type(p_func.attr("restype")); auto ret = plier::PyType::getUndefined(&ctx); llvm::SmallVector args; -// for (auto arg : p_func.attr("argtypes")) -// { -// args.push_back(get_type(arg)); -// } + for (auto arg : fnargs()) + { + args.push_back(get_obj_type(arg)); + } return mlir::FunctionType::get(args, {ret}, &ctx); } diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index a694da53f84..053488bd3f2 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -476,7 +476,7 @@ def run_pass(self, state): ctx = {} ctx['typemap'] = lambda op: state.typemap[op.name] # ctx['fndesc'] = lambda: fndesc - # ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) + ctx['fnargs'] = lambda: state.args ctx['fnname'] = lambda: state.func_ir.func_id.func_qualname # ctx['get_var_type'] = lambda name: self.context.get_value_type(self.typeof(name)) import mlir_compiler From 8588e7e52be2c0ab679eeead29414672e962e876 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 043/259] func args folding --- mlir-compiler/include/plier/PlierOps.td | 1 + mlir-compiler/src/dialect.cpp | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index d201e70805f..38ba3351fda 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -23,6 +23,7 @@ def ArgOp : Plier_Op<"arg", []> { StrAttr:$name); let results = (outs Plier_PyType); + let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, unsigned index, StringRef name"> diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index b7a8a1152b1..08adc2501c5 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -76,6 +76,19 @@ void ArgOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, index, name); } +mlir::OpFoldResult ArgOp::fold(llvm::ArrayRef /*operands*/) +{ + auto func = getParentOfType(); + auto ind = index(); + if (ind >= func.getNumArguments() || + func.getArgument(ind).getType() != getType()) + { + emitError("Invalid function args"); + return nullptr; + } + return func.getArgument(ind); +} + void ConstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Attribute val) { From d0a5f4c1699b2497419699474defd89bfd552dc4 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 5 Oct 2020 22:44:07 +0300 Subject: [PATCH 044/259] some rewrites --- mlir-compiler/include/plier/PlierOps.td | 4 +- mlir-compiler/src/compiler.cpp | 11 +- mlir-compiler/src/passes/plier_to_std.cpp | 200 +++++++++++++++++++++- 3 files changed, 206 insertions(+), 9 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 38ba3351fda..d617339812b 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -54,8 +54,8 @@ def GlobalOp : Plier_Op<"global", []> { def BinOp : Plier_Op<"binop", []> { let arguments = (ins - Plier_PyType:$rhs, - Plier_PyType:$lhs, + AnyType:$rhs, + AnyType:$lhs, StrAttr:$op); let results = (outs Plier_PyType); diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index 6c6e7fe6a52..6e4195791fe 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -8,15 +8,18 @@ #include "utils.hpp" +#include "passes/plier_to_std.hpp" + class CompilerContext::CompilerContextImpl { public: CompilerContextImpl(mlir::MLIRContext& ctx): - pm(&ctx) + pm(&ctx, /*verify*/false) { - pm.addNestedPass(mlir::createCanonicalizerPass()); - auto& funcPm = pm.nest(); - // TODO + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(createPlierToStdPass()); + pm.enableStatistics(); + pm.enableTiming(); } void run(mlir::ModuleOp& module) diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 16eaed6fbca..10abe42f66d 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -5,16 +5,210 @@ #include #include +#include "plier/dialect.hpp" + namespace { -struct PlierToStdPass : - public mlir::PassWrapper +mlir::Type map_int_type(plier::PyType type) +{ + auto name = type.getName(); + unsigned num_bits = 0; + if (name.consume_front("int") && + !name.consumeInteger(10, num_bits) && name.empty()) + { + return mlir::IntegerType::get(num_bits, type.getContext()); + } + return {}; +} + +mlir::Type map_plier_type(mlir::Type type) +{ + if (!type.isa()) + { + return {}; + } + auto ptype = type.cast(); + using func_t = mlir::Type(*)(plier::PyType); + const func_t handlers[] = { + &map_int_type + }; + for (auto h : handlers) + { + auto t = h(ptype); + if (t != mlir::Type()) + { + return t; + } + } + return {}; +} + +bool is_supported_type(mlir::Type type) +{ + return type.isIntOrFloat(); +} + +mlir::Type map_type(mlir::Type type) +{ + auto new_type = is_supported_type(type) ? type : map_plier_type(type); + return mlir::Type() == new_type ? type : new_type; +}; + +void convertFuncArgs(mlir::FuncOp func) +{ + llvm::SmallVector new_arg_types; + new_arg_types.reserve(func.getNumArguments()); + for (auto arg_type : func.getArgumentTypes()) + { + new_arg_types.push_back(map_type(arg_type)); + } + auto res_type = map_type(func.getType().getResult(0)); + auto func_type = mlir::FunctionType::get(new_arg_types, res_type, func.getContext()); + func.setType(func_type); + for (unsigned i = 0; i < func.getNumArguments(); ++i) + { + func.front().getArgument(i).setType(new_arg_types[i]); + } +} + +template +void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type) +{ + assert(nullptr != op); + rewriter.replaceOpWithNewOp(op, new_type, op->getOperands()); +} + +bool is_int(mlir::Type type) +{ + return type.isa(); +} + +bool is_float(mlir::Type type) { - void runOnFunction() + return type.isa(); +} + +struct ConstOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + mlir::LogicalResult matchAndRewrite(plier::ConstOp op, + mlir::PatternRewriter& rewriter) const + { + auto value = op.val(); + if (!is_supported_type(value.getType())) + { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp(op, value.getType(), value); + return mlir::success(); + } +}; + +struct BinOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(plier::BinOp op, + mlir::PatternRewriter& rewriter) const { + assert(op.getNumOperands() == 2); + auto type0 = op.getOperand(0).getType(); + auto type1 = op.getOperand(1).getType(); + if (type0 != type1 || !is_supported_type(type0) || !is_supported_type(type1)) + { + return mlir::failure(); + } + using func_t = void(*)(mlir::Operation*, mlir::PatternRewriter&, mlir::Type); + struct OpDesc + { + llvm::StringRef type; + func_t iop; + func_t fop; + }; + + const OpDesc handlers[] = { + {"+", &replace_op, &replace_op} + }; + + auto find_handler = [&]()->const OpDesc& + { + for (auto& h : handlers) + { + if (h.type == op.op()) + { + return h; + } + } + llvm_unreachable("Unhandled op type"); + }; + + if (is_int(type0)) + { + find_handler().iop(op, rewriter, type0); + } + else if (is_float(type0)) + { + find_handler().fop(op, rewriter, type0); + } + else + { + llvm_unreachable("Unhandled arg type"); + } + + return mlir::success(); } }; + +using namespace mlir; +struct FuncOpSignatureConversion : public OpConversionPattern { + FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) + : OpConversionPattern(converter, ctx) {} + + /// Hook for derived classes to implement combined matching and rewriting. + LogicalResult + matchAndRewrite(FuncOp funcOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override + { + convertFuncArgs(funcOp); + rewriter.updateRootInPlace(funcOp, [&] {}); // HACK + return success(); + } +}; + +struct PlierToStdPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + } + + void runOnOperation() override; +}; + +void PlierToStdPass::runOnOperation() +{ + mlir::TypeConverter type_converter; + type_converter.addConversion([](plier::Type type)->llvm::Optional + { + return map_plier_type(type); + }); + + mlir::ConversionTarget target(getContext()); + target.addLegalDialect(); + + mlir::OwningRewritePatternList patterns; + patterns.insert(&getContext()); + patterns.insert(&getContext(), type_converter); + + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) + { + signalPassFailure(); + } +} + } std::unique_ptr createPlierToStdPass() From 185a5fad14a4dfd3f9432c9e728ad97478662d0d Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 7 Oct 2020 23:20:17 +0300 Subject: [PATCH 045/259] refac --- mlir-compiler/src/passes/plier_to_std.cpp | 123 +++++++++++++++++----- 1 file changed, 96 insertions(+), 27 deletions(-) diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 10abe42f66d..13e52026939 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -21,6 +21,19 @@ mlir::Type map_int_type(plier::PyType type) return {}; } +mlir::Type map_int_literal_type(plier::PyType type) +{ + auto name = type.getName(); + unsigned dummy = 0; + if (name.consume_front("Literal[int](") && + !name.consumeInteger(10, dummy) && name.consume_front(")") + && name.empty()) + { + return mlir::IntegerType::get(64, type.getContext()); // TODO + } + return {}; +} + mlir::Type map_plier_type(mlir::Type type) { if (!type.isa()) @@ -30,7 +43,8 @@ mlir::Type map_plier_type(mlir::Type type) auto ptype = type.cast(); using func_t = mlir::Type(*)(plier::PyType); const func_t handlers[] = { - &map_int_type + &map_int_type, + &map_int_literal_type, }; for (auto h : handlers) { @@ -131,49 +145,90 @@ struct BinOpLowering : public mlir::OpRewritePattern {"+", &replace_op, &replace_op} }; - auto find_handler = [&]()->const OpDesc& + using membptr_t = func_t OpDesc::*; + auto call_handler = [&](membptr_t mem) { for (auto& h : handlers) { if (h.type == op.op()) { - return h; + (h.*mem)(op, rewriter, type0); + return mlir::success(); } } - llvm_unreachable("Unhandled op type"); + return mlir::failure(); }; + if (is_int(type0)) { - find_handler().iop(op, rewriter, type0); + return call_handler(&OpDesc::iop); } else if (is_float(type0)) { - find_handler().fop(op, rewriter, type0); - } - else - { - llvm_unreachable("Unhandled arg type"); + return call_handler(&OpDesc::fop); } + return mlir::failure(); + } +}; +struct FuncOpSignatureConversion : + public mlir::OpConversionPattern +{ + FuncOpSignatureConversion(mlir::MLIRContext* ctx, + mlir::TypeConverter& converter) + : OpConversionPattern(converter, ctx) {} + + /// Hook for derived classes to implement combined matching and rewriting. + mlir::LogicalResult + matchAndRewrite(mlir::FuncOp funcOp, mlir::ArrayRef /*operands*/, + mlir::ConversionPatternRewriter &rewriter) const override + { + convertFuncArgs(funcOp); + rewriter.updateRootInPlace(funcOp, [&] {}); // HACK return mlir::success(); } }; -using namespace mlir; -struct FuncOpSignatureConversion : public OpConversionPattern { - FuncOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter) - : OpConversionPattern(converter, ctx) {} - - /// Hook for derived classes to implement combined matching and rewriting. - LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override - { - convertFuncArgs(funcOp); - rewriter.updateRootInPlace(funcOp, [&] {}); // HACK - return success(); - } +struct OpTypeConversion : public mlir::ConversionPattern +{ + OpTypeConversion(mlir::MLIRContext* /*ctx*/, + mlir::TypeConverter& converter) + : ConversionPattern(0, converter, mlir::Pattern::MatchAnyOpTypeTag()) {} + + /// Hook for derived classes to implement combined matching and rewriting. + mlir::LogicalResult + matchAndRewrite(mlir::Operation* op, mlir::ArrayRef /*operands*/, + mlir::ConversionPatternRewriter &rewriter) const override + { + bool changed = false; + llvm::SmallVector new_types; + for (auto type : op->getResultTypes()) + { + if (auto new_type = map_plier_type(type)) + { + new_types.push_back(new_type); + changed = true; + } + else + { + new_types.push_back(type); + } + } + + if (changed) + { + rewriter.updateRootInPlace(op, [&] + { + for (unsigned i = 0; i < static_cast(new_types.size()); ++i) + { + op->getResult(i).setType(new_types[i]); + } + }); + return mlir::success(); + } + return mlir::failure(); + } }; struct PlierToStdPass : @@ -200,12 +255,26 @@ void PlierToStdPass::runOnOperation() target.addLegalDialect(); mlir::OwningRewritePatternList patterns; - patterns.insert(&getContext()); - patterns.insert(&getContext(), type_converter); + patterns.insert(&getContext(), type_converter); + + auto apply_conv = [&]() + { + return mlir::applyPartialConversion(getOperation(), target, patterns); + }; - if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) + if (mlir::failed(apply_conv())) + { + signalPassFailure(); + return; + } + + patterns.clear(); + patterns.insert(&getContext()); + if (mlir::failed(apply_conv())) { signalPassFailure(); + return; } } From 95440cdd5d4e81f00cf8f3957732278bebb92e5a Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 8 Oct 2020 03:05:44 +0300 Subject: [PATCH 046/259] work on lowerer to llvm --- mlir-compiler/CMakeLists.txt | 13 ++++--- mlir-compiler/src/compiler.cpp | 5 +++ mlir-compiler/src/passes/lower_to_llvm.cpp | 42 ++++++++++++++++++++++ mlir-compiler/src/passes/lower_to_llvm.hpp | 8 +++++ mlir-compiler/src/passes/plier_to_std.cpp | 2 +- 5 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 mlir-compiler/src/passes/lower_to_llvm.cpp create mode 100644 mlir-compiler/src/passes/lower_to_llvm.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 0e84c685878..9a88e163166 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -24,6 +24,7 @@ set(SOURCES_LIST src/dialect.cpp src/lowering.cpp src/module.cpp + src/passes/lower_to_llvm.cpp src/passes/plier_to_std.cpp src/type_parser.cpp src/utils.cpp @@ -31,6 +32,7 @@ set(SOURCES_LIST set(HEADERS_LIST src/compiler.hpp src/lowering.hpp + src/passes/lower_to_llvm.hpp src/passes/plier_to_std.hpp src/type_parser.hpp src/utils.hpp @@ -46,17 +48,20 @@ endif () target_compile_definitions(${PROJECT_NAME} PRIVATE ${LLVM_DEFINITIONS}) -target_link_libraries(${PROJECT_NAME} - PRIVATE +target_link_libraries(${PROJECT_NAME} PRIVATE + LLVM${LLVM_NATIVE_ARCH}CodeGen + LLVM${LLVM_NATIVE_ARCH}Desc + LLVMTarget MLIRSupport MLIRLLVMIR MLIRStandard MLIRTargetLLVMIR MLIRTransforms + MLIRStandardToLLVM ) -target_include_directories(${PROJECT_NAME} - PRIVATE +target_include_directories(${PROJECT_NAME} PRIVATE + ./src ./include ${LLVM_INCLUDE_DIRS} ${MLIR_INCLUDE_DIRS} diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index 6e4195791fe..f8a5cf1a287 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -5,10 +5,12 @@ #include #include #include +#include #include "utils.hpp" #include "passes/plier_to_std.hpp" +#include "passes/lower_to_llvm.hpp" class CompilerContext::CompilerContextImpl { @@ -18,6 +20,9 @@ class CompilerContext::CompilerContextImpl { pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(createPlierToStdPass()); + + populate_lower_to_llvm_pipeline(pm); + pm.enableStatistics(); pm.enableTiming(); } diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp new file mode 100644 index 00000000000..2eee9086423 --- /dev/null +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -0,0 +1,42 @@ +#include "passes/lower_to_llvm.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "utils.hpp" + +namespace +{ +const mlir::LowerToLLVMOptions &getLLVMOptions() +{ + static mlir::LowerToLLVMOptions options = []() + { + llvm::InitializeNativeTarget(); + auto triple = llvm::sys::getProcessTriple(); + std::string err_str; + auto target = llvm::TargetRegistry::lookupTarget(triple, err_str); + if (nullptr == target) + { + report_error(llvm::Twine("Unable to get target: ") + err_str); + } + llvm::TargetOptions target_opts; + std::unique_ptr machine(target->createTargetMachine(triple, llvm::sys::getHostCPUName(), "", target_opts, llvm::None)); + mlir::LowerToLLVMOptions opts; + opts.dataLayout = machine->createDataLayout(); + return opts; + }(); + return options; +} +} + +void populate_lower_to_llvm_pipeline(mlir::PassManager& pm) +{ + pm.addPass(mlir::createLowerToLLVMPass(getLLVMOptions())); +} diff --git a/mlir-compiler/src/passes/lower_to_llvm.hpp b/mlir-compiler/src/passes/lower_to_llvm.hpp new file mode 100644 index 00000000000..15423505a84 --- /dev/null +++ b/mlir-compiler/src/passes/lower_to_llvm.hpp @@ -0,0 +1,8 @@ +#pragma once + +namespace mlir +{ +class PassManager; +} + +void populate_lower_to_llvm_pipeline(mlir::PassManager& pm); diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 13e52026939..07411708736 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -1,4 +1,4 @@ -#include "plier_to_std.hpp" +#include "passes/plier_to_std.hpp" #include #include From aa10a12c3fa6566403994be147c63cd65b3e1edf Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 8 Oct 2020 15:53:48 +0300 Subject: [PATCH 047/259] work on llvm lowering --- mlir-compiler/src/lowering.cpp | 18 --- mlir-compiler/src/passes/lower_to_llvm.cpp | 158 +++++++++++++++++++++ 2 files changed, 158 insertions(+), 18 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index b940b76475f..4330631e5c2 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -59,24 +59,6 @@ py::list get_body(const py::handle& block) return block.attr("body").cast(); } -//struct scoped_goto_block -//{ -// scoped_goto_block(mlir::OpBuilder& b, mlir::Block* new_block): -// builder(b), -// old_block(b.getBlock()) -// { -// builder.setInsertionPointToEnd(new_block); -// } - -// ~scoped_goto_block() -// { -// builder.setInsertionPointToEnd(old_block); -// } - -// mlir::OpBuilder& builder; -// mlir::Block* old_block = nullptr; -//}; - struct inst_handles { inst_handles() diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index 2eee9086423..2c7964f7aeb 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -1,6 +1,9 @@ #include "passes/lower_to_llvm.hpp" +#include #include +#include +#include #include #include @@ -12,6 +15,8 @@ #include "utils.hpp" +#include + namespace { const mlir::LowerToLLVMOptions &getLLVMOptions() @@ -30,13 +35,166 @@ const mlir::LowerToLLVMOptions &getLLVMOptions() std::unique_ptr machine(target->createTargetMachine(triple, llvm::sys::getHostCPUName(), "", target_opts, llvm::None)); mlir::LowerToLLVMOptions opts; opts.dataLayout = machine->createDataLayout(); + opts.useBarePtrCallConv = true; return opts; }(); return options; } + +struct LLVMTypeHelper +{ + LLVMTypeHelper(mlir::MLIRContext& ctx): + type_converter(&ctx) {} + + mlir::LLVM::LLVMType i(unsigned bits) + { + return mlir::LLVM::LLVMIntegerType::get(&type_converter.getContext(), bits); + } + + mlir::LLVM::LLVMType ptr(mlir::Type type) + { + auto ll_type = type_converter.convertType(type).cast(); + return mlir::LLVM::LLVMPointerType::get(ll_type); + } + + mlir::MLIRContext& get_context() + { + return type_converter.getContext(); + } + + mlir::LLVMTypeConverter& get_type_converter() + { + return type_converter; + } + +private: + mlir::LLVMTypeConverter type_converter; +}; + +mlir::Type getExceptInfoType(LLVMTypeHelper& type_helper) +{ + mlir::LLVM::LLVMType elems[] = { + type_helper.ptr(type_helper.i(8)), + type_helper.i(32), + type_helper.ptr(type_helper.i(8)), + }; + return mlir::LLVM::LLVMStructType::getLiteral(&type_helper.get_context(), elems); +} + +mlir::FunctionType legalize_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) +{ + auto old_type = func.getType(); + assert(old_type.getNumResults() == 1); + auto& ctx = *old_type.getContext(); + llvm::SmallVector args; + + auto ptr = [&](auto arg) + { + return type_helper.ptr(arg); + }; + + unsigned index = 0; + auto add_arg = [&](mlir::Type type) + { + args.push_back(type); + func.getBody().insertArgument(index, type); + ++index; + }; + + add_arg(ptr(old_type.getResult(0))); + add_arg(ptr(ptr(getExceptInfoType(type_helper)))); + + auto old_args = old_type.getResults(); + std::copy(old_args.begin(), old_args.end(), std::back_inserter(args)); + auto ret_type = mlir::IntegerType::get(32, &ctx); + return mlir::FunctionType::get(args, ret_type, &ctx); +} + +struct ReturnOpLowering : public mlir::OpRewritePattern +{ + ReturnOpLowering(mlir::MLIRContext* ctx, mlir::TypeConverter& converter): + OpRewritePattern(ctx), type_converter(converter) {} + + mlir::LogicalResult matchAndRewrite(mlir::ReturnOp op, + mlir::PatternRewriter& rewriter) const + { + auto insert_ret = [&]() + { + auto ctx = op.getContext(); + auto ret_type = mlir::IntegerType::get(32, ctx); + auto ll_ret_type = mlir::LLVM::LLVMIntegerType::get(ctx, 32); + mlir::Value ret = rewriter.create(op.getLoc(), ll_ret_type, mlir::IntegerAttr::get(ret_type, 0)); + rewriter.replaceOpWithNewOp(op, ret); + }; + + if (op.getNumOperands() == 0) + { + rewriter.setInsertionPoint(op); + insert_ret(); + return mlir::success(); + } + else if (op.getNumOperands() == 1) + { + rewriter.setInsertionPoint(op); + auto addr = op.getParentRegion()->front().getArgument(0); + auto val = op.getOperand(0); + auto ll_ret_type = type_converter.convertType(val.getType()); + auto ll_val = rewriter.create(op.getLoc(), ll_ret_type, val); // TODO: hack to make verifier happy + rewriter.create(op.getLoc(), ll_val, addr); + insert_ret(); + return mlir::success(); + } + else + { + return mlir::failure(); + } + } + +private: + mlir::TypeConverter& type_converter; +}; + +struct LegalizeForNative : + public mlir::PassWrapper +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + } + + void runOnFunction() override; +}; + +void LegalizeForNative::runOnFunction() +{ + LLVMTypeHelper type_helper(getContext()); + auto func = getFunction(); + func.setType(legalize_func_sig(type_helper, func)); + + mlir::ConversionTarget target(getContext()); + target.addLegalDialect(); + + mlir::OwningRewritePatternList patterns; + patterns.insert(&getContext(), + type_helper.get_type_converter()); + + auto apply_conv = [&]() + { + return mlir::applyPartialConversion(getOperation(), target, patterns); + }; + + if (mlir::failed(apply_conv())) + { + signalPassFailure(); + return; + } +} } void populate_lower_to_llvm_pipeline(mlir::PassManager& pm) { + pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToLLVMPass(getLLVMOptions())); } From d7f9d70f8bc20704cf60a4396b898bcb4c1932e1 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 8 Oct 2020 20:03:53 +0300 Subject: [PATCH 048/259] work on llvm lowering --- mlir-compiler/src/compiler.cpp | 2 +- mlir-compiler/src/passes/lower_to_llvm.cpp | 124 +++++++++++++++++---- 2 files changed, 104 insertions(+), 22 deletions(-) diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index f8a5cf1a287..250ea530588 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -16,7 +16,7 @@ class CompilerContext::CompilerContextImpl { public: CompilerContextImpl(mlir::MLIRContext& ctx): - pm(&ctx, /*verify*/false) + pm(&ctx, /*verify*/true) { pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(createPlierToStdPass()); diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index 2c7964f7aeb..02d7024163f 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -1,5 +1,6 @@ #include "passes/lower_to_llvm.hpp" +#include #include #include #include @@ -81,7 +82,7 @@ mlir::Type getExceptInfoType(LLVMTypeHelper& type_helper) return mlir::LLVM::LLVMStructType::getLiteral(&type_helper.get_context(), elems); } -mlir::FunctionType legalize_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) +mlir::FunctionType fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) { auto old_type = func.getType(); assert(old_type.getNumResults() == 1); @@ -154,8 +155,24 @@ struct ReturnOpLowering : public mlir::OpRewritePattern mlir::TypeConverter& type_converter; }; -struct LegalizeForNative : - public mlir::PassWrapper +struct RemoveBitcasts : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite(mlir::LLVM::BitcastOp op, + mlir::PatternRewriter& rewriter) const + { + if (op.getType() == op.getOperand().getType()) + { + rewriter.replaceOp(op, op.getOperand()); + return mlir::success(); + } + return mlir::failure(); + } +}; + +template +struct LLVMLowererBase : public mlir::PassWrapper { virtual void getDependentDialects( mlir::DialectRegistry ®istry) const override @@ -164,37 +181,102 @@ struct LegalizeForNative : registry.insert(); } - void runOnFunction() override; + void runOnFunction() override final + { + LLVMTypeHelper type_helper(getContext()); + + mlir::OwningRewritePatternList patterns; + auto apply_conv = [&]() + { + return mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + }; + static_cast(this)->run(type_helper, patterns, apply_conv); + } }; -void LegalizeForNative::runOnFunction() +class LLVMFunctionPass : public mlir::OperationPass { - LLVMTypeHelper type_helper(getContext()); - auto func = getFunction(); - func.setType(legalize_func_sig(type_helper, func)); +public: + using OperationPass::OperationPass; - mlir::ConversionTarget target(getContext()); - target.addLegalDialect(); + /// The polymorphic API that runs the pass over the currently held function. + virtual void runOnFunction() = 0; - mlir::OwningRewritePatternList patterns; - patterns.insert(&getContext(), - type_helper.get_type_converter()); + /// The polymorphic API that runs the pass over the currently held operation. + void runOnOperation() final { + if (!getFunction().isExternal()) + runOnFunction(); + } + + /// Return the current function being transformed. + mlir::LLVM::LLVMFuncOp getFunction() { return this->getOperation(); } +}; - auto apply_conv = [&]() +struct PreLLVMLowering : public mlir::PassWrapper +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override { - return mlir::applyPartialConversion(getOperation(), target, patterns); - }; + registry.insert(); + registry.insert(); + } - if (mlir::failed(apply_conv())) + void runOnFunction() override final { - signalPassFailure(); - return; + LLVMTypeHelper type_helper(getContext()); + + mlir::OwningRewritePatternList patterns; + auto apply_conv = [&]() + { + return mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + }; + + auto func = getFunction(); + func.setType(fix_func_sig(type_helper, func)); + + patterns.insert(&getContext(), + type_helper.get_type_converter()); + + if (mlir::failed(apply_conv())) + { + signalPassFailure(); + return; + } } -} +}; + +struct PostLLVMLowering : + public mlir::PassWrapper +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + } + + void runOnFunction() override final + { + mlir::OwningRewritePatternList patterns; + auto apply_conv = [&]() + { + return mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + }; + + // Remove redundant bitcasts we have created on PreLowering + patterns.insert(&getContext()); + + if (mlir::failed(apply_conv())) + { + signalPassFailure(); + return; + } + } +}; } void populate_lower_to_llvm_pipeline(mlir::PassManager& pm) { - pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToLLVMPass(getLLVMOptions())); + pm.addPass(std::make_unique()); } From b21bd13cfef5cd736518dd0de45b5bd4a94aebcc Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 8 Oct 2020 20:29:17 +0300 Subject: [PATCH 049/259] fix --- mlir-compiler/src/passes/lower_to_llvm.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index 02d7024163f..7f77f376f15 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -16,8 +16,6 @@ #include "utils.hpp" -#include - namespace { const mlir::LowerToLLVMOptions &getLLVMOptions() @@ -82,7 +80,7 @@ mlir::Type getExceptInfoType(LLVMTypeHelper& type_helper) return mlir::LLVM::LLVMStructType::getLiteral(&type_helper.get_context(), elems); } -mlir::FunctionType fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) +void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) { auto old_type = func.getType(); assert(old_type.getNumResults() == 1); @@ -105,10 +103,10 @@ mlir::FunctionType fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) add_arg(ptr(old_type.getResult(0))); add_arg(ptr(ptr(getExceptInfoType(type_helper)))); - auto old_args = old_type.getResults(); + auto old_args = old_type.getInputs(); std::copy(old_args.begin(), old_args.end(), std::back_inserter(args)); auto ret_type = mlir::IntegerType::get(32, &ctx); - return mlir::FunctionType::get(args, ret_type, &ctx); + func.setType(mlir::FunctionType::get(args, ret_type, &ctx)); } struct ReturnOpLowering : public mlir::OpRewritePattern @@ -232,7 +230,7 @@ struct PreLLVMLowering : public mlir::PassWrapper(&getContext(), type_helper.get_type_converter()); From d2cb76776440cdfcb66b7cc67570bfe2bfaef3ac Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 8 Oct 2020 21:11:13 +0300 Subject: [PATCH 050/259] gen llvm module --- mlir-compiler/src/lowering.cpp | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 4330631e5c2..0d54b5a129f 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -10,9 +10,12 @@ #include #include #include - #include +#include + +#include + #include "plier/dialect.hpp" #include "compiler.hpp" @@ -22,15 +25,15 @@ namespace py = pybind11; namespace { -//std::string serialize_mod(const llvm::Module& mod) -//{ -// std::string ret; -// llvm::raw_string_ostream stream(ret); -//// mod.print(stream, nullptr); -// llvm::WriteBitcodeToFile(mod, stream); -// stream.flush(); -// return ret; -//} +std::string serialize_mod(const llvm::Module& mod) +{ + std::string ret; + llvm::raw_string_ostream stream(ret); +// mod.print(stream, nullptr); + llvm::WriteBitcodeToFile(mod, stream); + stream.flush(); + return ret; +} template std::string to_str(T& obj) @@ -558,5 +561,9 @@ py::bytes lower_function(const py::object& compilation_context, const py::object CompilerContext compiler(context); compiler.run(mod); mod.dump(); - return {}; + + llvm::LLVMContext ll_ctx; + auto ll_mod = mlir::translateModuleToLLVMIR(mod, ll_ctx); + ll_mod->dump(); + return serialize_mod(*ll_mod); } From 95e2032dab7cf024522dc3ea23122115f842d69c Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 8 Oct 2020 21:41:03 +0300 Subject: [PATCH 051/259] execute code from mlir --- numba/core/lowering.py | 10 ++-------- numba/core/typed_passes.py | 28 ++++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 5cbb2a1583d..3bfaa383eb5 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -11,7 +11,7 @@ from numba.core.funcdesc import default_mangler from numba.core.environment import Environment -_use_mlir = False +_use_mlir = True _VarArgItem = namedtuple("_VarArgItem", ("vararg", "index")) @@ -186,13 +186,7 @@ def lower_normal_function(self, fndesc): Lower non-generator *fndesc*. """ if _use_mlir: - ctx = {} - ctx['fndesc'] = lambda: fndesc - ctx['fntype'] = lambda: self.context.call_conv.get_function_type(fndesc.restype, fndesc.argtypes) - ctx['fnname'] = lambda: fndesc.mangled_name - ctx['get_var_type'] = lambda name: self.context.get_value_type(self.typeof(name)) - import mlir_compiler - mod_ir = mlir_compiler.lower_normal_function(ctx, self.func_ir) + mod_ir = self.mlir_blob import llvmlite.binding as llvm mod = llvm.parse_bitcode(mod_ir) diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 053488bd3f2..f4ea2e275b8 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -367,6 +367,7 @@ def run_pass(self, state): with targetctx.push_code_library(library): lower = lowering.Lower(targetctx, library, fndesc, interp, metadata=metadata) + setattr(lower, 'mlir_blob', state.mlir_blob) lower.lower() if not flags.no_cpython_wrapper: lower.create_cpython_wrapper(flags.release_gil) @@ -473,14 +474,33 @@ def __init__(self): pass def run_pass(self, state): + targetctx = state.targetctx + library = state.library + interp = state.func_ir # why is it called this?! + typemap = state.typemap + restype = state.return_type + calltypes = state.calltypes + flags = state.flags + metadata = state.metadata + + msg = ("Function %s failed at nopython " + "mode lowering" % (state.func_id.func_name,)) + with fallback_context(state, msg): + # Lowering + fndesc = \ + funcdesc.PythonFunctionDescriptor.from_specialized_function( + interp, typemap, restype, calltypes, + mangler=targetctx.mangler, inline=flags.forceinline, + noalias=flags.noalias) + fn_name = fndesc.mangled_name + ctx = {} ctx['typemap'] = lambda op: state.typemap[op.name] - # ctx['fndesc'] = lambda: fndesc ctx['fnargs'] = lambda: state.args - ctx['fnname'] = lambda: state.func_ir.func_id.func_qualname - # ctx['get_var_type'] = lambda name: self.context.get_value_type(self.typeof(name)) + ctx['fnname'] = lambda: fn_name import mlir_compiler - mlir_compiler.lower_normal_function(ctx, state.func_ir) + mod = mlir_compiler.lower_normal_function(ctx, state.func_ir) + setattr(state, 'mlir_blob', mod) return True From cc2c92f3477be94ec6c450a4989f23be00e7feda Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 8 Oct 2020 21:44:27 +0300 Subject: [PATCH 052/259] remove unused --- mlir-compiler/CMakeLists.txt | 2 -- mlir-compiler/src/type_parser.cpp | 24 ------------------------ mlir-compiler/src/type_parser.hpp | 5 ----- 3 files changed, 31 deletions(-) delete mode 100644 mlir-compiler/src/type_parser.cpp delete mode 100644 mlir-compiler/src/type_parser.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 9a88e163166..50e69387c84 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -26,7 +26,6 @@ set(SOURCES_LIST src/module.cpp src/passes/lower_to_llvm.cpp src/passes/plier_to_std.cpp - src/type_parser.cpp src/utils.cpp ) set(HEADERS_LIST @@ -34,7 +33,6 @@ set(HEADERS_LIST src/lowering.hpp src/passes/lower_to_llvm.hpp src/passes/plier_to_std.hpp - src/type_parser.hpp src/utils.hpp include/plier/dialect.hpp include/plier/PlierOps.td diff --git a/mlir-compiler/src/type_parser.cpp b/mlir-compiler/src/type_parser.cpp deleted file mode 100644 index d10d58692b6..00000000000 --- a/mlir-compiler/src/type_parser.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include "type_parser.hpp" - -#include - -namespace -{ -[[noreturn]] void report_error(const llvm::Twine& msg) -{ - auto str = msg.str(); - throw std::exception(str.c_str()); -} -} - -mlir::LLVM::LLVMType parse_type(mlir::MLIRContext& context, llvm::StringRef str) -{ - assert(!str.empty()); - auto mlir_type = (std::string("!llvm<\"") + str + "\">").str(); - auto res = mlir::parseType(mlir_type, &context).dyn_cast_or_null(); - if (mlir::Type() == res) - { - report_error(llvm::Twine("cannot parse type: \"") + str + "\""); - } - return res; -} diff --git a/mlir-compiler/src/type_parser.hpp b/mlir-compiler/src/type_parser.hpp deleted file mode 100644 index 6d95aeef219..00000000000 --- a/mlir-compiler/src/type_parser.hpp +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once - -#include - -mlir::LLVM::LLVMType parse_type(mlir::MLIRContext& context, llvm::StringRef str); From a1eb71035bb4a6e76d0e28d6f5ba4347d230ce72 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 9 Oct 2020 20:19:49 +0300 Subject: [PATCH 053/259] refac --- mlir-compiler/src/passes/plier_to_std.cpp | 80 +++++++++++++---------- 1 file changed, 45 insertions(+), 35 deletions(-) diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 07411708736..fd4fa216132 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -34,6 +34,16 @@ mlir::Type map_int_literal_type(plier::PyType type) return {}; } +mlir::Type map_bool_type(plier::PyType type) +{ + auto name = type.getName(); + if (name == "bool") + { + return mlir::IntegerType::get(1, type.getContext()); + } + return {}; +} + mlir::Type map_plier_type(mlir::Type type) { if (!type.isa()) @@ -45,6 +55,7 @@ mlir::Type map_plier_type(mlir::Type type) const func_t handlers[] = { &map_int_type, &map_int_literal_type, + &map_bool_type, }; for (auto h : handlers) { @@ -68,21 +79,31 @@ mlir::Type map_type(mlir::Type type) return mlir::Type() == new_type ? type : new_type; }; -void convertFuncArgs(mlir::FuncOp func) +bool convertFuncArgs(mlir::FuncOp func) { llvm::SmallVector new_arg_types; new_arg_types.reserve(func.getNumArguments()); + bool changed = false; for (auto arg_type : func.getArgumentTypes()) { - new_arg_types.push_back(map_type(arg_type)); + auto new_type = map_type(arg_type); + changed = changed || (new_type != arg_type); + new_arg_types.push_back(new_type); } - auto res_type = map_type(func.getType().getResult(0)); - auto func_type = mlir::FunctionType::get(new_arg_types, res_type, func.getContext()); - func.setType(func_type); - for (unsigned i = 0; i < func.getNumArguments(); ++i) + + auto res_type = func.getType().getResult(0); + auto new_res_type = map_type(res_type); + changed = changed || (res_type != new_res_type); + if (changed) { - func.front().getArgument(i).setType(new_arg_types[i]); + auto func_type = mlir::FunctionType::get(new_arg_types, new_res_type, func.getContext()); + func.setType(func_type); + for (unsigned i = 0; i < func.getNumArguments(); ++i) + { + func.front().getArgument(i).setType(new_arg_types[i]); + } } + return changed; } template @@ -172,34 +193,34 @@ struct BinOpLowering : public mlir::OpRewritePattern } }; -struct FuncOpSignatureConversion : - public mlir::OpConversionPattern +struct FuncOpSignatureConversion : public mlir::OpRewritePattern { FuncOpSignatureConversion(mlir::MLIRContext* ctx, - mlir::TypeConverter& converter) - : OpConversionPattern(converter, ctx) {} + mlir::TypeConverter& /*converter*/) + : OpRewritePattern(ctx) {} /// Hook for derived classes to implement combined matching and rewriting. mlir::LogicalResult - matchAndRewrite(mlir::FuncOp funcOp, mlir::ArrayRef /*operands*/, - mlir::ConversionPatternRewriter &rewriter) const override + matchAndRewrite(mlir::FuncOp funcOp, mlir::PatternRewriter &rewriter) const override { - convertFuncArgs(funcOp); - rewriter.updateRootInPlace(funcOp, [&] {}); // HACK - return mlir::success(); + bool changed = convertFuncArgs(funcOp); + if (changed) + { + rewriter.updateRootInPlace(funcOp, [&] {}); // HACK + } + return mlir::success(changed); } }; -struct OpTypeConversion : public mlir::ConversionPattern +struct OpTypeConversion : public mlir::RewritePattern { OpTypeConversion(mlir::MLIRContext* /*ctx*/, - mlir::TypeConverter& converter) - : ConversionPattern(0, converter, mlir::Pattern::MatchAnyOpTypeTag()) {} + mlir::TypeConverter& /*converter*/) + : RewritePattern(0, mlir::Pattern::MatchAnyOpTypeTag()) {} /// Hook for derived classes to implement combined matching and rewriting. mlir::LogicalResult - matchAndRewrite(mlir::Operation* op, mlir::ArrayRef /*operands*/, - mlir::ConversionPatternRewriter &rewriter) const override + matchAndRewrite(mlir::Operation* op, mlir::PatternRewriter &rewriter) const override { bool changed = false; llvm::SmallVector new_types; @@ -225,9 +246,8 @@ struct OpTypeConversion : public mlir::ConversionPattern op->getResult(i).setType(new_types[i]); } }); - return mlir::success(); } - return mlir::failure(); + return mlir::success(changed); } }; @@ -251,16 +271,14 @@ void PlierToStdPass::runOnOperation() return map_plier_type(type); }); - mlir::ConversionTarget target(getContext()); - target.addLegalDialect(); - mlir::OwningRewritePatternList patterns; patterns.insert(&getContext(), type_converter); + patterns.insert(&getContext()); auto apply_conv = [&]() { - return mlir::applyPartialConversion(getOperation(), target, patterns); + return mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); }; if (mlir::failed(apply_conv())) @@ -268,14 +286,6 @@ void PlierToStdPass::runOnOperation() signalPassFailure(); return; } - - patterns.clear(); - patterns.insert(&getContext()); - if (mlir::failed(apply_conv())) - { - signalPassFailure(); - return; - } } } From 89ab7684bbd834aa8a85f0c4867985d68f0a9e69 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 00:12:43 +0300 Subject: [PATCH 054/259] bool cast lowering --- mlir-compiler/include/plier/PlierOps.td | 8 +- mlir-compiler/src/passes/plier_to_std.cpp | 129 +++++++++++++++++++++- 2 files changed, 131 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index d617339812b..570b376dbf3 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -17,7 +17,7 @@ def Plier_PyType : DialectType traits = []> : Op; -def ArgOp : Plier_Op<"arg", []> { +def ArgOp : Plier_Op<"arg", [NoSideEffect]> { let arguments = (ins UI32Attr:$index, StrAttr:$name); @@ -30,7 +30,7 @@ def ArgOp : Plier_Op<"arg", []> { ]; } -def ConstOp : Plier_Op<"const", []> { +def ConstOp : Plier_Op<"const", [NoSideEffect]> { let arguments = (ins AnyAttr:$val); @@ -41,7 +41,7 @@ def ConstOp : Plier_Op<"const", []> { ]; } -def GlobalOp : Plier_Op<"global", []> { +def GlobalOp : Plier_Op<"global", [NoSideEffect]> { let arguments = (ins StrAttr:$name); @@ -58,7 +58,7 @@ def BinOp : Plier_Op<"binop", []> { AnyType:$lhs, StrAttr:$op); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value rhs, ::mlir::Value lhs, StringRef op"> diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index fd4fa216132..42f9994fb29 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -5,6 +5,8 @@ #include #include +#include + #include "plier/dialect.hpp" namespace @@ -113,6 +115,15 @@ void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type rewriter.replaceOpWithNewOp(op, new_type, op->getOperands()); } +template +void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/) +{ + assert(nullptr != op); + auto pred_attr = mlir::IntegerAttr::get(mlir::IntegerType::get(64, op->getContext()), Pred); + mlir::Type new_type = mlir::IntegerType::get(1, op->getContext()); + rewriter.replaceOpWithNewOp(op, new_type, pred_attr, op->getOperand(0), op->getOperand(1)); +} + bool is_int(mlir::Type type) { return type.isa(); @@ -163,7 +174,9 @@ struct BinOpLowering : public mlir::OpRewritePattern }; const OpDesc handlers[] = { - {"+", &replace_op, &replace_op} + {"+", &replace_op, &replace_op}, + {">", &replace_cmp_op(mlir::CmpIPredicate::sgt)>, + &replace_cmp_op(mlir::CmpFPredicate::OGT)>}, }; using membptr_t = func_t OpDesc::*; @@ -193,6 +206,117 @@ struct BinOpLowering : public mlir::OpRewritePattern } }; +template +mlir::Value int_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + auto src_bits = val.getType().cast().getWidth(); + auto dst_bits = dst_type.cast().getWidth(); + assert(src_bits != dst_bits); + if (dst_bits > src_bits) + { + using T = std::conditional_t; + return rewriter.create(val.getLoc(), val, dst_type); + } + else + { + return rewriter.create(val.getLoc(), val, dst_type); + } +} + +mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + auto src_type = val.getType(); + if (src_type == dst_type) + { + return val; + } + + struct Handler + { + using selector_t = bool(*)(mlir::Type); + using cast_op_t = mlir::Value(*)(mlir::Type, mlir::Value, mlir::PatternRewriter&); + selector_t src; + selector_t dst; + cast_op_t cast_op; + }; + + const Handler handlers[] = { + {&is_int, &is_int, &int_cast}, + }; + + for (auto& h : handlers) + { + if (h.src(src_type) && h.dst(dst_type)) + { + return h.cast_op(dst_type, val, rewriter); + } + } + + return nullptr; +} + +mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, mlir::PatternRewriter& rewriter) +{ + if (op.getNumOperands() != 2) + { + return mlir::failure(); + } + auto val = op.getOperand(1); + bool success = false; + auto replace_op = [&](mlir::Value val) + { + assert(!success); + if (val) + { + rewriter.replaceOp(op, val); + success = true; + } + }; + auto src_type = val.getType(); + auto dst_type = mlir::IntegerType::get(1, op.getContext()); + mlir::TypeSwitch(src_type) + .Case([&](auto) { replace_op(do_cast(dst_type, val, rewriter)); }); + return mlir::success(success); +} + +using call_lowerer_func_t = mlir::LogicalResult(*)(plier::PyCallOp, mlir::PatternRewriter&); +const constexpr std::pair builtin_calls[] = { + {"", &lower_bool_cast}, +}; + +struct CallOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(plier::PyCallOp op, mlir::PatternRewriter& rewriter) const override + { + if (op.getNumOperands() == 0) + { + return mlir::failure(); + } + auto func_type = op.getOperand(0).getType(); + if (!func_type.isa()) + { + return mlir::failure(); + } + auto name = func_type.cast().getName(); + if (!name.consume_front("Function(") || !name.consume_back(")")) + { + return mlir::failure(); + } + for (auto& c : builtin_calls) + { + if (c.first == name) + { + return c.second(op, rewriter); + } + } + return mlir::failure(); + } +}; + + struct FuncOpSignatureConversion : public mlir::OpRewritePattern { FuncOpSignatureConversion(mlir::MLIRContext* ctx, @@ -274,7 +398,8 @@ void PlierToStdPass::runOnOperation() mlir::OwningRewritePatternList patterns; patterns.insert(&getContext(), type_converter); - patterns.insert(&getContext()); + patterns.insert(&getContext()); auto apply_conv = [&]() { From 9eb99554fe6377fe0e3aed58262659c30da9912f Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 00:27:54 +0300 Subject: [PATCH 055/259] convert block args types --- mlir-compiler/src/passes/plier_to_std.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 42f9994fb29..1c9a5179d4f 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -81,7 +81,7 @@ mlir::Type map_type(mlir::Type type) return mlir::Type() == new_type ? type : new_type; }; -bool convertFuncArgs(mlir::FuncOp func) +bool convert_func_sig(mlir::FuncOp func) { llvm::SmallVector new_arg_types; new_arg_types.reserve(func.getNumArguments()); @@ -105,6 +105,20 @@ bool convertFuncArgs(mlir::FuncOp func) func.front().getArgument(i).setType(new_arg_types[i]); } } + for (auto& bb : llvm::make_range(++func.getBody().begin(), + func.getBody().end())) + { + for (auto arg : bb.getArguments()) + { + auto arg_type = arg.getType(); + auto new_type = map_type(arg_type); + if (new_type != arg_type) + { + arg.setType(new_type); + changed = true; + } + } + } return changed; } @@ -327,7 +341,7 @@ struct FuncOpSignatureConversion : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite(mlir::FuncOp funcOp, mlir::PatternRewriter &rewriter) const override { - bool changed = convertFuncArgs(funcOp); + bool changed = convert_func_sig(funcOp); if (changed) { rewriter.updateRootInPlace(funcOp, [&] {}); // HACK From eeef3a5a1e265b3a559089b770895ba1833d8edc Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 00:53:23 +0300 Subject: [PATCH 056/259] refac --- mlir-compiler/src/compiler.cpp | 24 ++++++++++++++++++------ mlir-compiler/src/compiler.hpp | 10 +++++++++- mlir-compiler/src/lowering.cpp | 11 +++++++---- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index 250ea530588..af04454630e 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -15,16 +15,28 @@ class CompilerContext::CompilerContextImpl { public: - CompilerContextImpl(mlir::MLIRContext& ctx): - pm(&ctx, /*verify*/true) + CompilerContextImpl(mlir::MLIRContext& ctx, + const CompilerContext::Settings& settings): + pm(&ctx, settings.verify) { pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(createPlierToStdPass()); populate_lower_to_llvm_pipeline(pm); - pm.enableStatistics(); - pm.enableTiming(); + if (settings.pass_statistics) + { + pm.enableStatistics(); + } + if (settings.pass_timings) + { + pm.enableTiming(); + } + if (settings.ir_printing) + { + ctx.enableMultithreading(false); + pm.enableIRPrinting(); + } } void run(mlir::ModuleOp& module) @@ -38,8 +50,8 @@ class CompilerContext::CompilerContextImpl mlir::PassManager pm; }; -CompilerContext::CompilerContext(mlir::MLIRContext& ctx): - impl(std::make_unique(ctx)) +CompilerContext::CompilerContext(mlir::MLIRContext& ctx, const Settings& settings): + impl(std::make_unique(ctx, settings)) { } diff --git a/mlir-compiler/src/compiler.hpp b/mlir-compiler/src/compiler.hpp index 67f6a9591a6..d168abc8dc0 100644 --- a/mlir-compiler/src/compiler.hpp +++ b/mlir-compiler/src/compiler.hpp @@ -11,9 +11,17 @@ class ModuleOp; class CompilerContext { public: + struct Settings + { + bool verify = false; + bool pass_statistics = false; + bool pass_timings = false; + bool ir_printing = false; + }; + class CompilerContextImpl; - CompilerContext(mlir::MLIRContext& ctx); + CompilerContext(mlir::MLIRContext& ctx, const Settings& settings); ~CompilerContext(); CompilerContext(CompilerContext&&) = default; diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 0d54b5a129f..be5a2209ee2 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -557,13 +557,16 @@ py::bytes lower_function(const py::object& compilation_context, const py::object mlir::registerDialect(); mlir::MLIRContext context; auto mod = plier_lowerer(context).lower(compilation_context, func_ir); - mod.dump(); - CompilerContext compiler(context); + CompilerContext::Settings settings; + settings.verify = true; + settings.pass_statistics = false; + settings.pass_timings = false; + settings.ir_printing = false; + CompilerContext compiler(context, settings); compiler.run(mod); - mod.dump(); llvm::LLVMContext ll_ctx; auto ll_mod = mlir::translateModuleToLLVMIR(mod, ll_ctx); - ll_mod->dump(); +// ll_mod->dump(); return serialize_mod(*ll_mod); } From 0e85e18099064579f809be2aee91ef4eca6b6429 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 19:30:34 +0300 Subject: [PATCH 057/259] tests --- numba/mlir/__init__.py | 4 ++++ numba/mlir/tests/__init__.py | 10 ++++++++++ numba/mlir/tests/test_basic.py | 33 +++++++++++++++++++++++++++++++++ numba/tests/__init__.py | 3 +++ 4 files changed, 50 insertions(+) create mode 100644 numba/mlir/__init__.py create mode 100644 numba/mlir/tests/__init__.py create mode 100644 numba/mlir/tests/test_basic.py diff --git a/numba/mlir/__init__.py b/numba/mlir/__init__.py new file mode 100644 index 00000000000..d4ef12a664e --- /dev/null +++ b/numba/mlir/__init__.py @@ -0,0 +1,4 @@ +from numba import runtests + +def test(*args, **kwargs): + return runtests.main("numba.mlir.tests", *args, **kwargs) diff --git a/numba/mlir/tests/__init__.py b/numba/mlir/tests/__init__.py new file mode 100644 index 00000000000..b2cb40fb5ce --- /dev/null +++ b/numba/mlir/tests/__init__.py @@ -0,0 +1,10 @@ +from numba.testing import unittest +from numba.testing import load_testsuite +from os.path import dirname + +def load_tests(loader, tests, pattern): + suite = unittest.TestSuite() + this_dir = dirname(__file__) + suite.addTests(load_testsuite(loader, this_dir)) + + return suite diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py new file mode 100644 index 00000000000..d452852a6cf --- /dev/null +++ b/numba/mlir/tests/test_basic.py @@ -0,0 +1,33 @@ +import numba +from numba import njit + +from numba.tests.support import TestCase +import unittest + +import itertools + +_test_values = [-3,-2,-1,0,1,2,3] +class TestMlirBasic(TestCase): + + def test_ret(self): + def py_func(a): + return a + + jit_func = njit(py_func) + for val in _test_values: + self.assertEqual(py_func(val), jit_func(val)) + + def test_ops(self): + py_funcs = [ + lambda a, b: a + b, + #lambda a, b: a - b, + ] + + for py_func in py_funcs: + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + self.assertEqual(py_func(a, b), jit_func(a, b)) + + +if __name__ == '__main__': + unittest.main() diff --git a/numba/tests/__init__.py b/numba/tests/__init__.py index f04b1007dab..555bab11663 100644 --- a/numba/tests/__init__.py +++ b/numba/tests/__init__.py @@ -33,5 +33,8 @@ def load_tests(loader, tests, pattern): roc_dir = join(dirname(dirname(__file__)), 'roc/tests') suite.addTests(loader.discover(roc_dir)) + mlir_dir = join(dirname(dirname(__file__)), 'mlir/tests') + suite.addTests(loader.discover(mlir_dir)) + return suite From ca5d7f31eb03e5406118bec8695c05e4822aac07 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 19:30:42 +0300 Subject: [PATCH 058/259] refac --- mlir-compiler/src/lowering.cpp | 69 ++++++++++++---------------------- 1 file changed, 25 insertions(+), 44 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index be5a2209ee2..ca1c17a0bfb 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -81,10 +81,11 @@ struct inst_handles auto ops = py::module::import("operator"); - add = ops.attr("add"); - - eq = ops.attr("eq"); - gt = ops.attr("gt"); + for (auto elem : llvm::zip(ops_names, ops_handles)) + { + auto name = std::get<0>(elem).name; + std::get<1>(elem) = ops.attr(name.data()); + } } py::handle Assign; @@ -99,33 +100,22 @@ struct inst_handles py::handle Const; py::handle Global; - py::handle add; + struct OpId + { + llvm::StringRef op; + llvm::StringRef name; + }; - py::handle eq; - py::handle gt; -}; + static const constexpr OpId ops_names[] = { + {"+", "add"}, + {"-", "sub"}, -//struct type_cache -//{ -// using Type = mllvm::LLVMType; - -// Type get_type(mlir::MLIRContext& context, llvm::StringRef str) -// { -// assert(!str.empty()); -// auto s = str.str(); -// auto it = typemap.find(s); -// if (typemap.end() != it) -// { -// return it->second; -// } -// auto type = parse_type(context, str); -// typemap[s] = type; -// return type; -// } - -//private: -// std::unordered_map typemap; -//}; + {"==", "eq"}, + {">", "gt"}, + }; + + std::array ops_handles; +}; struct plier_lowerer { @@ -392,22 +382,13 @@ struct plier_lowerer mlir::Value resolve_op(mlir::Value lhs, mlir::Value rhs, const py::handle& op) { - // TODO unhardcode - if (op.is(insts.add)) - { - return builder.create(builder.getUnknownLoc(), lhs, rhs, "+"); - } -// if (op.is(insts.eq)) -// { -// assert(lhs.getType() == rhs.getType()); -// if (lhs.getType().cast().isIntegerTy()) -// { -// return builder.create(builder.getUnknownLoc(), mllvm::ICmpPredicate::eq, lhs, rhs); -// } -// } - if (op.is(insts.gt)) + for (auto elem : llvm::zip(insts.ops_names, insts.ops_handles)) { - return builder.create(builder.getUnknownLoc(), lhs, rhs, ">"); + if (op.is(std::get<1>(elem))) + { + auto op_name = std::get<0>(elem).op; + return builder.create(builder.getUnknownLoc(), lhs, rhs, op_name); + } } report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); From 2f22ee2c4abebad05dcf18f4e627a2d5688215eb Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 19:40:23 +0300 Subject: [PATCH 059/259] ops tests --- mlir-compiler/src/lowering.cpp | 4 +--- mlir-compiler/src/passes/plier_to_std.cpp | 3 +++ numba/mlir/tests/test_basic.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index ca1c17a0bfb..630b23ffa3f 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -109,6 +109,7 @@ struct inst_handles static const constexpr OpId ops_names[] = { {"+", "add"}, {"-", "sub"}, + {"*", "mul"}, {"==", "eq"}, {">", "gt"}, @@ -348,9 +349,6 @@ struct plier_lowerer auto args = expr.attr("args").cast(); auto kws = expr.attr("kws").cast(); auto vararg = expr.attr("vararg"); -// std::cout << py::str(args).cast() << std::endl; -// std::cout << py::str(kws).cast() << std::endl; -// std::cout << py::str(vararg).cast() << std::endl; mlir::SmallVector args_list; mlir::SmallVector, 8> kwargs_list; diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 1c9a5179d4f..1944483c317 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -189,6 +189,9 @@ struct BinOpLowering : public mlir::OpRewritePattern const OpDesc handlers[] = { {"+", &replace_op, &replace_op}, + {"-", &replace_op, &replace_op}, + {"*", &replace_op, &replace_op}, + {">", &replace_cmp_op(mlir::CmpIPredicate::sgt)>, &replace_cmp_op(mlir::CmpFPredicate::OGT)>}, }; diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index d452852a6cf..4bf8dd423a7 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -20,7 +20,8 @@ def py_func(a): def test_ops(self): py_funcs = [ lambda a, b: a + b, - #lambda a, b: a - b, + lambda a, b: a - b, + lambda a, b: a * b, ] for py_func in py_funcs: From ce72419edaaeaf75676678ed2708754dfaa9c20a Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 19:47:03 +0300 Subject: [PATCH 060/259] cmp ops --- mlir-compiler/src/lowering.cpp | 6 +++++- mlir-compiler/src/passes/plier_to_std.cpp | 10 ++++++++++ numba/mlir/tests/test_basic.py | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 630b23ffa3f..54cdf0aa6ee 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -111,8 +111,12 @@ struct inst_handles {"-", "sub"}, {"*", "mul"}, - {"==", "eq"}, {">", "gt"}, + {">=", "ge"}, + {"<", "lt"}, + {"<=", "le"}, + {"!=", "ne"}, + {"==", "eq"}, }; std::array ops_handles; diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 1944483c317..ca9e1cdf9ec 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -194,6 +194,16 @@ struct BinOpLowering : public mlir::OpRewritePattern {">", &replace_cmp_op(mlir::CmpIPredicate::sgt)>, &replace_cmp_op(mlir::CmpFPredicate::OGT)>}, + {">=", &replace_cmp_op(mlir::CmpIPredicate::sge)>, + &replace_cmp_op(mlir::CmpFPredicate::OGE)>}, + {"<", &replace_cmp_op(mlir::CmpIPredicate::slt)>, + &replace_cmp_op(mlir::CmpFPredicate::OLT)>}, + {"<=", &replace_cmp_op(mlir::CmpIPredicate::sle)>, + &replace_cmp_op(mlir::CmpFPredicate::OLE)>}, + {"!=", &replace_cmp_op(mlir::CmpIPredicate::ne)>, + &replace_cmp_op(mlir::CmpFPredicate::ONE)>}, + {"==", &replace_cmp_op(mlir::CmpIPredicate::eq)>, + &replace_cmp_op(mlir::CmpFPredicate::OEQ)>}, }; using membptr_t = func_t OpDesc::*; diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 4bf8dd423a7..2cb2cafa50f 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -22,6 +22,22 @@ def test_ops(self): lambda a, b: a + b, lambda a, b: a - b, lambda a, b: a * b, + # TODO: div + ] + + for py_func in py_funcs: + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + self.assertEqual(py_func(a, b), jit_func(a, b)) + + def test_cmp_ops(self): + py_funcs = [ + lambda a, b: a if a > b else b, + lambda a, b: a if a < b else b, + lambda a, b: a if a >= b else b, + lambda a, b: a if a <= b else b, + lambda a, b: a if a == b else b, + lambda a, b: a if a != b else b, ] for py_func in py_funcs: From c256448b847c50ef2dafcc38ac753979878362c7 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 19:49:14 +0300 Subject: [PATCH 061/259] const ops tests --- numba/mlir/tests/test_basic.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 2cb2cafa50f..a705bab721e 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -45,6 +45,17 @@ def test_cmp_ops(self): for a, b in itertools.product(_test_values, _test_values): self.assertEqual(py_func(a, b), jit_func(a, b)) + def test_const_ops(self): + py_funcs = [ + lambda a: a + 42, + lambda a: 43 + a, + ] + + for py_func in py_funcs: + jit_func = njit(py_func) + for val in _test_values: + self.assertEqual(py_func(val), jit_func(val)) + if __name__ == '__main__': unittest.main() From 7e07104e34a4abc2bdf42bb3e110ebe42974989e Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 19:53:50 +0300 Subject: [PATCH 062/259] more tests --- numba/mlir/tests/test_basic.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index a705bab721e..3d22c4c2c2d 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -56,6 +56,28 @@ def test_const_ops(self): for val in _test_values: self.assertEqual(py_func(val), jit_func(val)) + def test_var(self): + def py_func(a): + c = 1 + c = c + a + return c + + jit_func = njit(py_func) + for val in _test_values: + self.assertEqual(py_func(val), jit_func(val)) + + def test_jump(self): + def py_func(a, b): + c = 3 + if a > 5: + c = c + a + c = c + b + return c + + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + self.assertEqual(py_func(a, b), jit_func(a, b)) + if __name__ == '__main__': unittest.main() From 8196a1052469fb1125aafe0814c627e8d8668eb0 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 10 Oct 2020 21:07:06 +0300 Subject: [PATCH 063/259] some error handling --- mlir-compiler/src/compiler.cpp | 25 ++++++++++++++++++++++--- mlir-compiler/src/lowering.cpp | 33 ++++++++++++++++++++++++++++----- mlir-compiler/src/utils.hpp | 16 ++++++++++++++++ 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index af04454630e..efa7209f4d1 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -7,6 +7,11 @@ #include #include +#include + +#include +#include + #include "utils.hpp" #include "passes/plier_to_std.hpp" @@ -41,10 +46,24 @@ class CompilerContext::CompilerContextImpl void run(mlir::ModuleOp& module) { - if (mlir::failed(pm.run(module))) + std::string err; + llvm::raw_string_ostream err_stream(err); + auto diag_handler = [&](mlir::Diagnostic& diag) { - report_error("Compiler pipeline failed"); - } + if (diag.getSeverity() == mlir::DiagnosticSeverity::Error) + { + err_stream << diag; + } + }; + + scoped_diag_handler(*pm.getContext(), diag_handler, [&]() + { + if (mlir::failed(pm.run(module))) + { + err_stream.flush(); + report_error(llvm::Twine("MLIR pipeline failed\n") + err); + } + }); } private: mlir::PassManager pm; diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 54cdf0aa6ee..bae7b0f63e8 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -532,6 +532,33 @@ struct plier_lowerer } }; + +py::bytes gen_ll_module(mlir::ModuleOp mod) +{ + std::string err; + llvm::raw_string_ostream err_stream(err); + auto diag_handler = [&](mlir::Diagnostic& diag) + { + if (diag.getSeverity() == mlir::DiagnosticSeverity::Error) + { + err_stream << diag; + } + }; + llvm::LLVMContext ll_ctx; + std::unique_ptr ll_mod; + scoped_diag_handler(*mod.getContext(), diag_handler, [&]() + { + ll_mod = mlir::translateModuleToLLVMIR(mod, ll_ctx); + if (nullptr == ll_mod) + { + err_stream.flush(); + report_error(llvm::Twine("Cannot generate LLVM module\n") + err); + } + }); + assert(nullptr != ll_mod); +// ll_mod->dump(); + return serialize_mod(*ll_mod); +} } py::bytes lower_function(const py::object& compilation_context, const py::object& func_ir) @@ -547,9 +574,5 @@ py::bytes lower_function(const py::object& compilation_context, const py::object settings.ir_printing = false; CompilerContext compiler(context, settings); compiler.run(mod); - - llvm::LLVMContext ll_ctx; - auto ll_mod = mlir::translateModuleToLLVMIR(mod, ll_ctx); -// ll_mod->dump(); - return serialize_mod(*ll_mod); + return gen_ll_module(mod); } diff --git a/mlir-compiler/src/utils.hpp b/mlir-compiler/src/utils.hpp index adb98e2457a..1e026c45137 100644 --- a/mlir-compiler/src/utils.hpp +++ b/mlir-compiler/src/utils.hpp @@ -1,8 +1,24 @@ #pragma once +#include + +#include + namespace llvm { class Twine; } [[noreturn]] void report_error(const llvm::Twine& msg); + +template +void scoped_diag_handler(T& ctx, H&& diag_handler, F&& func) +{ + auto& diag_engine = ctx.getDiagEngine(); + auto diag_id = diag_engine.registerHandler(std::forward(diag_handler)); + auto diag_guard = llvm::make_scope_exit([&]() + { + diag_engine.eraseHandler(diag_id); + }); + func(); +} From 6221c08987458e98d3c43b8bf6f277eae623acec Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 01:17:53 +0300 Subject: [PATCH 064/259] more checks --- mlir-compiler/src/passes/lower_to_llvm.cpp | 40 +++++++++++++--------- mlir-compiler/src/passes/plier_to_std.cpp | 7 ++-- 2 files changed, 27 insertions(+), 20 deletions(-) diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index 7f77f376f15..cf0a5f985b9 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -14,6 +14,8 @@ #include #include +#include "plier/dialect.hpp" + #include "utils.hpp" namespace @@ -52,7 +54,9 @@ struct LLVMTypeHelper mlir::LLVM::LLVMType ptr(mlir::Type type) { + assert(static_cast(type)); auto ll_type = type_converter.convertType(type).cast(); + assert(static_cast(ll_type)); return mlir::LLVM::LLVMPointerType::get(ll_type); } @@ -138,6 +142,7 @@ struct ReturnOpLowering : public mlir::OpRewritePattern auto addr = op.getParentRegion()->front().getArgument(0); auto val = op.getOperand(0); auto ll_ret_type = type_converter.convertType(val.getType()); + assert(static_cast(ll_ret_type)); auto ll_val = rewriter.create(op.getLoc(), ll_ret_type, val); // TODO: hack to make verifier happy rewriter.create(op.getLoc(), ll_val, addr); insert_ret(); @@ -169,26 +174,26 @@ struct RemoveBitcasts : public mlir::OpRewritePattern } }; -template -struct LLVMLowererBase : public mlir::PassWrapper +class CheckForPlierTypes : + public mlir::PassWrapper> { - virtual void getDependentDialects( - mlir::DialectRegistry ®istry) const override + void runOnOperation() override { - registry.insert(); - registry.insert(); - } - - void runOnFunction() override final - { - LLVMTypeHelper type_helper(getContext()); - - mlir::OwningRewritePatternList patterns; - auto apply_conv = [&]() + markAllAnalysesPreserved(); + getOperation()->walk([&](mlir::Operation* op) { - return mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); - }; - static_cast(this)->run(type_helper, patterns, apply_conv); + auto check_type = [](mlir::Type type) + { + return type.isa(); + }; + + if (llvm::any_of(op->getResultTypes(), check_type) || + llvm::any_of(op->getOperandTypes(), check_type)) + { + op->emitOpError(": not all plier types were translated\n"); + signalPassFailure(); + } + }); } }; @@ -274,6 +279,7 @@ struct PostLLVMLowering : void populate_lower_to_llvm_pipeline(mlir::PassManager& pm) { + pm.addPass(std::make_unique()); pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToLLVMPass(getLLVMOptions())); pm.addPass(std::make_unique()); diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index ca9e1cdf9ec..80b86a69745 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -48,9 +48,10 @@ mlir::Type map_bool_type(plier::PyType type) mlir::Type map_plier_type(mlir::Type type) { + assert(static_cast(type)); if (!type.isa()) { - return {}; + return nullptr; } auto ptype = type.cast(); using func_t = mlir::Type(*)(plier::PyType); @@ -67,7 +68,7 @@ mlir::Type map_plier_type(mlir::Type type) return t; } } - return {}; + return nullptr; } bool is_supported_type(mlir::Type type) @@ -78,7 +79,7 @@ bool is_supported_type(mlir::Type type) mlir::Type map_type(mlir::Type type) { auto new_type = is_supported_type(type) ? type : map_plier_type(type); - return mlir::Type() == new_type ? type : new_type; + return static_cast(new_type) ? new_type : type; }; bool convert_func_sig(mlir::FuncOp func) From 3c48f831e056357d66d0b949e545b66623bb1f84 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 01:23:50 +0300 Subject: [PATCH 065/259] read settings from python --- mlir-compiler/src/lowering.cpp | 17 +++++++++++------ numba/core/typed_passes.py | 1 + 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index bae7b0f63e8..4bb88c25a4d 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -533,6 +533,16 @@ struct plier_lowerer }; +CompilerContext::Settings get_settings(const py::handle& settings) +{ + CompilerContext::Settings ret; + ret.verify = settings["verify"].cast(); + ret.pass_statistics = settings["pass_statistics"].cast(); + ret.pass_timings = settings["pass_timings"].cast(); + ret.ir_printing = settings["ir_printing"].cast(); + return ret; +} + py::bytes gen_ll_module(mlir::ModuleOp mod) { std::string err; @@ -567,12 +577,7 @@ py::bytes lower_function(const py::object& compilation_context, const py::object mlir::registerDialect(); mlir::MLIRContext context; auto mod = plier_lowerer(context).lower(compilation_context, func_ir); - CompilerContext::Settings settings; - settings.verify = true; - settings.pass_statistics = false; - settings.pass_timings = false; - settings.ir_printing = false; - CompilerContext compiler(context, settings); + CompilerContext compiler(context, get_settings(compilation_context["compiler_settings"])); compiler.run(mod); return gen_ll_module(mod); } diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index f4ea2e275b8..6414f8d1449 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -495,6 +495,7 @@ def run_pass(self, state): fn_name = fndesc.mangled_name ctx = {} + ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': False} ctx['typemap'] = lambda op: state.typemap[op.name] ctx['fnargs'] = lambda: state.args ctx['fnname'] = lambda: fn_name From 3f3be019b1f6bd49505ece770e7ba0b2840e0ff9 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 01:39:19 +0300 Subject: [PATCH 066/259] more checks --- mlir-compiler/src/passes/lower_to_llvm.cpp | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index cf0a5f985b9..62d8c259283 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -182,6 +182,13 @@ class CheckForPlierTypes : markAllAnalysesPreserved(); getOperation()->walk([&](mlir::Operation* op) { + if (op->getName().getDialect() == plier::PlierDialect::getDialectNamespace()) + { + op->emitOpError(": not all plier ops were translated\n"); + signalPassFailure(); + return; + } + auto check_type = [](mlir::Type type) { return type.isa(); From 08043cfae6609a021bd4284274a21424cd6c7aa3 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 02:05:32 +0300 Subject: [PATCH 067/259] some float support --- mlir-compiler/src/passes/plier_to_std.cpp | 44 +++++++++++++++++++---- 1 file changed, 37 insertions(+), 7 deletions(-) diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 80b86a69745..300deb927ec 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -20,7 +20,7 @@ mlir::Type map_int_type(plier::PyType type) { return mlir::IntegerType::get(num_bits, type.getContext()); } - return {}; + return nullptr; } mlir::Type map_int_literal_type(plier::PyType type) @@ -33,7 +33,7 @@ mlir::Type map_int_literal_type(plier::PyType type) { return mlir::IntegerType::get(64, type.getContext()); // TODO } - return {}; + return nullptr; } mlir::Type map_bool_type(plier::PyType type) @@ -43,7 +43,25 @@ mlir::Type map_bool_type(plier::PyType type) { return mlir::IntegerType::get(1, type.getContext()); } - return {}; + return nullptr; +} + +mlir::Type map_float_type(plier::PyType type) +{ + auto name = type.getName(); + unsigned num_bits = 0; + if (name.consume_front("float") && + !name.consumeInteger(10, num_bits) && name.empty()) + { + auto ctx = type.getContext(); + switch(num_bits) + { + case 64: return mlir::Float64Type::get(ctx); + case 32: return mlir::Float32Type::get(ctx); + case 16: return mlir::Float16Type::get(ctx); + } + } + return nullptr; } mlir::Type map_plier_type(mlir::Type type) @@ -59,6 +77,7 @@ mlir::Type map_plier_type(mlir::Type type) &map_int_type, &map_int_literal_type, &map_bool_type, + &map_float_type }; for (auto h : handlers) { @@ -175,10 +194,21 @@ struct BinOpLowering : public mlir::OpRewritePattern assert(op.getNumOperands() == 2); auto type0 = op.getOperand(0).getType(); auto type1 = op.getOperand(1).getType(); - if (type0 != type1 || !is_supported_type(type0) || !is_supported_type(type1)) + if (!is_supported_type(type0) || !is_supported_type(type1)) { return mlir::failure(); } + mlir::Type final_type; + if (type0 != type1) + { + // TODO: coerce + return mlir::failure(); + } + else + { + final_type = type0; + } + assert(static_cast(final_type)); using func_t = void(*)(mlir::Operation*, mlir::PatternRewriter&, mlir::Type); struct OpDesc @@ -214,7 +244,7 @@ struct BinOpLowering : public mlir::OpRewritePattern { if (h.type == op.op()) { - (h.*mem)(op, rewriter, type0); + (h.*mem)(op, rewriter, final_type); return mlir::success(); } } @@ -222,11 +252,11 @@ struct BinOpLowering : public mlir::OpRewritePattern }; - if (is_int(type0)) + if (is_int(final_type)) { return call_handler(&OpDesc::iop); } - else if (is_float(type0)) + else if (is_float(final_type)) { return call_handler(&OpDesc::fop); } From c7034658a177a3339641469830cb19d105122c41 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 04:07:32 +0300 Subject: [PATCH 068/259] float handling --- mlir-compiler/include/plier/PlierOps.td | 6 +- mlir-compiler/src/dialect.cpp | 6 - mlir-compiler/src/lowering.cpp | 29 ++-- mlir-compiler/src/passes/plier_to_std.cpp | 190 +++++++++++++++------- numba/core/typed_passes.py | 1 + numba/mlir/tests/test_basic.py | 16 +- 6 files changed, 158 insertions(+), 90 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 570b376dbf3..594bcef90df 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -67,14 +67,10 @@ def BinOp : Plier_Op<"binop", []> { def CastOp : Plier_Op<"cast", []> { let arguments = (ins - Plier_PyType:$value); + AnyType:$value); let results = (outs AnyType); let hasFolder = 1; - - let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value val"> - ]; } def PyCallOp : Plier_Op<"call", []> { diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 08adc2501c5..faf9352385c 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -108,12 +108,6 @@ void BinOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, rhs, op); } -void CastOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value val) { - CastOp::build(builder, state, PyType::getUndefined(state.getContext()), - val); -} - mlir::OpFoldResult CastOp::fold(llvm::ArrayRef /*operands*/) { auto op_type = getOperand().getType(); diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 4bb88c25a4d..8ff6bf0fe8c 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -137,7 +137,7 @@ struct plier_lowerer auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); typemap = compilation_context["typemap"]; auto name = compilation_context["fnname"]().cast(); - auto typ = get_func_type(compilation_context["fnargs"]); + auto typ = get_func_type(compilation_context["fnargs"], compilation_context["restype"]); func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); lower_func_body(func_ir); mod.push_back(func); @@ -276,7 +276,7 @@ struct plier_lowerer using func_t = mlir::Value (plier_lowerer::*)(const py::handle&); const std::pair handlers[] = { {"binop", &plier_lowerer::lower_binop}, - {"cast", &plier_lowerer::lower_simple}, + {"cast", &plier_lowerer::lower_cast}, {"call", &plier_lowerer::lower_call}, {"phi", &plier_lowerer::lower_phi}, {"build_tuple", &plier_lowerer::lower_build_tuple}, @@ -303,6 +303,13 @@ struct plier_lowerer return builder.create(builder.getUnknownLoc(), value); } + mlir::Value lower_cast(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + auto res_type = get_type(current_instr.attr("target")); + return builder.create(builder.getUnknownLoc(), res_type, value); + } + mlir::Value lower_static_getitem(const py::handle& inst) { auto value = loadvar(inst.attr("value")); @@ -418,20 +425,14 @@ struct plier_lowerer void retvar(const py::handle& inst) { auto var = loadvar(inst); - builder.create(builder.getUnknownLoc(), var); auto func_type = func.getType(); auto ret_type = func_type.getResult(0); - auto new_ret_type = var.getType(); - if (ret_type != new_ret_type) + auto var_type = var.getType(); + if (ret_type != var_type) { - auto def_type = plier::PyType::getUndefined(&ctx); - if (ret_type != def_type) - { - report_error(llvm::Twine("Conflicting return types: ") + to_str(ret_type) + " and " + to_str(new_ret_type)); - } - auto new_func_type = mlir::FunctionType::get(func_type.getInputs(), new_ret_type, &ctx); - func.setType(new_func_type); + var = builder.create(builder.getUnknownLoc(), ret_type, var); } + builder.create(builder.getUnknownLoc(), var); } void branch(const py::handle& cond, const py::handle& tr, const py::handle& fl) @@ -458,9 +459,9 @@ struct plier_lowerer report_error(llvm::Twine("get_const_val unhandled type \"") + py::str(val.get_type()).cast() + "\""); } - mlir::FunctionType get_func_type(const py::handle& fnargs) + mlir::FunctionType get_func_type(const py::handle& fnargs, const py::handle& restype) { - auto ret = plier::PyType::getUndefined(&ctx); + auto ret = get_obj_type(restype()); llvm::SmallVector args; for (auto arg : fnargs()) { diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 300deb927ec..47830378883 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -143,19 +143,19 @@ bool convert_func_sig(mlir::FuncOp func) } template -void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type) +void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) { assert(nullptr != op); - rewriter.replaceOpWithNewOp(op, new_type, op->getOperands()); + rewriter.replaceOpWithNewOp(op, new_type, operands); } template -void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/) +void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) { assert(nullptr != op); auto pred_attr = mlir::IntegerAttr::get(mlir::IntegerType::get(64, op->getContext()), Pred); mlir::Type new_type = mlir::IntegerType::get(1, op->getContext()); - rewriter.replaceOpWithNewOp(op, new_type, pred_attr, op->getOperand(0), op->getOperand(1)); + rewriter.replaceOpWithNewOp(op, new_type, pred_attr, operands[0], operands[1]); } bool is_int(mlir::Type type) @@ -184,6 +184,108 @@ struct ConstOpLowering : public mlir::OpRewritePattern } }; +mlir::Type coerce(mlir::Type type0, mlir::Type type1) +{ + // TODO: proper rules + assert(type0 != type1); + auto get_bits_count = [](mlir::Type type)->unsigned + { + if (type.isa()) + { + return type.cast().getWidth(); + } + if (type.isa()) + { + return 11; + } + if (type.isa()) + { + return 24; + } + if (type.isa()) + { + return 53; + } + llvm_unreachable("Unhandled type"); + }; + auto f0 = is_float(type0); + auto f1 = is_float(type1); + if (f0 && !f1) + { + return type0; + } + if (!f0 && f1) + { + return type1; + } + return get_bits_count(type0) < get_bits_count(type1) ? type1 : type0; +} + +template +mlir::Value int_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + auto src_bits = val.getType().cast().getWidth(); + auto dst_bits = dst_type.cast().getWidth(); + assert(src_bits != dst_bits); + if (dst_bits > src_bits) + { + using T = std::conditional_t; + return rewriter.create(val.getLoc(), val, dst_type); + } + else + { + return rewriter.create(val.getLoc(), val, dst_type); + } +} + +template +mlir::Value int_float_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + using T = std::conditional_t; + return rewriter.create(val.getLoc(), val, dst_type); +} + +template +mlir::Value float_int_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + using T = std::conditional_t; + return rewriter.create(val.getLoc(), val, dst_type); +} + +mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + auto src_type = val.getType(); + if (src_type == dst_type) + { + return val; + } + + struct Handler + { + using selector_t = bool(*)(mlir::Type); + using cast_op_t = mlir::Value(*)(mlir::Type, mlir::Value, mlir::PatternRewriter&); + selector_t src; + selector_t dst; + cast_op_t cast_op; + }; + + const Handler handlers[] = { + {&is_int, &is_int, &int_cast}, + {&is_int, &is_float, &int_float_cast}, + {&is_float, &is_int, &float_int_cast}, + }; + + for (auto& h : handlers) + { + if (h.src(src_type) && h.dst(dst_type)) + { + return h.cast_op(dst_type, val, rewriter); + } + } + + llvm_unreachable("Unhandled cast"); +} + struct BinOpLowering : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -199,18 +301,21 @@ struct BinOpLowering : public mlir::OpRewritePattern return mlir::failure(); } mlir::Type final_type; + std::array operands; if (type0 != type1) { - // TODO: coerce - return mlir::failure(); + final_type = coerce(type0, type1); + operands = {do_cast(final_type, op.getOperand(0), rewriter), + do_cast(final_type, op.getOperand(1), rewriter)}; } else { final_type = type0; + operands = {op.getOperand(0), op.getOperand(1)}; } assert(static_cast(final_type)); - using func_t = void(*)(mlir::Operation*, mlir::PatternRewriter&, mlir::Type); + using func_t = void(*)(mlir::Operation*, mlir::PatternRewriter&, mlir::Type, mlir::ValueRange); struct OpDesc { llvm::StringRef type; @@ -244,7 +349,7 @@ struct BinOpLowering : public mlir::OpRewritePattern { if (h.type == op.op()) { - (h.*mem)(op, rewriter, final_type); + (h.*mem)(op, rewriter, final_type, operands); return mlir::success(); } } @@ -264,55 +369,6 @@ struct BinOpLowering : public mlir::OpRewritePattern } }; -template -mlir::Value int_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) -{ - auto src_bits = val.getType().cast().getWidth(); - auto dst_bits = dst_type.cast().getWidth(); - assert(src_bits != dst_bits); - if (dst_bits > src_bits) - { - using T = std::conditional_t; - return rewriter.create(val.getLoc(), val, dst_type); - } - else - { - return rewriter.create(val.getLoc(), val, dst_type); - } -} - -mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) -{ - auto src_type = val.getType(); - if (src_type == dst_type) - { - return val; - } - - struct Handler - { - using selector_t = bool(*)(mlir::Type); - using cast_op_t = mlir::Value(*)(mlir::Type, mlir::Value, mlir::PatternRewriter&); - selector_t src; - selector_t dst; - cast_op_t cast_op; - }; - - const Handler handlers[] = { - {&is_int, &is_int, &int_cast}, - }; - - for (auto& h : handlers) - { - if (h.src(src_type) && h.dst(dst_type)) - { - return h.cast_op(dst_type, val, rewriter); - } - } - - return nullptr; -} - mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, mlir::PatternRewriter& rewriter) { if (op.getNumOperands() != 2) @@ -374,6 +430,24 @@ struct CallOpLowering : public mlir::OpRewritePattern } }; +struct CastOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(plier::CastOp op, mlir::PatternRewriter& rewriter) const override + { + auto src_type = op.getOperand().getType(); + auto dst_type = op.getType(); + if (is_supported_type(src_type) && is_supported_type(dst_type)) + { + auto new_op = do_cast(dst_type, op.getOperand(), rewriter); + rewriter.replaceOp(op, new_op); + return mlir::success(); + } + return mlir::failure(); + } +}; struct FuncOpSignatureConversion : public mlir::OpRewritePattern { @@ -457,7 +531,7 @@ void PlierToStdPass::runOnOperation() patterns.insert(&getContext(), type_converter); patterns.insert(&getContext()); + CallOpLowering, CastOpLowering>(&getContext()); auto apply_conv = [&]() { diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 6414f8d1449..7edddf33ca5 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -498,6 +498,7 @@ def run_pass(self, state): ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': False} ctx['typemap'] = lambda op: state.typemap[op.name] ctx['fnargs'] = lambda: state.args + ctx['restype'] = lambda: state.return_type ctx['fnname'] = lambda: fn_name import mlir_compiler mod = mlir_compiler.lower_normal_function(ctx, state.func_ir) diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 3d22c4c2c2d..e3a3e5f5833 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -1,12 +1,14 @@ import numba from numba import njit +from math import nan, inf +from numpy.testing import assert_equal # for nans comparison from numba.tests.support import TestCase import unittest import itertools -_test_values = [-3,-2,-1,0,1,2,3] +_test_values = [-3,-2,-1,0,1,2,3,-2.5,-1.0,-0.5 -0.0, 0.0, 0.5, 1.0, 2.5, -inf, inf] # nans class TestMlirBasic(TestCase): def test_ret(self): @@ -15,7 +17,7 @@ def py_func(a): jit_func = njit(py_func) for val in _test_values: - self.assertEqual(py_func(val), jit_func(val)) + assert_equal(py_func(val), jit_func(val)) def test_ops(self): py_funcs = [ @@ -28,7 +30,7 @@ def test_ops(self): for py_func in py_funcs: jit_func = njit(py_func) for a, b in itertools.product(_test_values, _test_values): - self.assertEqual(py_func(a, b), jit_func(a, b)) + assert_equal(py_func(a, b), jit_func(a, b)) def test_cmp_ops(self): py_funcs = [ @@ -43,7 +45,7 @@ def test_cmp_ops(self): for py_func in py_funcs: jit_func = njit(py_func) for a, b in itertools.product(_test_values, _test_values): - self.assertEqual(py_func(a, b), jit_func(a, b)) + assert_equal(py_func(a, b), jit_func(a, b)) def test_const_ops(self): py_funcs = [ @@ -54,7 +56,7 @@ def test_const_ops(self): for py_func in py_funcs: jit_func = njit(py_func) for val in _test_values: - self.assertEqual(py_func(val), jit_func(val)) + assert_equal(py_func(val), jit_func(val)) def test_var(self): def py_func(a): @@ -64,7 +66,7 @@ def py_func(a): jit_func = njit(py_func) for val in _test_values: - self.assertEqual(py_func(val), jit_func(val)) + assert_equal(py_func(val), jit_func(val)) def test_jump(self): def py_func(a, b): @@ -76,7 +78,7 @@ def py_func(a, b): jit_func = njit(py_func) for a, b in itertools.product(_test_values, _test_values): - self.assertEqual(py_func(a, b), jit_func(a, b)) + assert_equal(py_func(a, b), jit_func(a, b)) if __name__ == '__main__': From 4253134f999c0fb9e30790dfd949903d35bdeff9 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 13:19:23 +0300 Subject: [PATCH 069/259] refac --- mlir-compiler/src/passes/lower_to_llvm.cpp | 3 ++- mlir-compiler/src/passes/lower_to_llvm.hpp | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index 62d8c259283..b0d01223a54 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -282,9 +282,10 @@ struct PostLLVMLowering : } } }; + } -void populate_lower_to_llvm_pipeline(mlir::PassManager& pm) +void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); pm.addPass(std::make_unique()); diff --git a/mlir-compiler/src/passes/lower_to_llvm.hpp b/mlir-compiler/src/passes/lower_to_llvm.hpp index 15423505a84..3a2c9aec992 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.hpp +++ b/mlir-compiler/src/passes/lower_to_llvm.hpp @@ -2,7 +2,7 @@ namespace mlir { -class PassManager; +class OpPassManager; } -void populate_lower_to_llvm_pipeline(mlir::PassManager& pm); +void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm); From 7b4d4245aa628f90c22d4239273dc6f8d80cfa8d Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 13:40:10 +0300 Subject: [PATCH 070/259] pass registry stub --- mlir-compiler/CMakeLists.txt | 2 ++ mlir-compiler/src/passes/pass_registry.cpp | 26 ++++++++++++++++ mlir-compiler/src/passes/pass_registry.hpp | 36 ++++++++++++++++++++++ 3 files changed, 64 insertions(+) create mode 100644 mlir-compiler/src/passes/pass_registry.cpp create mode 100644 mlir-compiler/src/passes/pass_registry.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 50e69387c84..3170a86636c 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -25,6 +25,7 @@ set(SOURCES_LIST src/lowering.cpp src/module.cpp src/passes/lower_to_llvm.cpp + src/passes/pass_registry.cpp src/passes/plier_to_std.cpp src/utils.cpp ) @@ -32,6 +33,7 @@ set(HEADERS_LIST src/compiler.hpp src/lowering.hpp src/passes/lower_to_llvm.hpp + src/passes/pass_registry.hpp src/passes/plier_to_std.hpp src/utils.hpp include/plier/dialect.hpp diff --git a/mlir-compiler/src/passes/pass_registry.cpp b/mlir-compiler/src/passes/pass_registry.cpp new file mode 100644 index 00000000000..2ba1debcdc6 --- /dev/null +++ b/mlir-compiler/src/passes/pass_registry.cpp @@ -0,0 +1,26 @@ +#include "pass_registry.hpp" + + +void pass_registry::register_pipeline(pass_registry::registry_entry_t func) +{ + assert(nullptr != func); + pipelines.push_back(std::move(func)); +} + +void pass_registry::populate_pass_manager(mlir::OpPassManager& pm) const +{ + // TODO: build proper dep graph + auto sink = [&](llvm::StringRef /*pipeline_name*/, + llvm::ArrayRef /*prev_pipelines*/, + llvm::ArrayRef /*next_pipelines*/, + pipeline_funt_t func) + { + assert(nullptr != func); + func(pm); + }; + for (auto& p : pipelines) + { + assert(nullptr != p); + p(sink); + } +} diff --git a/mlir-compiler/src/passes/pass_registry.hpp b/mlir-compiler/src/passes/pass_registry.hpp new file mode 100644 index 00000000000..99f35d2549c --- /dev/null +++ b/mlir-compiler/src/passes/pass_registry.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace mlir +{ +class OpPassManager; +} + + +class pass_registry +{ +public: + pass_registry() = default; + pass_registry(const pass_registry&) = delete; + + using pipeline_funt_t = void(*)(mlir::OpPassManager&); + using registry_entry_sink_t = void( + llvm::StringRef pipeline_name, + llvm::ArrayRef prev_pipelines, + llvm::ArrayRef next_pipelines, + pipeline_funt_t func); + using registry_entry_t = std::function)>; + + void register_pipeline(registry_entry_t func); + + void populate_pass_manager(mlir::OpPassManager& pm) const; + +private: + std::vector pipelines; +}; From bf46a772724513ca0a12c8ed1e7393293ff12b79 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 13:50:12 +0300 Subject: [PATCH 071/259] refac --- mlir-compiler/src/compiler.cpp | 7 +------ mlir-compiler/src/passes/plier_to_std.cpp | 7 +++++-- mlir-compiler/src/passes/plier_to_std.hpp | 6 ++---- numba/mlir/tests/test_basic.py | 2 +- 4 files changed, 9 insertions(+), 13 deletions(-) diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index efa7209f4d1..b6b983f0777 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -4,12 +4,9 @@ #include #include #include -#include -#include #include -#include #include #include "utils.hpp" @@ -24,9 +21,7 @@ class CompilerContext::CompilerContextImpl const CompilerContext::Settings& settings): pm(&ctx, settings.verify) { - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(createPlierToStdPass()); - + populate_plier_to_std_pipeline(pm); populate_lower_to_llvm_pipeline(pm); if (settings.pass_statistics) diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 47830378883..6fe9c69ddc9 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -3,7 +3,9 @@ #include #include #include +#include #include +#include #include @@ -547,7 +549,8 @@ void PlierToStdPass::runOnOperation() } -std::unique_ptr createPlierToStdPass() +void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) { - return std::make_unique(); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(std::make_unique()); } diff --git a/mlir-compiler/src/passes/plier_to_std.hpp b/mlir-compiler/src/passes/plier_to_std.hpp index b43c608fb81..5adcb2ebbe2 100644 --- a/mlir-compiler/src/passes/plier_to_std.hpp +++ b/mlir-compiler/src/passes/plier_to_std.hpp @@ -1,10 +1,8 @@ #pragma once -#include - namespace mlir { -class Pass; +class OpPassManager; } -std::unique_ptr createPlierToStdPass(); +void populate_plier_to_std_pipeline(mlir::OpPassManager& pm); diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index e3a3e5f5833..b9f80e4d194 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -8,7 +8,7 @@ import itertools -_test_values = [-3,-2,-1,0,1,2,3,-2.5,-1.0,-0.5 -0.0, 0.0, 0.5, 1.0, 2.5, -inf, inf] # nans +_test_values = [-3,-2,-1,0,1,2,3,-2.5,-1.0,-0.5 -0.0, 0.0, 0.5, 1.0, 2.5, -inf, inf] # TODO: nans class TestMlirBasic(TestCase): def test_ret(self): From 9dbaef99b9d3ee27b9bcca1992e5ce2813d2b570 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 15:18:00 +0300 Subject: [PATCH 072/259] work on pipeline deps --- mlir-compiler/src/passes/pass_registry.cpp | 139 ++++++++++++++++++++- 1 file changed, 134 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/src/passes/pass_registry.cpp b/mlir-compiler/src/passes/pass_registry.cpp index 2ba1debcdc6..4a6e2896542 100644 --- a/mlir-compiler/src/passes/pass_registry.cpp +++ b/mlir-compiler/src/passes/pass_registry.cpp @@ -1,5 +1,14 @@ #include "pass_registry.hpp" +#include +#include +#include + +#include "utils.hpp" + +#include +#include +#include void pass_registry::register_pipeline(pass_registry::registry_entry_t func) { @@ -7,20 +16,140 @@ void pass_registry::register_pipeline(pass_registry::registry_entry_t func) pipelines.push_back(std::move(func)); } +namespace +{ +template +void topo_visit(T& elem, IterF&& iter_func, VisitF&& func) +{ + if (elem.visited) + { + return; + } + elem.visited = true; + iter_func(elem, [&](T& next) + { + topo_visit(next, std::forward(iter_func), std::forward(func)); + }); + func(elem); +} +} + void pass_registry::populate_pass_manager(mlir::OpPassManager& pm) const { - // TODO: build proper dep graph - auto sink = [&](llvm::StringRef /*pipeline_name*/, - llvm::ArrayRef /*prev_pipelines*/, - llvm::ArrayRef /*next_pipelines*/, + llvm::BumpPtrAllocator allocator; + llvm::UniqueStringSaver string_set(allocator); + + using name_id = const void*; + auto get_id = [](llvm::StringRef name)->name_id + { + assert(!name.empty()); + return name.data(); + }; + std::set pipelines_ordered; // sorted map to make order consistent + + auto get_pipeline = [&](llvm::StringRef name)->llvm::StringRef + { + if (name.empty()) + { + report_error("Empty pipeline name"); + } + auto str = string_set.save(name); + pipelines_ordered.insert(str); + return str; + }; + + struct IdSet : protected llvm::SmallVector + { + using Base = llvm::SmallVector; + using Base::begin; + using Base::end; + void push_back(name_id id) + { + auto it = std::equal_range(begin(), end(), id); + if (it.first == it.second) + { + insert(it.first, id); + } + } + }; + + struct PipelineInfo + { + llvm::StringRef name; + llvm::SmallVector prev_pipelines; + llvm::SmallVector next_pipelines; + pipeline_funt_t func = nullptr; + bool visited = false; + }; + + std::unordered_map pipelines_map; + + auto sink = [&](llvm::StringRef pipeline_name, + llvm::ArrayRef prev_pipelines, + llvm::ArrayRef next_pipelines, pipeline_funt_t func) { assert(nullptr != func); - func(pm); + auto i = get_pipeline(pipeline_name); + auto it = pipelines_map.insert({get_id(i), {}}); + if (!it.second) + { + report_error("Duplicated pipeline name"); + } + auto& info = it.first->second; + info.name = i; + info.func = func; + llvm::transform(prev_pipelines, std::back_inserter(info.prev_pipelines), get_pipeline); + llvm::transform(next_pipelines, std::back_inserter(info.next_pipelines), get_pipeline); }; + for (auto& p : pipelines) { assert(nullptr != p); p(sink); } + + auto get_pipeline_info = [&](llvm::StringRef name)->PipelineInfo& + { + auto id = get_id(name); + auto it = pipelines_map.find(id); + if (it == pipelines_map.end()) + { + report_error(llvm::Twine("Pipeline not found") + name); + } + return it->second; + }; + + // Make all deps bidirectional + for (auto name : pipelines_ordered) + { + auto& info = get_pipeline_info(name); + for (auto prev : info.prev_pipelines) + { + auto& prev_info = get_pipeline_info(prev); + prev_info.next_pipelines.push_back(name); + } + for (auto next : info.next_pipelines) + { + auto& next_info = get_pipeline_info(next); + next_info.prev_pipelines.push_back(name); + } + } + + // toposort + for (auto name : pipelines_ordered) + { + auto iter_func = [&](const PipelineInfo& elem, auto func) + { + for (auto prev : elem.prev_pipelines) + { + func(get_pipeline_info(prev)); + } + }; + auto visit_func = [&](const PipelineInfo& elem) + { + elem.func(pm); + }; + topo_visit(get_pipeline_info(name), iter_func, visit_func); + } } From a6448802945e345dabbb2191cf876ae93a0d4085 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 15:31:14 +0300 Subject: [PATCH 073/259] check for graph cycles --- mlir-compiler/src/passes/pass_registry.cpp | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/passes/pass_registry.cpp b/mlir-compiler/src/passes/pass_registry.cpp index 4a6e2896542..5a3c3621ae4 100644 --- a/mlir-compiler/src/passes/pass_registry.cpp +++ b/mlir-compiler/src/passes/pass_registry.cpp @@ -79,6 +79,7 @@ void pass_registry::populate_pass_manager(mlir::OpPassManager& pm) const llvm::SmallVector prev_pipelines; llvm::SmallVector next_pipelines; pipeline_funt_t func = nullptr; + PipelineInfo* next = nullptr; bool visited = false; }; @@ -137,19 +138,33 @@ void pass_registry::populate_pass_manager(mlir::OpPassManager& pm) const } // toposort + PipelineInfo* first_pipeline = nullptr; for (auto name : pipelines_ordered) { auto iter_func = [&](const PipelineInfo& elem, auto func) { for (auto prev : elem.prev_pipelines) { + if (get_id(prev) == get_id(name)) + { + report_error(llvm::Twine("Pipeline depends on itself: ") + name); + } func(get_pipeline_info(prev)); } }; - auto visit_func = [&](const PipelineInfo& elem) + auto visit_func = [&](PipelineInfo& elem) { - elem.func(pm); + assert(nullptr == elem.next); + elem.next = first_pipeline; + first_pipeline = &elem; }; topo_visit(get_pipeline_info(name), iter_func, visit_func); } + + for (auto current = first_pipeline; nullptr != first_pipeline; + first_pipeline = first_pipeline->next) + { + assert(nullptr != current); + current->func(pm); + } } From 7490713273fbcc4d9e4b7a8173d288824becddb3 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 11 Oct 2020 15:34:55 +0300 Subject: [PATCH 074/259] refac --- mlir-compiler/CMakeLists.txt | 4 ++-- .../src/{passes/pass_registry.cpp => pipeline_registry.cpp} | 6 +++--- .../src/{passes/pass_registry.hpp => pipeline_registry.hpp} | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) rename mlir-compiler/src/{passes/pass_registry.cpp => pipeline_registry.cpp} (96%) rename mlir-compiler/src/{passes/pass_registry.hpp => pipeline_registry.hpp} (86%) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 3170a86636c..e090b9b70e9 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -25,16 +25,16 @@ set(SOURCES_LIST src/lowering.cpp src/module.cpp src/passes/lower_to_llvm.cpp - src/passes/pass_registry.cpp src/passes/plier_to_std.cpp + src/pipeline_registry.cpp src/utils.cpp ) set(HEADERS_LIST src/compiler.hpp src/lowering.hpp src/passes/lower_to_llvm.hpp - src/passes/pass_registry.hpp src/passes/plier_to_std.hpp + src/pipeline_registry.hpp src/utils.hpp include/plier/dialect.hpp include/plier/PlierOps.td diff --git a/mlir-compiler/src/passes/pass_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp similarity index 96% rename from mlir-compiler/src/passes/pass_registry.cpp rename to mlir-compiler/src/pipeline_registry.cpp index 5a3c3621ae4..88b83d31e00 100644 --- a/mlir-compiler/src/passes/pass_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -1,4 +1,4 @@ -#include "pass_registry.hpp" +#include "pipeline_registry.hpp" #include #include @@ -10,7 +10,7 @@ #include #include -void pass_registry::register_pipeline(pass_registry::registry_entry_t func) +void PipelineRegistry::register_pipeline(PipelineRegistry::registry_entry_t func) { assert(nullptr != func); pipelines.push_back(std::move(func)); @@ -34,7 +34,7 @@ void topo_visit(T& elem, IterF&& iter_func, VisitF&& func) } } -void pass_registry::populate_pass_manager(mlir::OpPassManager& pm) const +void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const { llvm::BumpPtrAllocator allocator; llvm::UniqueStringSaver string_set(allocator); diff --git a/mlir-compiler/src/passes/pass_registry.hpp b/mlir-compiler/src/pipeline_registry.hpp similarity index 86% rename from mlir-compiler/src/passes/pass_registry.hpp rename to mlir-compiler/src/pipeline_registry.hpp index 99f35d2549c..3b348be8533 100644 --- a/mlir-compiler/src/passes/pass_registry.hpp +++ b/mlir-compiler/src/pipeline_registry.hpp @@ -13,11 +13,11 @@ class OpPassManager; } -class pass_registry +class PipelineRegistry { public: - pass_registry() = default; - pass_registry(const pass_registry&) = delete; + PipelineRegistry() = default; + PipelineRegistry(const PipelineRegistry&) = delete; using pipeline_funt_t = void(*)(mlir::OpPassManager&); using registry_entry_sink_t = void( From d4cd61752a3376f7421423100e3eddd5925205a5 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 12 Oct 2020 14:28:52 +0300 Subject: [PATCH 075/259] fix --- mlir-compiler/src/pipeline_registry.cpp | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp index 88b83d31e00..a5bc42504c3 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -26,10 +26,12 @@ void topo_visit(T& elem, IterF&& iter_func, VisitF&& func) return; } elem.visited = true; + elem.iterating = true; iter_func(elem, [&](T& next) { topo_visit(next, std::forward(iter_func), std::forward(func)); }); + elem.iterating = false; func(elem); } } @@ -81,6 +83,7 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const pipeline_funt_t func = nullptr; PipelineInfo* next = nullptr; bool visited = false; + bool iterating = false; }; std::unordered_map pipelines_map; @@ -143,12 +146,12 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const { auto iter_func = [&](const PipelineInfo& elem, auto func) { + if (elem.iterating) + { + report_error(llvm::Twine("Pipeline depends on itself: ") + elem.name); + } for (auto prev : elem.prev_pipelines) { - if (get_id(prev) == get_id(name)) - { - report_error(llvm::Twine("Pipeline depends on itself: ") + name); - } func(get_pipeline_info(prev)); } }; From a20dafcc4a99b942eb59ea63d5e233bdb019dd97 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 12 Oct 2020 14:28:52 +0300 Subject: [PATCH 076/259] refac --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/src/compiler.cpp | 15 ++++--- mlir-compiler/src/compiler.hpp | 4 +- mlir-compiler/src/lowering.cpp | 17 +++++++- mlir-compiler/src/passes/base_pipeline.cpp | 20 +++++++++ mlir-compiler/src/passes/base_pipeline.hpp | 5 +++ mlir-compiler/src/passes/lower_to_llvm.cpp | 14 +++++- mlir-compiler/src/passes/lower_to_llvm.hpp | 7 +-- mlir-compiler/src/passes/plier_to_std.cpp | 14 +++++- mlir-compiler/src/passes/plier_to_std.hpp | 7 +-- mlir-compiler/src/pipeline_registry.cpp | 51 +++++++++++++--------- 11 files changed, 113 insertions(+), 43 deletions(-) create mode 100644 mlir-compiler/src/passes/base_pipeline.cpp create mode 100644 mlir-compiler/src/passes/base_pipeline.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index e090b9b70e9..a527e30f14a 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -24,6 +24,7 @@ set(SOURCES_LIST src/dialect.cpp src/lowering.cpp src/module.cpp + src/passes/base_pipeline.cpp src/passes/lower_to_llvm.cpp src/passes/plier_to_std.cpp src/pipeline_registry.cpp @@ -32,6 +33,7 @@ set(SOURCES_LIST set(HEADERS_LIST src/compiler.hpp src/lowering.hpp + src/passes/base_pipeline.hpp src/passes/lower_to_llvm.hpp src/passes/plier_to_std.hpp src/pipeline_registry.hpp diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index b6b983f0777..1aad5a5d9ee 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -11,18 +11,17 @@ #include "utils.hpp" -#include "passes/plier_to_std.hpp" -#include "passes/lower_to_llvm.hpp" +#include "pipeline_registry.hpp" class CompilerContext::CompilerContextImpl { public: CompilerContextImpl(mlir::MLIRContext& ctx, - const CompilerContext::Settings& settings): + const CompilerContext::Settings& settings, + const PipelineRegistry& registry): pm(&ctx, settings.verify) { - populate_plier_to_std_pipeline(pm); - populate_lower_to_llvm_pipeline(pm); + registry.populate_pass_manager(pm); if (settings.pass_statistics) { @@ -64,8 +63,10 @@ class CompilerContext::CompilerContextImpl mlir::PassManager pm; }; -CompilerContext::CompilerContext(mlir::MLIRContext& ctx, const Settings& settings): - impl(std::make_unique(ctx, settings)) +CompilerContext::CompilerContext(mlir::MLIRContext& ctx, + const Settings& settings, + const PipelineRegistry& registry): + impl(std::make_unique(ctx, settings, registry)) { } diff --git a/mlir-compiler/src/compiler.hpp b/mlir-compiler/src/compiler.hpp index d168abc8dc0..f08538c77be 100644 --- a/mlir-compiler/src/compiler.hpp +++ b/mlir-compiler/src/compiler.hpp @@ -8,6 +8,7 @@ class MLIRContext; class ModuleOp; } +class PipelineRegistry; class CompilerContext { public: @@ -21,7 +22,8 @@ class CompilerContext class CompilerContextImpl; - CompilerContext(mlir::MLIRContext& ctx, const Settings& settings); + CompilerContext(mlir::MLIRContext& ctx, const Settings& settings, + const PipelineRegistry& registry); ~CompilerContext(); CompilerContext(CompilerContext&&) = default; diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 8ff6bf0fe8c..826b57c3301 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -19,8 +19,13 @@ #include "plier/dialect.hpp" #include "compiler.hpp" +#include "pipeline_registry.hpp" #include "utils.hpp" +#include "passes/base_pipeline.hpp" +#include "passes/plier_to_std.hpp" +#include "passes/lower_to_llvm.hpp" + namespace py = pybind11; namespace { @@ -570,6 +575,13 @@ py::bytes gen_ll_module(mlir::ModuleOp mod) // ll_mod->dump(); return serialize_mod(*ll_mod); } + +void create_pipeline(PipelineRegistry& registry) +{ + register_base_pipeline(registry); + register_lower_to_llvm_pipeline(registry); + register_plier_to_std_pipeline(registry); +} } py::bytes lower_function(const py::object& compilation_context, const py::object& func_ir) @@ -578,7 +590,10 @@ py::bytes lower_function(const py::object& compilation_context, const py::object mlir::registerDialect(); mlir::MLIRContext context; auto mod = plier_lowerer(context).lower(compilation_context, func_ir); - CompilerContext compiler(context, get_settings(compilation_context["compiler_settings"])); + PipelineRegistry registry; + create_pipeline(registry); + auto settings = get_settings(compilation_context["compiler_settings"]); + CompilerContext compiler(context, settings, registry); compiler.run(mod); return gen_ll_module(mod); } diff --git a/mlir-compiler/src/passes/base_pipeline.cpp b/mlir-compiler/src/passes/base_pipeline.cpp new file mode 100644 index 00000000000..bef2f592859 --- /dev/null +++ b/mlir-compiler/src/passes/base_pipeline.cpp @@ -0,0 +1,20 @@ +#include "passes/base_pipeline.hpp" + +#include "pipeline_registry.hpp" + +void register_base_pipeline(PipelineRegistry& registry) +{ + auto dummu_func = [](mlir::OpPassManager&){}; + registry.register_pipeline([&](auto sink) + { + sink("init", {}, {}, dummu_func); + }); + registry.register_pipeline([&](auto sink) + { + sink("lowering", {"init"}, {}, dummu_func); + }); + registry.register_pipeline([&](auto sink) + { + sink("terminate", {"lowering"}, {}, dummu_func); + }); +} diff --git a/mlir-compiler/src/passes/base_pipeline.hpp b/mlir-compiler/src/passes/base_pipeline.hpp new file mode 100644 index 00000000000..cc4e9fdc8fa --- /dev/null +++ b/mlir-compiler/src/passes/base_pipeline.hpp @@ -0,0 +1,5 @@ +#pragma once + +class PipelineRegistry; + +void register_base_pipeline(PipelineRegistry& registry);; diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index b0d01223a54..880db4c7769 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -16,6 +16,8 @@ #include "plier/dialect.hpp" +#include "pipeline_registry.hpp" + #include "utils.hpp" namespace @@ -283,8 +285,6 @@ struct PostLLVMLowering : } }; -} - void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); @@ -292,3 +292,13 @@ void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) pm.addPass(mlir::createLowerToLLVMPass(getLLVMOptions())); pm.addPass(std::make_unique()); } +} + + +void register_lower_to_llvm_pipeline(PipelineRegistry& registry) +{ + registry.register_pipeline([](auto sink) + { + sink("lower_to_llvm", {"lowering"}, {"terminate"}, &populate_lower_to_llvm_pipeline); + }); +} diff --git a/mlir-compiler/src/passes/lower_to_llvm.hpp b/mlir-compiler/src/passes/lower_to_llvm.hpp index 3a2c9aec992..61c6947c37b 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.hpp +++ b/mlir-compiler/src/passes/lower_to_llvm.hpp @@ -1,8 +1,5 @@ #pragma once -namespace mlir -{ -class OpPassManager; -} +class PipelineRegistry; -void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm); +void register_lower_to_llvm_pipeline(PipelineRegistry& registry); diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 6fe9c69ddc9..1dde3492e75 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -11,6 +11,8 @@ #include "plier/dialect.hpp" +#include "pipeline_registry.hpp" + namespace { mlir::Type map_int_type(plier::PyType type) @@ -547,10 +549,18 @@ void PlierToStdPass::runOnOperation() } } -} - void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(std::make_unique()); } +} + + +void register_plier_to_std_pipeline(PipelineRegistry& registry) +{ + registry.register_pipeline([](auto sink) + { + sink("plier_to_std", {"init"}, {"lowering"}, &populate_plier_to_std_pipeline); + }); +} diff --git a/mlir-compiler/src/passes/plier_to_std.hpp b/mlir-compiler/src/passes/plier_to_std.hpp index 5adcb2ebbe2..b3fd97784b0 100644 --- a/mlir-compiler/src/passes/plier_to_std.hpp +++ b/mlir-compiler/src/passes/plier_to_std.hpp @@ -1,8 +1,5 @@ #pragma once -namespace mlir -{ -class OpPassManager; -} +class PipelineRegistry; -void populate_plier_to_std_pipeline(mlir::OpPassManager& pm); +void register_plier_to_std_pipeline(PipelineRegistry& registry); diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp index a5bc42504c3..9717c0f5ea5 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -26,12 +26,10 @@ void topo_visit(T& elem, IterF&& iter_func, VisitF&& func) return; } elem.visited = true; - elem.iterating = true; iter_func(elem, [&](T& next) { topo_visit(next, std::forward(iter_func), std::forward(func)); }); - elem.iterating = false; func(elem); } } @@ -47,7 +45,7 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const assert(!name.empty()); return name.data(); }; - std::set pipelines_ordered; // sorted map to make order consistent + std::set pipelines_ordered; // sorted set to make order consistent auto get_pipeline = [&](llvm::StringRef name)->llvm::StringRef { @@ -60,12 +58,13 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const return str; }; - struct IdSet : protected llvm::SmallVector + struct PipelineSet : protected llvm::SmallVector { - using Base = llvm::SmallVector; + using Base = llvm::SmallVector; using Base::begin; using Base::end; - void push_back(name_id id) + using Base::value_type; + void push_back(llvm::StringRef id) { auto it = std::equal_range(begin(), end(), id); if (it.first == it.second) @@ -78,8 +77,8 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const struct PipelineInfo { llvm::StringRef name; - llvm::SmallVector prev_pipelines; - llvm::SmallVector next_pipelines; + PipelineSet prev_pipelines; + PipelineSet next_pipelines; pipeline_funt_t func = nullptr; PipelineInfo* next = nullptr; bool visited = false; @@ -142,32 +141,44 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const // toposort PipelineInfo* first_pipeline = nullptr; + PipelineInfo* current_pipeline = nullptr; for (auto name : pipelines_ordered) { - auto iter_func = [&](const PipelineInfo& elem, auto func) + auto iter_func = [&](PipelineInfo& elem, auto func) { - if (elem.iterating) + elem.iterating = true; + for (auto it : elem.prev_pipelines) { - report_error(llvm::Twine("Pipeline depends on itself: ") + elem.name); - } - for (auto prev : elem.prev_pipelines) - { - func(get_pipeline_info(prev)); + auto& info = get_pipeline_info(it); + if (info.iterating) + { + report_error(llvm::Twine("Pipeline depends on itself: ") + elem.name); + } + func(info); } + elem.iterating = false; }; auto visit_func = [&](PipelineInfo& elem) { assert(nullptr == elem.next); - elem.next = first_pipeline; - first_pipeline = &elem; + auto current = &elem; + if (nullptr == first_pipeline) + { + first_pipeline = current; + } + else + { + assert(nullptr != current_pipeline); + current_pipeline->next = current; + } + current_pipeline = current; }; topo_visit(get_pipeline_info(name), iter_func, visit_func); } - for (auto current = first_pipeline; nullptr != first_pipeline; - first_pipeline = first_pipeline->next) + for (auto current = first_pipeline; nullptr != current; + current = current->next) { - assert(nullptr != current); current->func(pm); } } From ab798bdcff66deb59ddc1bd6717a45cbcec7f04c Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 12 Oct 2020 14:28:52 +0300 Subject: [PATCH 077/259] refac --- mlir-compiler/src/passes/base_pipeline.cpp | 46 ++++++++++++++++------ mlir-compiler/src/passes/base_pipeline.hpp | 13 +++++- mlir-compiler/src/passes/lower_to_llvm.cpp | 4 +- mlir-compiler/src/passes/plier_to_std.cpp | 4 +- mlir-compiler/src/pipeline_registry.cpp | 1 + 5 files changed, 53 insertions(+), 15 deletions(-) diff --git a/mlir-compiler/src/passes/base_pipeline.cpp b/mlir-compiler/src/passes/base_pipeline.cpp index bef2f592859..ee9746df39a 100644 --- a/mlir-compiler/src/passes/base_pipeline.cpp +++ b/mlir-compiler/src/passes/base_pipeline.cpp @@ -2,19 +2,41 @@ #include "pipeline_registry.hpp" +namespace +{ +const constexpr llvm::StringRef passes[] ={ + "init", + "lowering", + "terminate", +}; + +void dummy_pass_func(mlir::OpPassManager&) {} +} + void register_base_pipeline(PipelineRegistry& registry) { - auto dummu_func = [](mlir::OpPassManager&){}; - registry.register_pipeline([&](auto sink) - { - sink("init", {}, {}, dummu_func); - }); - registry.register_pipeline([&](auto sink) - { - sink("lowering", {"init"}, {}, dummu_func); - }); - registry.register_pipeline([&](auto sink) + for (std::size_t i = 0; i < llvm::array_lengthof(passes); ++i) { - sink("terminate", {"lowering"}, {}, dummu_func); - }); + registry.register_pipeline([i](auto sink) + { + if (0 == i) + { + sink(passes[i], {}, {}, dummy_pass_func); + } + else + { + sink(passes[i], {passes[i - 1]}, {}, dummy_pass_func); + } + }); + } +} + +PipelineStage get_high_lowering_stage() +{ + return {passes[0], passes[1]}; +} + +PipelineStage get_lower_lowering_stage() +{ + return {passes[1], passes[2]}; } diff --git a/mlir-compiler/src/passes/base_pipeline.hpp b/mlir-compiler/src/passes/base_pipeline.hpp index cc4e9fdc8fa..65f9747802c 100644 --- a/mlir-compiler/src/passes/base_pipeline.hpp +++ b/mlir-compiler/src/passes/base_pipeline.hpp @@ -1,5 +1,16 @@ #pragma once +#include + class PipelineRegistry; -void register_base_pipeline(PipelineRegistry& registry);; +void register_base_pipeline(PipelineRegistry& registry); + +struct PipelineStage +{ + llvm::StringRef begin; + llvm::StringRef end; +}; + +PipelineStage get_high_lowering_stage(); // TODO: better name +PipelineStage get_lower_lowering_stage(); // TODO: better name diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index 880db4c7769..d8bdfc27190 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -16,6 +16,7 @@ #include "plier/dialect.hpp" +#include "base_pipeline.hpp" #include "pipeline_registry.hpp" #include "utils.hpp" @@ -299,6 +300,7 @@ void register_lower_to_llvm_pipeline(PipelineRegistry& registry) { registry.register_pipeline([](auto sink) { - sink("lower_to_llvm", {"lowering"}, {"terminate"}, &populate_lower_to_llvm_pipeline); + auto stage = get_lower_lowering_stage(); + sink("lower_to_llvm", {stage.begin}, {stage.end}, &populate_lower_to_llvm_pipeline); }); } diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 1dde3492e75..c1748363fa2 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -11,6 +11,7 @@ #include "plier/dialect.hpp" +#include "base_pipeline.hpp" #include "pipeline_registry.hpp" namespace @@ -561,6 +562,7 @@ void register_plier_to_std_pipeline(PipelineRegistry& registry) { registry.register_pipeline([](auto sink) { - sink("plier_to_std", {"init"}, {"lowering"}, &populate_plier_to_std_pipeline); + auto stage = get_high_lowering_stage(); + sink("plier_to_std", {stage.begin}, {stage.end}, &populate_plier_to_std_pipeline); }); } diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp index 9717c0f5ea5..023a1ef7078 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -92,6 +92,7 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const llvm::ArrayRef next_pipelines, pipeline_funt_t func) { + assert(!pipeline_name.empty()); assert(nullptr != func); auto i = get_pipeline(pipeline_name); auto it = pipelines_map.insert({get_id(i), {}}); From 02350c3fe45a2d60e318f39b1cac934c9387d745 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 14 Oct 2020 15:14:36 +0300 Subject: [PATCH 078/259] some refac --- mlir-compiler/include/plier/PlierOps.td | 30 +++++----- mlir-compiler/src/passes/lower_to_llvm.cpp | 16 ++--- mlir-compiler/src/passes/plier_to_std.cpp | 68 ++++++++++------------ 3 files changed, 51 insertions(+), 63 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 594bcef90df..7460a174df7 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -75,12 +75,12 @@ def CastOp : Plier_Op<"cast", []> { def PyCallOp : Plier_Op<"call", []> { let arguments = (ins - Plier_PyType:$func, + AnyType:$func, Variadic:$args, UI32Attr:$kw_start, ArrayAttr:$kw_names); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value func, " @@ -93,7 +93,7 @@ def BuildTupleOp : Plier_Op<"build_tuple", []> { let arguments = (ins Variadic:$args); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::ValueRange args"> @@ -102,11 +102,11 @@ def BuildTupleOp : Plier_Op<"build_tuple", []> { def StaticGetItemOp : Plier_Op<"static_getitem", []> { let arguments = (ins - Plier_PyType:$value, - Plier_PyType:$index_var, + AnyType:$value, + AnyType:$index_var, UI32Attr:$index); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, " @@ -116,9 +116,9 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { def GetiterOp : Plier_Op<"getiter", []> { let arguments = (ins - Plier_PyType:$value); + AnyType:$value); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> @@ -127,9 +127,9 @@ def GetiterOp : Plier_Op<"getiter", []> { def IternextOp : Plier_Op<"iternext", []> { let arguments = (ins - Plier_PyType:$value); + AnyType:$value); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> @@ -138,9 +138,9 @@ def IternextOp : Plier_Op<"iternext", []> { def PairfirstOp : Plier_Op<"pair_first", []> { let arguments = (ins - Plier_PyType:$value); + AnyType:$value); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> @@ -149,9 +149,9 @@ def PairfirstOp : Plier_Op<"pair_first", []> { def PairsecondOp : Plier_Op<"pair_second", []> { let arguments = (ins - Plier_PyType:$value); + AnyType:$value); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> @@ -160,7 +160,7 @@ def PairsecondOp : Plier_Op<"pair_second", []> { def DelOp : Plier_Op<"del", []> { let arguments = (ins - Plier_PyType:$value); + AnyType:$value); // let builders = [ // OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value val"> diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index d8bdfc27190..e2c693193ce 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -241,7 +241,7 @@ struct PreLLVMLowering : public mlir::PassWrapper(&getContext(), type_helper.get_type_converter()); - if (mlir::failed(apply_conv())) - { - signalPassFailure(); - return; - } + apply_conv(); } }; @@ -272,17 +268,13 @@ struct PostLLVMLowering : mlir::OwningRewritePatternList patterns; auto apply_conv = [&]() { - return mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); }; // Remove redundant bitcasts we have created on PreLowering patterns.insert(&getContext()); - if (mlir::failed(apply_conv())) - { - signalPassFailure(); - return; - } + apply_conv(); } }; diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index c1748363fa2..85fcb8d1314 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -16,68 +16,56 @@ namespace { -mlir::Type map_int_type(plier::PyType type) +mlir::Type map_int_type(mlir::MLIRContext& ctx, llvm::StringRef& name) { - auto name = type.getName(); unsigned num_bits = 0; if (name.consume_front("int") && - !name.consumeInteger(10, num_bits) && name.empty()) + !name.consumeInteger(10, num_bits)) { - return mlir::IntegerType::get(num_bits, type.getContext()); + return mlir::IntegerType::get(num_bits, &ctx); } return nullptr; } -mlir::Type map_int_literal_type(plier::PyType type) +mlir::Type map_int_literal_type(mlir::MLIRContext& ctx, llvm::StringRef& name) { - auto name = type.getName(); unsigned dummy = 0; if (name.consume_front("Literal[int](") && - !name.consumeInteger(10, dummy) && name.consume_front(")") - && name.empty()) + !name.consumeInteger(10, dummy) && name.consume_front(")")) { - return mlir::IntegerType::get(64, type.getContext()); // TODO + return mlir::IntegerType::get(64, &ctx); // TODO } return nullptr; } -mlir::Type map_bool_type(plier::PyType type) +mlir::Type map_bool_type(mlir::MLIRContext& ctx, llvm::StringRef& name) { - auto name = type.getName(); - if (name == "bool") + if (name.consume_front("bool")) { - return mlir::IntegerType::get(1, type.getContext()); + return mlir::IntegerType::get(1, &ctx); } return nullptr; } -mlir::Type map_float_type(plier::PyType type) +mlir::Type map_float_type(mlir::MLIRContext& ctx, llvm::StringRef& name) { - auto name = type.getName(); unsigned num_bits = 0; if (name.consume_front("float") && - !name.consumeInteger(10, num_bits) && name.empty()) + !name.consumeInteger(10, num_bits)) { - auto ctx = type.getContext(); switch(num_bits) { - case 64: return mlir::Float64Type::get(ctx); - case 32: return mlir::Float32Type::get(ctx); - case 16: return mlir::Float16Type::get(ctx); + case 64: return mlir::Float64Type::get(&ctx); + case 32: return mlir::Float32Type::get(&ctx); + case 16: return mlir::Float16Type::get(&ctx); } } return nullptr; } -mlir::Type map_plier_type(mlir::Type type) +mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) { - assert(static_cast(type)); - if (!type.isa()) - { - return nullptr; - } - auto ptype = type.cast(); - using func_t = mlir::Type(*)(plier::PyType); + using func_t = mlir::Type(*)(mlir::MLIRContext& ctx, llvm::StringRef& name); const func_t handlers[] = { &map_int_type, &map_int_literal_type, @@ -86,15 +74,27 @@ mlir::Type map_plier_type(mlir::Type type) }; for (auto h : handlers) { - auto t = h(ptype); - if (t != mlir::Type()) + auto temp_name = name; + auto t = h(ctx, temp_name); + if (static_cast(t)) { + name = temp_name; return t; } } return nullptr; } +mlir::Type map_plier_type(mlir::Type type) +{ + if (!type.isa()) + { + return type; + } + auto name = type.cast().getName(); + return map_plier_type_name(*type.getContext(), name); +} + bool is_supported_type(mlir::Type type) { return type.isIntOrFloat(); @@ -540,14 +540,10 @@ void PlierToStdPass::runOnOperation() auto apply_conv = [&]() { - return mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); }; - if (mlir::failed(apply_conv())) - { - signalPassFailure(); - return; - } + apply_conv(); } void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) From a2c88acb239ce7fd8fdabea8a9443a52dde3ba09 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 14 Oct 2020 15:27:15 +0300 Subject: [PATCH 079/259] pair --- mlir-compiler/src/passes/plier_to_std.cpp | 32 ++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index 85fcb8d1314..b1e559b7470 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -63,6 +63,35 @@ mlir::Type map_float_type(mlir::MLIRContext& ctx, llvm::StringRef& name) return nullptr; } +mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name); + +mlir::Type map_pair_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + if (!name.consume_front("pair<")) + { + return nullptr; + } + auto first = map_plier_type_name(ctx, name); + if (!static_cast(first)) + { + return nullptr; + } + if (!name.consume_front(", ")) + { + return nullptr; + } + auto second = map_plier_type_name(ctx, name); + if (!static_cast(second)) + { + return nullptr; + } + if (!name.consume_front(">")) + { + return nullptr; + } + return mlir::TupleType::get({first, second}, &ctx); +} + mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) { using func_t = mlir::Type(*)(mlir::MLIRContext& ctx, llvm::StringRef& name); @@ -70,7 +99,8 @@ mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) &map_int_type, &map_int_literal_type, &map_bool_type, - &map_float_type + &map_float_type, + &map_pair_type, }; for (auto h : handlers) { From f63ebee12bd4c2806e792152ecf88833bfa4638a Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 14 Oct 2020 20:02:41 +0300 Subject: [PATCH 080/259] translate tuples to multiple results --- mlir-compiler/include/plier/PlierOps.td | 6 +- mlir-compiler/src/passes/plier_to_std.cpp | 70 ++++++++++++++++++++++- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 7460a174df7..8148ecaed7a 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -22,7 +22,7 @@ def ArgOp : Plier_Op<"arg", [NoSideEffect]> { UI32Attr:$index, StrAttr:$name); - let results = (outs Plier_PyType); + let results = (outs AnyType); let hasFolder = 1; let builders = [ @@ -34,7 +34,7 @@ def ConstOp : Plier_Op<"const", [NoSideEffect]> { let arguments = (ins AnyAttr:$val); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Attribute val"> @@ -45,7 +45,7 @@ def GlobalOp : Plier_Op<"global", [NoSideEffect]> { let arguments = (ins StrAttr:$name); - let results = (outs Plier_PyType); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, StringRef name"> diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index b1e559b7470..ba011e0ee92 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -484,6 +484,70 @@ struct CastOpLowering : public mlir::OpRewritePattern } }; +mlir::Operation* change_op_ret_type(mlir::Operation* op, + mlir::PatternRewriter& rewriter, + llvm::ArrayRef types) +{ + assert(nullptr != op); + mlir::OperationState state(op->getLoc(), op->getName().getStringRef(), + op->getOperands(), types, op->getAttrs()); + return rewriter.createOperation(state); +} + +struct ExpandTuples : public mlir::RewritePattern +{ + ExpandTuples(mlir::MLIRContext* ctx): + RewritePattern(0, mlir::Pattern::MatchAnyOpTypeTag()), + dialect(ctx->getLoadedDialect()) + { + assert(nullptr != dialect); + } + + mlir::LogicalResult + matchAndRewrite(plier::Operation* op, mlir::PatternRewriter& rewriter) const override + { + if (op->getResultTypes().size() != 1 || + !op->getResultTypes()[0].isa() || + (op->getDialect() != dialect)) + { + return mlir::failure(); + } + auto types = op->getResultTypes()[0].cast().getTypes(); + + auto new_op = change_op_ret_type(op, rewriter, types); + auto new_op_results = new_op->getResults(); + + llvm::SmallVector users(op->getUsers()); + llvm::SmallVector new_operands; + op->dump(); + for (auto user_op : users) + { + new_operands.clear(); + for (auto arg : user_op->getOperands()) + { + if (arg.getDefiningOp() == op) + { + std::copy(new_op_results.begin(), new_op_results.end(), std::back_inserter(new_operands)); + } + else + { + new_operands.push_back(arg); + } + } + rewriter.updateRootInPlace(user_op, [&]() + { + user_op->setOperands(new_operands); + }); + } + rewriter.eraseOp(op); + return mlir::success(); + } + +private: + mlir::Dialect* dialect = nullptr; +}; + + struct FuncOpSignatureConversion : public mlir::OpRewritePattern { FuncOpSignatureConversion(mlir::MLIRContext* ctx, @@ -519,8 +583,8 @@ struct OpTypeConversion : public mlir::RewritePattern { if (auto new_type = map_plier_type(type)) { + changed = changed || (new_type != type); new_types.push_back(new_type); - changed = true; } else { @@ -548,6 +612,7 @@ struct PlierToStdPass : virtual void getDependentDialects( mlir::DialectRegistry ®istry) const override { + registry.insert(); registry.insert(); } @@ -566,7 +631,8 @@ void PlierToStdPass::runOnOperation() patterns.insert(&getContext(), type_converter); patterns.insert(&getContext()); + CallOpLowering, CastOpLowering, + ExpandTuples>(&getContext()); auto apply_conv = [&]() { From 94d307dfcd8527c773426e08a3bc2d6052f13495 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 14 Oct 2020 20:18:44 +0300 Subject: [PATCH 081/259] pair ops folding --- mlir-compiler/include/plier/PlierOps.td | 16 ++++++++-------- mlir-compiler/src/dialect.cpp | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 8148ecaed7a..8965ce05ec2 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -136,22 +136,26 @@ def IternextOp : Plier_Op<"iternext", []> { ]; } -def PairfirstOp : Plier_Op<"pair_first", []> { +def PairfirstOp : Plier_Op<"pair_first", [NoSideEffect]> { let arguments = (ins - AnyType:$value); + AnyType:$value, + Optional:$second_val); let results = (outs AnyType); + let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> ]; } -def PairsecondOp : Plier_Op<"pair_second", []> { +def PairsecondOp : Plier_Op<"pair_second", [NoSideEffect]> { let arguments = (ins - AnyType:$value); + AnyType:$value, + Optional:$second_val); let results = (outs AnyType); + let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> @@ -161,10 +165,6 @@ def PairsecondOp : Plier_Op<"pair_second", []> { def DelOp : Plier_Op<"del", []> { let arguments = (ins AnyType:$value); - -// let builders = [ -// OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value val"> -// ]; } #endif // PLIER_OPS diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index faf9352385c..16642a172a4 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -175,6 +175,15 @@ void PairfirstOp::build(OpBuilder &builder, OperationState &state, value); } +mlir::OpFoldResult PairfirstOp::fold(llvm::ArrayRef /*operands*/) +{ + if (getNumOperands() == 2) + { + return getOperand(0); + } + return nullptr; +} + void PairsecondOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value) { @@ -182,6 +191,15 @@ void PairsecondOp::build(OpBuilder &builder, OperationState &state, PyType::getUndefined(state.getContext()), value); } +mlir::OpFoldResult PairsecondOp::fold(llvm::ArrayRef /*operands*/) +{ + if (getNumOperands() == 2) + { + return getOperand(1); + } + return nullptr; +} + } #define GET_OP_CLASSES From 1f71e94a8cf9c9373c12338e6d69fd25896bcde1 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 14 Oct 2020 21:49:38 +0300 Subject: [PATCH 082/259] tuples --- mlir-compiler/include/plier/PlierOps.td | 6 +- mlir-compiler/src/dialect.cpp | 26 +++++++++ mlir-compiler/src/lowering.cpp | 2 +- mlir-compiler/src/passes/plier_to_std.cpp | 69 ++++++++++++++++++----- mlir-compiler/test.py | 1 + numba/mlir/tests/test_basic.py | 9 +++ 6 files changed, 95 insertions(+), 18 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 8965ce05ec2..18dd6f941bc 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -89,11 +89,12 @@ def PyCallOp : Plier_Op<"call", []> { ]; } -def BuildTupleOp : Plier_Op<"build_tuple", []> { +def BuildTupleOp : Plier_Op<"build_tuple", [NoSideEffect]> { let arguments = (ins Variadic:$args); - let results = (outs AnyType); + let results = (outs Variadic); + let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::ValueRange args"> @@ -107,6 +108,7 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { UI32Attr:$index); let results = (outs AnyType); + let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, " diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 16642a172a4..62c8ac905c4 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -145,6 +145,20 @@ void BuildTupleOp::build(OpBuilder &builder, OperationState &state, PyType::getUndefined(state.getContext()), args); } +mlir::LogicalResult BuildTupleOp::fold( + llvm::ArrayRef /*operands*/, + llvm::SmallVectorImpl &results) +{ + auto res_types = getResultTypes(); + auto args = getOperands(); + if (res_types.size() == args.size()) + { + std::copy(args.begin(), args.end(), std::back_inserter(results)); + return mlir::success(); + } + return mlir::failure(); +} + void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value, ::mlir::Value index_var, unsigned int index) @@ -154,6 +168,18 @@ void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, value, index_var, index); } +mlir::OpFoldResult StaticGetItemOp::fold(llvm::ArrayRef /*operands*/) +{ + auto index = this->index(); + auto args = getOperands(); + if ((index + 1) < args.size() && // skip last arg + args[index].getType() == getResult().getType()) + { + return args[index]; + } + return nullptr; +} + void GetiterOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value) { diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 826b57c3301..df81c9dff6a 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -332,7 +332,7 @@ struct plier_lowerer { args.push_back(loadvar(item)); } - return builder.create(builder.getUnknownLoc(), args); + return builder.create(builder.getUnknownLoc(), args).getResult(0); } mlir::Value lower_phi(const py::handle& expr) diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index ba011e0ee92..a5fd69d620d 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -64,32 +64,70 @@ mlir::Type map_float_type(mlir::MLIRContext& ctx, llvm::StringRef& name) } mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name); - -mlir::Type map_pair_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +bool map_type_helper(mlir::MLIRContext& ctx, llvm::StringRef& name, mlir::Type& ret) { - if (!name.consume_front("pair<")) + auto type = map_plier_type_name(ctx, name); + if (static_cast(type)) { - return nullptr; + ret = type; + return true; } - auto first = map_plier_type_name(ctx, name); - if (!static_cast(first)) + return false; +} + +mlir::Type map_pair_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + mlir::Type first; + mlir::Type second; + if (name.consume_front("pair<") && + map_type_helper(ctx, name, first) && + name.consume_front(", ") && + map_type_helper(ctx, name, second) && + name.consume_front(">")) { - return nullptr; + return mlir::TupleType::get({first, second}, &ctx); } - if (!name.consume_front(", ")) + return nullptr; +} + +mlir::Type map_unituple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + mlir::Type type; + unsigned count = 0; + if (name.consume_front("UniTuple(") && + map_type_helper(ctx, name, type) && + name.consume_front(" x ") && + !name.consumeInteger(10, count) && + name.consume_front(")")) { - return nullptr; + llvm::SmallVector types(count, type); + return mlir::TupleType::get(types, &ctx); } - auto second = map_plier_type_name(ctx, name); - if (!static_cast(second)) + return nullptr; +} + +mlir::Type map_tuple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + if (!name.consume_front("Tuple(")) { return nullptr; } - if (!name.consume_front(">")) + llvm::SmallVector types; + while (true) { - return nullptr; + if (name.consume_front(")")) + { + break; + } + auto type = map_plier_type_name(ctx, name); + if (!static_cast(type)) + { + return nullptr; + } + types.push_back(type); + (void)name.consume_front(", "); } - return mlir::TupleType::get({first, second}, &ctx); + return mlir::TupleType::get(types, &ctx); } mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) @@ -101,6 +139,8 @@ mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) &map_bool_type, &map_float_type, &map_pair_type, + &map_unituple_type, + &map_tuple_type, }; for (auto h : handlers) { @@ -519,7 +559,6 @@ struct ExpandTuples : public mlir::RewritePattern llvm::SmallVector users(op->getUsers()); llvm::SmallVector new_operands; - op->dump(); for (auto user_op : users) { new_operands.clear(); diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 2bb322eacc2..d761cdbbdd8 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -78,6 +78,7 @@ def test(func, params): test(jump, (7,8)) test(call, (1,2,3)) test(tuple, (1,2,3)) +test(tuple, (1,2.0,3)) test(loop, (8,)) print(f'Tests passed: {_tests_passes}/{_tests_total}') diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index b9f80e4d194..330c4a9093f 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -80,6 +80,15 @@ def py_func(a, b): for a, b in itertools.product(_test_values, _test_values): assert_equal(py_func(a, b), jit_func(a, b)) + def test_tuple(self): + def py_func(a, b, c): + t = (a,b,c) + return t[0] + t[1] + t[2] + + jit_func = njit(py_func) + for a, b, c in itertools.product(_test_values, _test_values, _test_values): + assert_equal(py_func(a, b, c), jit_func(a, b, c)) + if __name__ == '__main__': unittest.main() From aad613bb9544fae30ff1cb7a6c1ca3d57fa919d2 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 15 Oct 2020 16:09:06 +0300 Subject: [PATCH 083/259] fix --- mlir-compiler/include/plier/PlierOps.td | 2 +- mlir-compiler/src/lowering.cpp | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 18dd6f941bc..85bea44153e 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -131,7 +131,7 @@ def IternextOp : Plier_Op<"iternext", []> { let arguments = (ins AnyType:$value); - let results = (outs AnyType); + let results = (outs Variadic); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index df81c9dff6a..fa532c7ca71 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -287,7 +287,7 @@ struct plier_lowerer {"build_tuple", &plier_lowerer::lower_build_tuple}, {"static_getitem", &plier_lowerer::lower_static_getitem}, {"getiter", &plier_lowerer::lower_simple}, - {"iternext", &plier_lowerer::lower_simple}, + {"iternext", &plier_lowerer::lower_simple_multiresult}, {"pair_first", &plier_lowerer::lower_simple}, {"pair_second", &plier_lowerer::lower_simple}, }; @@ -308,6 +308,15 @@ struct plier_lowerer return builder.create(builder.getUnknownLoc(), value); } + template + mlir::Value lower_simple_multiresult(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + auto res = builder.create(builder.getUnknownLoc(), value); + assert(res.getNumResults() == 1); + return res.getResult(0); + } + mlir::Value lower_cast(const py::handle& inst) { auto value = loadvar(inst.attr("value")); @@ -332,7 +341,9 @@ struct plier_lowerer { args.push_back(loadvar(item)); } - return builder.create(builder.getUnknownLoc(), args).getResult(0); + auto res = builder.create(builder.getUnknownLoc(), args); + assert(res.getNumResults() == 1); + return res.getResult(0); } mlir::Value lower_phi(const py::handle& expr) From c3fb6aa8dcd43a8e958d2668176036614186e39d Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 15 Oct 2020 16:13:51 +0300 Subject: [PATCH 084/259] refac --- mlir-compiler/src/lowering.cpp | 37 +++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index fa532c7ca71..4f9d55b29a2 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -249,7 +249,7 @@ struct plier_lowerer if (py::isinstance(value, insts.Arg)) { auto index = value.attr("index").cast(); - return builder.create(builder.getUnknownLoc(), index, + return builder.create(get_current_loc(), index, target.attr("name").cast()); } if(py::isinstance(value, insts.Expr)) @@ -263,12 +263,12 @@ struct plier_lowerer if (py::isinstance(value, insts.Const)) { auto val = get_const_val(value.attr("value")); - return builder.create(builder.getUnknownLoc(), val); + return builder.create(get_current_loc(), val); } if (py::isinstance(value, insts.Global)) { auto name = value.attr("name").cast(); - return builder.create(builder.getUnknownLoc(), + return builder.create(get_current_loc(), name); } @@ -305,14 +305,14 @@ struct plier_lowerer mlir::Value lower_simple(const py::handle& inst) { auto value = loadvar(inst.attr("value")); - return builder.create(builder.getUnknownLoc(), value); + return builder.create(get_current_loc(), value); } template mlir::Value lower_simple_multiresult(const py::handle& inst) { auto value = loadvar(inst.attr("value")); - auto res = builder.create(builder.getUnknownLoc(), value); + auto res = builder.create(get_current_loc(), value); assert(res.getNumResults() == 1); return res.getResult(0); } @@ -321,7 +321,7 @@ struct plier_lowerer { auto value = loadvar(inst.attr("value")); auto res_type = get_type(current_instr.attr("target")); - return builder.create(builder.getUnknownLoc(), res_type, value); + return builder.create(get_current_loc(), res_type, value); } mlir::Value lower_static_getitem(const py::handle& inst) @@ -329,7 +329,7 @@ struct plier_lowerer auto value = loadvar(inst.attr("value")); auto index_var = loadvar(inst.attr("index_var")); auto index = inst.attr("index").cast(); - return builder.create(builder.getUnknownLoc(), + return builder.create(get_current_loc(), value, index_var, index); } @@ -341,7 +341,7 @@ struct plier_lowerer { args.push_back(loadvar(item)); } - auto res = builder.create(builder.getUnknownLoc(), args); + auto res = builder.create(get_current_loc(), args); assert(res.getNumResults() == 1); return res.getResult(0); } @@ -391,7 +391,7 @@ struct plier_lowerer kwargs_list.push_back({name.cast(), loadvar(val_name)}); } - return builder.create(builder.getUnknownLoc(), func, + return builder.create(get_current_loc(), func, args_list, kwargs_list); } @@ -412,7 +412,7 @@ struct plier_lowerer if (op.is(std::get<1>(elem))) { auto op_name = std::get<0>(elem).op; - return builder.create(builder.getUnknownLoc(), lhs, rhs, op_name); + return builder.create(get_current_loc(), lhs, rhs, op_name); } } @@ -435,7 +435,7 @@ struct plier_lowerer void delvar(const py::handle& inst) { auto var = loadvar(inst); - builder.create(builder.getUnknownLoc(), var); + builder.create(get_current_loc(), var); } void retvar(const py::handle& inst) @@ -446,9 +446,9 @@ struct plier_lowerer auto var_type = var.getType(); if (ret_type != var_type) { - var = builder.create(builder.getUnknownLoc(), ret_type, var); + var = builder.create(get_current_loc(), ret_type, var); } - builder.create(builder.getUnknownLoc(), var); + builder.create(get_current_loc(), var); } void branch(const py::handle& cond, const py::handle& tr, const py::handle& fl) @@ -456,14 +456,14 @@ struct plier_lowerer auto c = loadvar(cond); auto tr_block = blocks_map.find(tr.cast())->second; auto fl_block = blocks_map.find(fl.cast())->second; - auto cond_val = builder.create(builder.getUnknownLoc(), mlir::IntegerType::get(1, &ctx), c); - builder.create(builder.getUnknownLoc(), cond_val, tr_block, fl_block); + auto cond_val = builder.create(get_current_loc(), mlir::IntegerType::get(1, &ctx), c); + builder.create(get_current_loc(), cond_val, tr_block, fl_block); } void jump(const py::handle& target) { auto block = blocks_map.find(target.cast())->second; - builder.create(builder.getUnknownLoc(), mlir::None, block); + builder.create(get_current_loc(), mlir::None, block); } mlir::Attribute get_const_val(const py::handle& val) @@ -486,6 +486,11 @@ struct plier_lowerer return mlir::FunctionType::get(args, {ret}, &ctx); } + mlir::Location get_current_loc() + { + return builder.getUnknownLoc(); // TODO + } + void fixup_phis() { auto build_arg_list = [&](mlir::Block* block, auto& outgoing_phi_nodes, auto& list) From 3314620bb965172961ed5532736d16da7096e4c4 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 15 Oct 2020 19:25:57 +0300 Subject: [PATCH 085/259] refac --- mlir-compiler/src/passes/lower_to_llvm.cpp | 7 ++++++- mlir-compiler/src/passes/lower_to_llvm.hpp | 7 +++++++ mlir-compiler/src/passes/plier_to_std.cpp | 7 ++++++- mlir-compiler/src/passes/plier_to_std.hpp | 7 +++++++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/passes/lower_to_llvm.cpp index e2c693193ce..d97ba493adb 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/passes/lower_to_llvm.cpp @@ -293,6 +293,11 @@ void register_lower_to_llvm_pipeline(PipelineRegistry& registry) registry.register_pipeline([](auto sink) { auto stage = get_lower_lowering_stage(); - sink("lower_to_llvm", {stage.begin}, {stage.end}, &populate_lower_to_llvm_pipeline); + sink(lower_to_llvm_pipeline_name(), {stage.begin}, {stage.end}, &populate_lower_to_llvm_pipeline); }); } + +llvm::StringRef lower_to_llvm_pipeline_name() +{ + return "lower_to_llvm"; +} diff --git a/mlir-compiler/src/passes/lower_to_llvm.hpp b/mlir-compiler/src/passes/lower_to_llvm.hpp index 61c6947c37b..43ec2466314 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.hpp +++ b/mlir-compiler/src/passes/lower_to_llvm.hpp @@ -2,4 +2,11 @@ class PipelineRegistry; +namespace llvm +{ +class StringRef; +} + void register_lower_to_llvm_pipeline(PipelineRegistry& registry); + +llvm::StringRef lower_to_llvm_pipeline_name(); diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/passes/plier_to_std.cpp index a5fd69d620d..3420530b108 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/passes/plier_to_std.cpp @@ -694,6 +694,11 @@ void register_plier_to_std_pipeline(PipelineRegistry& registry) registry.register_pipeline([](auto sink) { auto stage = get_high_lowering_stage(); - sink("plier_to_std", {stage.begin}, {stage.end}, &populate_plier_to_std_pipeline); + sink(plier_to_std_pipeline_name(), {stage.begin}, {stage.end}, &populate_plier_to_std_pipeline); }); } + +llvm::StringRef plier_to_std_pipeline_name() +{ + return "plier_to_std"; +} diff --git a/mlir-compiler/src/passes/plier_to_std.hpp b/mlir-compiler/src/passes/plier_to_std.hpp index b3fd97784b0..1768d294a51 100644 --- a/mlir-compiler/src/passes/plier_to_std.hpp +++ b/mlir-compiler/src/passes/plier_to_std.hpp @@ -2,4 +2,11 @@ class PipelineRegistry; +namespace llvm +{ +class StringRef; +} + void register_plier_to_std_pipeline(PipelineRegistry& registry); + +llvm::StringRef plier_to_std_pipeline_name(); From 53cfca14031da026e0c5af54105626e76d75934e Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 18:33:02 +0300 Subject: [PATCH 086/259] plier to linalg pipeline stub --- mlir-compiler/CMakeLists.txt | 16 +++++----- mlir-compiler/src/lowering.cpp | 2 ++ mlir-compiler/src/passes/plier_to_linalg.cpp | 31 ++++++++++++++++++++ mlir-compiler/src/passes/plier_to_linalg.hpp | 12 ++++++++ 4 files changed, 54 insertions(+), 7 deletions(-) create mode 100644 mlir-compiler/src/passes/plier_to_linalg.cpp create mode 100644 mlir-compiler/src/passes/plier_to_linalg.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index a527e30f14a..d2d785c105b 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -20,26 +20,28 @@ include(HandleLLVMOptions) add_subdirectory(include/plier) set(SOURCES_LIST + src/passes/base_pipeline.cpp + src/passes/lower_to_llvm.cpp + src/passes/plier_to_linalg.cpp + src/passes/plier_to_std.cpp src/compiler.cpp src/dialect.cpp src/lowering.cpp src/module.cpp - src/passes/base_pipeline.cpp - src/passes/lower_to_llvm.cpp - src/passes/plier_to_std.cpp src/pipeline_registry.cpp src/utils.cpp ) set(HEADERS_LIST - src/compiler.hpp - src/lowering.hpp + include/plier/dialect.hpp + include/plier/PlierOps.td src/passes/base_pipeline.hpp src/passes/lower_to_llvm.hpp + src/passes/plier_to_linalg.hpp src/passes/plier_to_std.hpp + src/compiler.hpp + src/lowering.hpp src/pipeline_registry.hpp src/utils.hpp - include/plier/dialect.hpp - include/plier/PlierOps.td ) pybind11_add_module(${PROJECT_NAME} ${SOURCES_LIST} ${HEADERS_LIST}) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 4f9d55b29a2..408d4c270f6 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -24,6 +24,7 @@ #include "passes/base_pipeline.hpp" #include "passes/plier_to_std.hpp" +#include "passes/plier_to_linalg.hpp" #include "passes/lower_to_llvm.hpp" namespace py = pybind11; @@ -597,6 +598,7 @@ void create_pipeline(PipelineRegistry& registry) register_base_pipeline(registry); register_lower_to_llvm_pipeline(registry); register_plier_to_std_pipeline(registry); + register_plier_to_linalg_pipeline(registry); } } diff --git a/mlir-compiler/src/passes/plier_to_linalg.cpp b/mlir-compiler/src/passes/plier_to_linalg.cpp new file mode 100644 index 00000000000..332ebc50e02 --- /dev/null +++ b/mlir-compiler/src/passes/plier_to_linalg.cpp @@ -0,0 +1,31 @@ +#include "passes/plier_to_linalg.hpp" + +#include "plier/dialect.hpp" + +#include "passes/plier_to_std.hpp" + +#include "base_pipeline.hpp" +#include "pipeline_registry.hpp" + +namespace +{ + +void populate_plier_to_linalg_pipeline(mlir::OpPassManager& /*pm*/) +{ + +} +} + +void register_plier_to_linalg_pipeline(PipelineRegistry& registry) +{ + registry.register_pipeline([](auto sink) + { + auto stage = get_high_lowering_stage(); + sink(plier_to_linalg_pipeline_name(), {plier_to_std_pipeline_name()}, {stage.end}, &populate_plier_to_linalg_pipeline); + }); +} + +llvm::StringRef plier_to_linalg_pipeline_name() +{ + return "plier_to_linalg"; +} diff --git a/mlir-compiler/src/passes/plier_to_linalg.hpp b/mlir-compiler/src/passes/plier_to_linalg.hpp new file mode 100644 index 00000000000..a25cd7352ae --- /dev/null +++ b/mlir-compiler/src/passes/plier_to_linalg.hpp @@ -0,0 +1,12 @@ +#pragma once + +class PipelineRegistry; + +namespace llvm +{ +class StringRef; +} + +void register_plier_to_linalg_pipeline(PipelineRegistry& registry); + +llvm::StringRef plier_to_linalg_pipeline_name(); From 3b19516ebf636e3961aa5ace96a290281da49cc1 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 18:38:13 +0300 Subject: [PATCH 087/259] more stubs --- mlir-compiler/src/passes/plier_to_linalg.cpp | 45 +++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/src/passes/plier_to_linalg.cpp b/mlir-compiler/src/passes/plier_to_linalg.cpp index 332ebc50e02..1f868ff9f7d 100644 --- a/mlir-compiler/src/passes/plier_to_linalg.cpp +++ b/mlir-compiler/src/passes/plier_to_linalg.cpp @@ -1,5 +1,12 @@ #include "passes/plier_to_linalg.hpp" +#include +#include +#include +#include +#include +#include + #include "plier/dialect.hpp" #include "passes/plier_to_std.hpp" @@ -10,9 +17,45 @@ namespace { -void populate_plier_to_linalg_pipeline(mlir::OpPassManager& /*pm*/) +struct PlierToLinalgPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +void PlierToLinalgPass::runOnOperation() { +// mlir::TypeConverter type_converter; +// type_converter.addConversion([](plier::Type type)->llvm::Optional +// { +// return map_plier_type(type); +// }); + + mlir::OwningRewritePatternList patterns; +// patterns.insert(&getContext(), type_converter); +// patterns.insert(&getContext()); + auto apply_conv = [&]() + { + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + }; + + apply_conv(); +} + +void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) +{ + pm.addPass(std::make_unique()); } } From 60ec6c7ce6124a14e9857607e24a91bac339f9de Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 18:42:17 +0300 Subject: [PATCH 088/259] move files --- mlir-compiler/CMakeLists.txt | 16 ++++++++-------- mlir-compiler/src/lowering.cpp | 8 ++++---- .../src/{passes => pipelines}/base_pipeline.cpp | 2 +- .../src/{passes => pipelines}/base_pipeline.hpp | 0 .../src/{passes => pipelines}/lower_to_llvm.cpp | 2 +- .../src/{passes => pipelines}/lower_to_llvm.hpp | 0 .../{passes => pipelines}/plier_to_linalg.cpp | 4 ++-- .../{passes => pipelines}/plier_to_linalg.hpp | 0 .../src/{passes => pipelines}/plier_to_std.cpp | 2 +- .../src/{passes => pipelines}/plier_to_std.hpp | 0 10 files changed, 17 insertions(+), 17 deletions(-) rename mlir-compiler/src/{passes => pipelines}/base_pipeline.cpp (95%) rename mlir-compiler/src/{passes => pipelines}/base_pipeline.hpp (100%) rename mlir-compiler/src/{passes => pipelines}/lower_to_llvm.cpp (99%) rename mlir-compiler/src/{passes => pipelines}/lower_to_llvm.hpp (100%) rename mlir-compiler/src/{passes => pipelines}/plier_to_linalg.cpp (96%) rename mlir-compiler/src/{passes => pipelines}/plier_to_linalg.hpp (100%) rename mlir-compiler/src/{passes => pipelines}/plier_to_std.cpp (99%) rename mlir-compiler/src/{passes => pipelines}/plier_to_std.hpp (100%) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index d2d785c105b..42c8b6ec64b 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -20,10 +20,10 @@ include(HandleLLVMOptions) add_subdirectory(include/plier) set(SOURCES_LIST - src/passes/base_pipeline.cpp - src/passes/lower_to_llvm.cpp - src/passes/plier_to_linalg.cpp - src/passes/plier_to_std.cpp + src/pipelines/base_pipeline.cpp + src/pipelines/lower_to_llvm.cpp + src/pipelines/plier_to_linalg.cpp + src/pipelines/plier_to_std.cpp src/compiler.cpp src/dialect.cpp src/lowering.cpp @@ -34,10 +34,10 @@ set(SOURCES_LIST set(HEADERS_LIST include/plier/dialect.hpp include/plier/PlierOps.td - src/passes/base_pipeline.hpp - src/passes/lower_to_llvm.hpp - src/passes/plier_to_linalg.hpp - src/passes/plier_to_std.hpp + src/pipelines/base_pipeline.hpp + src/pipelines/lower_to_llvm.hpp + src/pipelines/plier_to_linalg.hpp + src/pipelines/plier_to_std.hpp src/compiler.hpp src/lowering.hpp src/pipeline_registry.hpp diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 408d4c270f6..3b9f4d48011 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -22,10 +22,10 @@ #include "pipeline_registry.hpp" #include "utils.hpp" -#include "passes/base_pipeline.hpp" -#include "passes/plier_to_std.hpp" -#include "passes/plier_to_linalg.hpp" -#include "passes/lower_to_llvm.hpp" +#include "pipelines/base_pipeline.hpp" +#include "pipelines/plier_to_std.hpp" +#include "pipelines/plier_to_linalg.hpp" +#include "pipelines/lower_to_llvm.hpp" namespace py = pybind11; namespace diff --git a/mlir-compiler/src/passes/base_pipeline.cpp b/mlir-compiler/src/pipelines/base_pipeline.cpp similarity index 95% rename from mlir-compiler/src/passes/base_pipeline.cpp rename to mlir-compiler/src/pipelines/base_pipeline.cpp index ee9746df39a..2d1405b81c0 100644 --- a/mlir-compiler/src/passes/base_pipeline.cpp +++ b/mlir-compiler/src/pipelines/base_pipeline.cpp @@ -1,4 +1,4 @@ -#include "passes/base_pipeline.hpp" +#include "pipelines/base_pipeline.hpp" #include "pipeline_registry.hpp" diff --git a/mlir-compiler/src/passes/base_pipeline.hpp b/mlir-compiler/src/pipelines/base_pipeline.hpp similarity index 100% rename from mlir-compiler/src/passes/base_pipeline.hpp rename to mlir-compiler/src/pipelines/base_pipeline.hpp diff --git a/mlir-compiler/src/passes/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp similarity index 99% rename from mlir-compiler/src/passes/lower_to_llvm.cpp rename to mlir-compiler/src/pipelines/lower_to_llvm.cpp index d97ba493adb..8ecbd192f6d 100644 --- a/mlir-compiler/src/passes/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -1,4 +1,4 @@ -#include "passes/lower_to_llvm.hpp" +#include "pipelines/lower_to_llvm.hpp" #include #include diff --git a/mlir-compiler/src/passes/lower_to_llvm.hpp b/mlir-compiler/src/pipelines/lower_to_llvm.hpp similarity index 100% rename from mlir-compiler/src/passes/lower_to_llvm.hpp rename to mlir-compiler/src/pipelines/lower_to_llvm.hpp diff --git a/mlir-compiler/src/passes/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp similarity index 96% rename from mlir-compiler/src/passes/plier_to_linalg.cpp rename to mlir-compiler/src/pipelines/plier_to_linalg.cpp index 1f868ff9f7d..f1e11d73fe3 100644 --- a/mlir-compiler/src/passes/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -1,4 +1,4 @@ -#include "passes/plier_to_linalg.hpp" +#include "pipelines/plier_to_linalg.hpp" #include #include @@ -9,7 +9,7 @@ #include "plier/dialect.hpp" -#include "passes/plier_to_std.hpp" +#include "pipelines/plier_to_std.hpp" #include "base_pipeline.hpp" #include "pipeline_registry.hpp" diff --git a/mlir-compiler/src/passes/plier_to_linalg.hpp b/mlir-compiler/src/pipelines/plier_to_linalg.hpp similarity index 100% rename from mlir-compiler/src/passes/plier_to_linalg.hpp rename to mlir-compiler/src/pipelines/plier_to_linalg.hpp diff --git a/mlir-compiler/src/passes/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp similarity index 99% rename from mlir-compiler/src/passes/plier_to_std.cpp rename to mlir-compiler/src/pipelines/plier_to_std.cpp index 3420530b108..abb452164e5 100644 --- a/mlir-compiler/src/passes/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1,4 +1,4 @@ -#include "passes/plier_to_std.hpp" +#include "pipelines/plier_to_std.hpp" #include #include diff --git a/mlir-compiler/src/passes/plier_to_std.hpp b/mlir-compiler/src/pipelines/plier_to_std.hpp similarity index 100% rename from mlir-compiler/src/passes/plier_to_std.hpp rename to mlir-compiler/src/pipelines/plier_to_std.hpp From cd6d3dd3fb17397cfebd3c7a0092a9b52a1d80c5 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 20:23:21 +0300 Subject: [PATCH 089/259] refac func sig conversion --- mlir-compiler/src/pipelines/plier_to_std.cpp | 79 ++++++++++++++++++-- 1 file changed, 73 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index abb452164e5..dbe0f07e4c5 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -586,24 +586,91 @@ struct ExpandTuples : public mlir::RewritePattern mlir::Dialect* dialect = nullptr; }; +mlir::LogicalResult setBlockSig( + mlir::Block& block, const mlir::TypeConverter::SignatureConversion& conversion) +{ + if (conversion.getConvertedTypes().size() != block.getNumArguments()) + { + return mlir::failure(); + } + for (auto it : llvm::zip(block.getArguments(), conversion.getConvertedTypes())) + { + auto arg = std::get<0>(it); + auto type = std::get<1>(it); + arg.setType(type); + } + return mlir::success(); +} + +mlir::LogicalResult convertRegionTypes( + mlir::Region *region, mlir::TypeConverter &converter, bool apply) +{ + if (region->empty()) + { + return mlir::failure(); + } + + // Convert the arguments of each block within the region. + auto sig = converter.convertBlockSignature(®ion->front()); + assert(static_cast(sig)); + if (apply) + { + auto res = setBlockSig(region->front(), *sig); + assert(mlir::succeeded(res)); + } + for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) + { + sig = converter.convertBlockSignature(&block); + if (!sig) + { + return mlir::failure(); + } + if (apply) + { + if (mlir::failed(setBlockSig(block, *sig))) + { + return mlir::failure(); + } + } + } + return mlir::success(); +} + struct FuncOpSignatureConversion : public mlir::OpRewritePattern { FuncOpSignatureConversion(mlir::MLIRContext* ctx, - mlir::TypeConverter& /*converter*/) - : OpRewritePattern(ctx) {} + mlir::TypeConverter& conv) + : OpRewritePattern(ctx), converter(conv) {} /// Hook for derived classes to implement combined matching and rewriting. mlir::LogicalResult matchAndRewrite(mlir::FuncOp funcOp, mlir::PatternRewriter &rewriter) const override { - bool changed = convert_func_sig(funcOp); - if (changed) + auto type = funcOp.getType(); + + // Convert the original function types. + mlir::TypeConverter::SignatureConversion result(type.getNumInputs()); + llvm::SmallVector newResults; + if (mlir::failed(converter.convertSignatureArgs(type.getInputs(), result)) || + mlir::failed(converter.convertTypes(type.getResults(), newResults)) || + mlir::failed(convertRegionTypes(&funcOp.getBody(), converter, false))) { - rewriter.updateRootInPlace(funcOp, [&] {}); // HACK + return mlir::failure(); } - return mlir::success(changed); + + // Update the function signature in-place. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(mlir::FunctionType::get(result.getConvertedTypes(), newResults, + funcOp.getContext())); + auto res = convertRegionTypes(&funcOp.getBody(), converter, true); + assert(mlir::succeeded(res)); + }); + return mlir::success(); } + +private: + mlir::TypeConverter& converter; }; struct OpTypeConversion : public mlir::RewritePattern From 1ceee94a3bb7cb34b063a7a6cef8ca3fab119683 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 20:33:28 +0300 Subject: [PATCH 090/259] move FuncOpSignatureConversion to separate file --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/src/pipelines/plier_to_std.cpp | 89 +------------------ .../rewrites/func_signature_conversion.cpp | 85 ++++++++++++++++++ .../rewrites/func_signature_conversion.hpp | 22 +++++ 4 files changed, 111 insertions(+), 87 deletions(-) create mode 100644 mlir-compiler/src/rewrites/func_signature_conversion.cpp create mode 100644 mlir-compiler/src/rewrites/func_signature_conversion.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 42c8b6ec64b..758357df176 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -24,6 +24,7 @@ set(SOURCES_LIST src/pipelines/lower_to_llvm.cpp src/pipelines/plier_to_linalg.cpp src/pipelines/plier_to_std.cpp + src/rewrites/func_signature_conversion.cpp src/compiler.cpp src/dialect.cpp src/lowering.cpp @@ -38,6 +39,7 @@ set(HEADERS_LIST src/pipelines/lower_to_llvm.hpp src/pipelines/plier_to_linalg.hpp src/pipelines/plier_to_std.hpp + src/rewrites/func_signature_conversion.hpp src/compiler.hpp src/lowering.hpp src/pipeline_registry.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index dbe0f07e4c5..f86041521b9 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -11,6 +11,8 @@ #include "plier/dialect.hpp" +#include "rewrites/func_signature_conversion.hpp" + #include "base_pipeline.hpp" #include "pipeline_registry.hpp" @@ -586,93 +588,6 @@ struct ExpandTuples : public mlir::RewritePattern mlir::Dialect* dialect = nullptr; }; -mlir::LogicalResult setBlockSig( - mlir::Block& block, const mlir::TypeConverter::SignatureConversion& conversion) -{ - if (conversion.getConvertedTypes().size() != block.getNumArguments()) - { - return mlir::failure(); - } - for (auto it : llvm::zip(block.getArguments(), conversion.getConvertedTypes())) - { - auto arg = std::get<0>(it); - auto type = std::get<1>(it); - arg.setType(type); - } - return mlir::success(); -} - -mlir::LogicalResult convertRegionTypes( - mlir::Region *region, mlir::TypeConverter &converter, bool apply) -{ - if (region->empty()) - { - return mlir::failure(); - } - - // Convert the arguments of each block within the region. - auto sig = converter.convertBlockSignature(®ion->front()); - assert(static_cast(sig)); - if (apply) - { - auto res = setBlockSig(region->front(), *sig); - assert(mlir::succeeded(res)); - } - for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) - { - sig = converter.convertBlockSignature(&block); - if (!sig) - { - return mlir::failure(); - } - if (apply) - { - if (mlir::failed(setBlockSig(block, *sig))) - { - return mlir::failure(); - } - } - } - return mlir::success(); -} - - -struct FuncOpSignatureConversion : public mlir::OpRewritePattern -{ - FuncOpSignatureConversion(mlir::MLIRContext* ctx, - mlir::TypeConverter& conv) - : OpRewritePattern(ctx), converter(conv) {} - - /// Hook for derived classes to implement combined matching and rewriting. - mlir::LogicalResult - matchAndRewrite(mlir::FuncOp funcOp, mlir::PatternRewriter &rewriter) const override - { - auto type = funcOp.getType(); - - // Convert the original function types. - mlir::TypeConverter::SignatureConversion result(type.getNumInputs()); - llvm::SmallVector newResults; - if (mlir::failed(converter.convertSignatureArgs(type.getInputs(), result)) || - mlir::failed(converter.convertTypes(type.getResults(), newResults)) || - mlir::failed(convertRegionTypes(&funcOp.getBody(), converter, false))) - { - return mlir::failure(); - } - - // Update the function signature in-place. - rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(mlir::FunctionType::get(result.getConvertedTypes(), newResults, - funcOp.getContext())); - auto res = convertRegionTypes(&funcOp.getBody(), converter, true); - assert(mlir::succeeded(res)); - }); - return mlir::success(); - } - -private: - mlir::TypeConverter& converter; -}; - struct OpTypeConversion : public mlir::RewritePattern { OpTypeConversion(mlir::MLIRContext* /*ctx*/, diff --git a/mlir-compiler/src/rewrites/func_signature_conversion.cpp b/mlir-compiler/src/rewrites/func_signature_conversion.cpp new file mode 100644 index 00000000000..6e171972b61 --- /dev/null +++ b/mlir-compiler/src/rewrites/func_signature_conversion.cpp @@ -0,0 +1,85 @@ +#include "rewrites/func_signature_conversion.hpp" + +#include + +namespace +{ +mlir::LogicalResult setBlockSig( + mlir::Block& block, const mlir::TypeConverter::SignatureConversion& conversion) +{ + if (conversion.getConvertedTypes().size() != block.getNumArguments()) + { + return mlir::failure(); + } + for (auto it : llvm::zip(block.getArguments(), conversion.getConvertedTypes())) + { + auto arg = std::get<0>(it); + auto type = std::get<1>(it); + arg.setType(type); + } + return mlir::success(); +} + +mlir::LogicalResult convertRegionTypes( + mlir::Region *region, mlir::TypeConverter &converter, bool apply) +{ + if (region->empty()) + { + return mlir::failure(); + } + + // Convert the arguments of each block within the region. + auto sig = converter.convertBlockSignature(®ion->front()); + assert(static_cast(sig)); + if (apply) + { + auto res = setBlockSig(region->front(), *sig); + assert(mlir::succeeded(res)); + } + for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) + { + sig = converter.convertBlockSignature(&block); + if (!sig) + { + return mlir::failure(); + } + if (apply) + { + if (mlir::failed(setBlockSig(block, *sig))) + { + return mlir::failure(); + } + } + } + return mlir::success(); +} +} + +FuncOpSignatureConversion::FuncOpSignatureConversion( + mlir::MLIRContext* ctx, mlir::TypeConverter& conv) + : OpRewritePattern(ctx), converter(conv) {} + +mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( + mlir::FuncOp funcOp, mlir::PatternRewriter& rewriter) const +{ + auto type = funcOp.getType(); + + // Convert the original function types. + mlir::TypeConverter::SignatureConversion result(type.getNumInputs()); + llvm::SmallVector newResults; + if (mlir::failed(converter.convertSignatureArgs(type.getInputs(), result)) || + mlir::failed(converter.convertTypes(type.getResults(), newResults)) || + mlir::failed(convertRegionTypes(&funcOp.getBody(), converter, false))) + { + return mlir::failure(); + } + + // Update the function signature in-place. + rewriter.updateRootInPlace(funcOp, [&] { + funcOp.setType(mlir::FunctionType::get(result.getConvertedTypes(), newResults, + funcOp.getContext())); + auto res = convertRegionTypes(&funcOp.getBody(), converter, true); + assert(mlir::succeeded(res)); + }); + return mlir::success(); +} diff --git a/mlir-compiler/src/rewrites/func_signature_conversion.hpp b/mlir-compiler/src/rewrites/func_signature_conversion.hpp new file mode 100644 index 00000000000..1cbce92fc85 --- /dev/null +++ b/mlir-compiler/src/rewrites/func_signature_conversion.hpp @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +namespace mlir +{ +class TypeConverter; +} + +struct FuncOpSignatureConversion : public mlir::OpRewritePattern +{ + FuncOpSignatureConversion(mlir::MLIRContext* ctx, + mlir::TypeConverter& conv); + + /// Hook for derived classes to implement combined matching and rewriting. + mlir::LogicalResult + matchAndRewrite(mlir::FuncOp funcOp, mlir::PatternRewriter &rewriter) const override; + +private: + mlir::TypeConverter& converter; +}; From 6bf8d4c3111c251e5e12d498e6c71aa024f5938c Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 21:25:10 +0300 Subject: [PATCH 091/259] remove unused --- mlir-compiler/src/pipelines/plier_to_std.cpp | 47 -------------------- 1 file changed, 47 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index f86041521b9..0cf1888761d 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -172,53 +172,6 @@ bool is_supported_type(mlir::Type type) return type.isIntOrFloat(); } -mlir::Type map_type(mlir::Type type) -{ - auto new_type = is_supported_type(type) ? type : map_plier_type(type); - return static_cast(new_type) ? new_type : type; -}; - -bool convert_func_sig(mlir::FuncOp func) -{ - llvm::SmallVector new_arg_types; - new_arg_types.reserve(func.getNumArguments()); - bool changed = false; - for (auto arg_type : func.getArgumentTypes()) - { - auto new_type = map_type(arg_type); - changed = changed || (new_type != arg_type); - new_arg_types.push_back(new_type); - } - - auto res_type = func.getType().getResult(0); - auto new_res_type = map_type(res_type); - changed = changed || (res_type != new_res_type); - if (changed) - { - auto func_type = mlir::FunctionType::get(new_arg_types, new_res_type, func.getContext()); - func.setType(func_type); - for (unsigned i = 0; i < func.getNumArguments(); ++i) - { - func.front().getArgument(i).setType(new_arg_types[i]); - } - } - for (auto& bb : llvm::make_range(++func.getBody().begin(), - func.getBody().end())) - { - for (auto arg : bb.getArguments()) - { - auto arg_type = arg.getType(); - auto new_type = map_type(arg_type); - if (new_type != arg_type) - { - arg.setType(new_type); - changed = true; - } - } - } - return changed; -} - template void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) { From d6f3049b94170bb5993c3ab91333ba6dcf0806ee Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 21:27:01 +0300 Subject: [PATCH 092/259] refac --- mlir-compiler/src/pipelines/plier_to_std.cpp | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 0cf1888761d..68d3b76abd0 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -544,8 +544,9 @@ struct ExpandTuples : public mlir::RewritePattern struct OpTypeConversion : public mlir::RewritePattern { OpTypeConversion(mlir::MLIRContext* /*ctx*/, - mlir::TypeConverter& /*converter*/) - : RewritePattern(0, mlir::Pattern::MatchAnyOpTypeTag()) {} + mlir::TypeConverter& conv): + RewritePattern(0, mlir::Pattern::MatchAnyOpTypeTag()), + converter(conv) {} /// Hook for derived classes to implement combined matching and rewriting. mlir::LogicalResult @@ -555,7 +556,7 @@ struct OpTypeConversion : public mlir::RewritePattern llvm::SmallVector new_types; for (auto type : op->getResultTypes()) { - if (auto new_type = map_plier_type(type)) + if (auto new_type = converter.convertType(type)) { changed = changed || (new_type != type); new_types.push_back(new_type); @@ -578,6 +579,9 @@ struct OpTypeConversion : public mlir::RewritePattern } return mlir::success(changed); } + +private: + mlir::TypeConverter& converter; }; struct PlierToStdPass : From 699d470de2a786c5769b51829a87091b32fd6852 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 21:29:21 +0300 Subject: [PATCH 093/259] rename files --- mlir-compiler/CMakeLists.txt | 4 ++-- mlir-compiler/src/pipelines/plier_to_std.cpp | 2 +- .../{func_signature_conversion.cpp => type_conversion.cpp} | 2 +- .../{func_signature_conversion.hpp => type_conversion.hpp} | 0 4 files changed, 4 insertions(+), 4 deletions(-) rename mlir-compiler/src/rewrites/{func_signature_conversion.cpp => type_conversion.cpp} (98%) rename mlir-compiler/src/rewrites/{func_signature_conversion.hpp => type_conversion.hpp} (100%) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 758357df176..b0818370b45 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -24,7 +24,7 @@ set(SOURCES_LIST src/pipelines/lower_to_llvm.cpp src/pipelines/plier_to_linalg.cpp src/pipelines/plier_to_std.cpp - src/rewrites/func_signature_conversion.cpp + src/rewrites/type_conversion.cpp src/compiler.cpp src/dialect.cpp src/lowering.cpp @@ -39,7 +39,7 @@ set(HEADERS_LIST src/pipelines/lower_to_llvm.hpp src/pipelines/plier_to_linalg.hpp src/pipelines/plier_to_std.hpp - src/rewrites/func_signature_conversion.hpp + src/rewrites/type_conversion.hpp src/compiler.hpp src/lowering.hpp src/pipeline_registry.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 68d3b76abd0..00fcc43fa77 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -11,7 +11,7 @@ #include "plier/dialect.hpp" -#include "rewrites/func_signature_conversion.hpp" +#include "rewrites/type_conversion.hpp" #include "base_pipeline.hpp" #include "pipeline_registry.hpp" diff --git a/mlir-compiler/src/rewrites/func_signature_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp similarity index 98% rename from mlir-compiler/src/rewrites/func_signature_conversion.cpp rename to mlir-compiler/src/rewrites/type_conversion.cpp index 6e171972b61..fda4e93589c 100644 --- a/mlir-compiler/src/rewrites/func_signature_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -1,4 +1,4 @@ -#include "rewrites/func_signature_conversion.hpp" +#include "rewrites/type_conversion.hpp" #include diff --git a/mlir-compiler/src/rewrites/func_signature_conversion.hpp b/mlir-compiler/src/rewrites/type_conversion.hpp similarity index 100% rename from mlir-compiler/src/rewrites/func_signature_conversion.hpp rename to mlir-compiler/src/rewrites/type_conversion.hpp From 1f4cb0252b0f0098456b5ef16c92666f80578e6d Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 21:31:36 +0300 Subject: [PATCH 094/259] move OpTypeConversion --- mlir-compiler/src/pipelines/plier_to_std.cpp | 43 ------------------- .../src/rewrites/type_conversion.cpp | 34 +++++++++++++++ .../src/rewrites/type_conversion.hpp | 13 ++++++ 3 files changed, 47 insertions(+), 43 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 00fcc43fa77..960f6a4211c 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -541,49 +541,6 @@ struct ExpandTuples : public mlir::RewritePattern mlir::Dialect* dialect = nullptr; }; -struct OpTypeConversion : public mlir::RewritePattern -{ - OpTypeConversion(mlir::MLIRContext* /*ctx*/, - mlir::TypeConverter& conv): - RewritePattern(0, mlir::Pattern::MatchAnyOpTypeTag()), - converter(conv) {} - - /// Hook for derived classes to implement combined matching and rewriting. - mlir::LogicalResult - matchAndRewrite(mlir::Operation* op, mlir::PatternRewriter &rewriter) const override - { - bool changed = false; - llvm::SmallVector new_types; - for (auto type : op->getResultTypes()) - { - if (auto new_type = converter.convertType(type)) - { - changed = changed || (new_type != type); - new_types.push_back(new_type); - } - else - { - new_types.push_back(type); - } - } - - if (changed) - { - rewriter.updateRootInPlace(op, [&] - { - for (unsigned i = 0; i < static_cast(new_types.size()); ++i) - { - op->getResult(i).setType(new_types[i]); - } - }); - } - return mlir::success(changed); - } - -private: - mlir::TypeConverter& converter; -}; - struct PlierToStdPass : public mlir::PassWrapper> { diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index fda4e93589c..af89bda511f 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -83,3 +83,37 @@ mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( }); return mlir::success(); } + +OpTypeConversion::OpTypeConversion(mlir::MLIRContext*, mlir::TypeConverter& conv): + RewritePattern(0, mlir::Pattern::MatchAnyOpTypeTag()), + converter(conv) {} + +mlir::LogicalResult OpTypeConversion::matchAndRewrite(mlir::Operation* op, mlir::PatternRewriter& rewriter) const +{ + bool changed = false; + llvm::SmallVector new_types; + for (auto type : op->getResultTypes()) + { + if (auto new_type = converter.convertType(type)) + { + changed = changed || (new_type != type); + new_types.push_back(new_type); + } + else + { + new_types.push_back(type); + } + } + + if (changed) + { + rewriter.updateRootInPlace(op, [&] + { + for (unsigned i = 0; i < static_cast(new_types.size()); ++i) + { + op->getResult(i).setType(new_types[i]); + } + }); + } + return mlir::success(changed); +} diff --git a/mlir-compiler/src/rewrites/type_conversion.hpp b/mlir-compiler/src/rewrites/type_conversion.hpp index 1cbce92fc85..d0c13f22016 100644 --- a/mlir-compiler/src/rewrites/type_conversion.hpp +++ b/mlir-compiler/src/rewrites/type_conversion.hpp @@ -20,3 +20,16 @@ struct FuncOpSignatureConversion : public mlir::OpRewritePattern private: mlir::TypeConverter& converter; }; + +struct OpTypeConversion : public mlir::RewritePattern +{ + OpTypeConversion(mlir::MLIRContext* ctx, + mlir::TypeConverter& conv); + + /// Hook for derived classes to implement combined matching and rewriting. + mlir::LogicalResult + matchAndRewrite(mlir::Operation* op, mlir::PatternRewriter &rewriter) const override; + +private: + mlir::TypeConverter& converter; +}; From 61b95e56fd328b28c5620c1a47ac392101bf5d47 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 16 Oct 2020 21:41:35 +0300 Subject: [PATCH 095/259] print module on error --- mlir-compiler/src/compiler.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index 1aad5a5d9ee..640cebf0edb 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -54,6 +54,8 @@ class CompilerContext::CompilerContextImpl { if (mlir::failed(pm.run(module))) { + err_stream << "\n"; + module.print(err_stream); err_stream.flush(); report_error(llvm::Twine("MLIR pipeline failed\n") + err); } From 48beca40e96e151699e5d3877cd2d3f7bf74004d Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 13:32:33 +0300 Subject: [PATCH 096/259] remove hacky tuples folding --- mlir-compiler/include/plier/PlierOps.td | 14 ++--- mlir-compiler/src/dialect.cpp | 80 ++++++++++++------------- mlir-compiler/src/lowering.cpp | 15 +---- numba/mlir/tests/test_basic.py | 1 + 4 files changed, 47 insertions(+), 63 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 85bea44153e..ce3b2482c33 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -93,8 +93,7 @@ def BuildTupleOp : Plier_Op<"build_tuple", [NoSideEffect]> { let arguments = (ins Variadic:$args); - let results = (outs Variadic); - let hasFolder = 1; + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::ValueRange args"> @@ -108,7 +107,6 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { UI32Attr:$index); let results = (outs AnyType); - let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, " @@ -131,7 +129,7 @@ def IternextOp : Plier_Op<"iternext", []> { let arguments = (ins AnyType:$value); - let results = (outs Variadic); + let results = (outs AnyType); let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> @@ -140,11 +138,9 @@ def IternextOp : Plier_Op<"iternext", []> { def PairfirstOp : Plier_Op<"pair_first", [NoSideEffect]> { let arguments = (ins - AnyType:$value, - Optional:$second_val); + AnyType:$value); let results = (outs AnyType); - let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> @@ -153,11 +149,9 @@ def PairfirstOp : Plier_Op<"pair_first", [NoSideEffect]> { def PairsecondOp : Plier_Op<"pair_second", [NoSideEffect]> { let arguments = (ins - AnyType:$value, - Optional:$second_val); + AnyType:$value); let results = (outs AnyType); - let hasFolder = 1; let builders = [ OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 62c8ac905c4..ffab9bb8dff 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -145,19 +145,19 @@ void BuildTupleOp::build(OpBuilder &builder, OperationState &state, PyType::getUndefined(state.getContext()), args); } -mlir::LogicalResult BuildTupleOp::fold( - llvm::ArrayRef /*operands*/, - llvm::SmallVectorImpl &results) -{ - auto res_types = getResultTypes(); - auto args = getOperands(); - if (res_types.size() == args.size()) - { - std::copy(args.begin(), args.end(), std::back_inserter(results)); - return mlir::success(); - } - return mlir::failure(); -} +//mlir::LogicalResult BuildTupleOp::fold( +// llvm::ArrayRef /*operands*/, +// llvm::SmallVectorImpl &results) +//{ +// auto res_types = getResultTypes(); +// auto args = getOperands(); +// if (res_types.size() == args.size()) +// { +// std::copy(args.begin(), args.end(), std::back_inserter(results)); +// return mlir::success(); +// } +// return mlir::failure(); +//} void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value, ::mlir::Value index_var, @@ -168,17 +168,17 @@ void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, value, index_var, index); } -mlir::OpFoldResult StaticGetItemOp::fold(llvm::ArrayRef /*operands*/) -{ - auto index = this->index(); - auto args = getOperands(); - if ((index + 1) < args.size() && // skip last arg - args[index].getType() == getResult().getType()) - { - return args[index]; - } - return nullptr; -} +//mlir::OpFoldResult StaticGetItemOp::fold(llvm::ArrayRef /*operands*/) +//{ +// auto index = this->index(); +// auto args = getOperands(); +// if ((index + 1) < args.size() && // skip last arg +// args[index].getType() == getResult().getType()) +// { +// return args[index]; +// } +// return nullptr; +//} void GetiterOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value) @@ -201,14 +201,14 @@ void PairfirstOp::build(OpBuilder &builder, OperationState &state, value); } -mlir::OpFoldResult PairfirstOp::fold(llvm::ArrayRef /*operands*/) -{ - if (getNumOperands() == 2) - { - return getOperand(0); - } - return nullptr; -} +//mlir::OpFoldResult PairfirstOp::fold(llvm::ArrayRef /*operands*/) +//{ +// if (getNumOperands() == 2) +// { +// return getOperand(0); +// } +// return nullptr; +//} void PairsecondOp::build(OpBuilder &builder, OperationState &state, ::mlir::Value value) @@ -217,14 +217,14 @@ void PairsecondOp::build(OpBuilder &builder, OperationState &state, PyType::getUndefined(state.getContext()), value); } -mlir::OpFoldResult PairsecondOp::fold(llvm::ArrayRef /*operands*/) -{ - if (getNumOperands() == 2) - { - return getOperand(1); - } - return nullptr; -} +//mlir::OpFoldResult PairsecondOp::fold(llvm::ArrayRef /*operands*/) +//{ +// if (getNumOperands() == 2) +// { +// return getOperand(1); +// } +// return nullptr; +//} } diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 3b9f4d48011..5fe6bca34dd 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -288,7 +288,7 @@ struct plier_lowerer {"build_tuple", &plier_lowerer::lower_build_tuple}, {"static_getitem", &plier_lowerer::lower_static_getitem}, {"getiter", &plier_lowerer::lower_simple}, - {"iternext", &plier_lowerer::lower_simple_multiresult}, + {"iternext", &plier_lowerer::lower_simple}, {"pair_first", &plier_lowerer::lower_simple}, {"pair_second", &plier_lowerer::lower_simple}, }; @@ -309,15 +309,6 @@ struct plier_lowerer return builder.create(get_current_loc(), value); } - template - mlir::Value lower_simple_multiresult(const py::handle& inst) - { - auto value = loadvar(inst.attr("value")); - auto res = builder.create(get_current_loc(), value); - assert(res.getNumResults() == 1); - return res.getResult(0); - } - mlir::Value lower_cast(const py::handle& inst) { auto value = loadvar(inst.attr("value")); @@ -342,9 +333,7 @@ struct plier_lowerer { args.push_back(loadvar(item)); } - auto res = builder.create(get_current_loc(), args); - assert(res.getNumResults() == 1); - return res.getResult(0); + return builder.create(get_current_loc(), args); } mlir::Value lower_phi(const py::handle& expr) diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 330c4a9093f..5f1488cd08b 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -80,6 +80,7 @@ def py_func(a, b): for a, b in itertools.product(_test_values, _test_values): assert_equal(py_func(a, b), jit_func(a, b)) + @unittest.skip def test_tuple(self): def py_func(a, b, c): t = (a,b,c) From bea6ffb46dba4799735a296a458c86105e824869 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 13:57:46 +0300 Subject: [PATCH 097/259] refac --- mlir-compiler/src/pipelines/plier_to_std.cpp | 63 ++++++++++++++++---- 1 file changed, 52 insertions(+), 11 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 960f6a4211c..343fac929ae 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -554,6 +554,26 @@ struct PlierToStdPass : void runOnOperation() override; }; +bool check_for_plier_types(mlir::Type type) +{ + if (type.isa()) + { + return true; + } + if (auto ftype = type.dyn_cast()) + { + return llvm::any_of(ftype.getResults(), &check_for_plier_types) || + llvm::any_of(ftype.getInputs(), &check_for_plier_types); + } + return false; +} + +bool check_op_for_plier_types(mlir::Operation* op) +{ + assert(nullptr != op); + return llvm::any_of(op->getResultTypes(), &check_for_plier_types); +} + void PlierToStdPass::runOnOperation() { mlir::TypeConverter type_converter; @@ -563,18 +583,39 @@ void PlierToStdPass::runOnOperation() }); mlir::OwningRewritePatternList patterns; - patterns.insert(&getContext(), type_converter); - patterns.insert(&getContext()); - - auto apply_conv = [&]() +// patterns.insert(&getContext(), type_converter); +// patterns.insert(&getContext()); + +// auto apply_conv = [&]() +// { +// (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); +// }; + +// apply_conv(); + mlir::ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addDynamicallyLegalOp( + [](mlir::Operation* op)->bool + { + return !check_for_plier_types(mlir::cast(op).getType()); + }); +// target.addDynamicallyLegalDialect( +// [](mlir::Operation* op)->bool +// { +// auto res = !check_op_for_plier_types(op); +// llvm::errs() << "Check op " << op->getName() << " " << res << "\n"; +// return res; +// }); + + mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), type_converter); + + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) { - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); - }; - - apply_conv(); + signalPassFailure(); + } } void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) From b3c1a621be8b10f6a00d6383cccae978f39c796a Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 17:08:42 +0300 Subject: [PATCH 098/259] some work --- mlir-compiler/src/pipelines/plier_to_std.cpp | 150 +++++++++++++++---- 1 file changed, 121 insertions(+), 29 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 343fac929ae..e0ea6f6ddf0 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -198,22 +198,69 @@ bool is_float(mlir::Type type) return type.isa(); } -struct ConstOpLowering : public mlir::OpRewritePattern +struct ConstOpLowering : public mlir::OpConversionPattern { - using mlir::OpRewritePattern::OpRewritePattern; - mlir::LogicalResult matchAndRewrite(plier::ConstOp op, - mlir::PatternRewriter& rewriter) const + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + plier::ConstOp op, llvm::ArrayRef /*operands*/, + mlir::ConversionPatternRewriter &rewriter) const { auto value = op.val(); if (!is_supported_type(value.getType())) { return mlir::failure(); } - rewriter.replaceOpWithNewOp(op, value.getType(), value); + rewriter.replaceOpWithNewOp(op, value); return mlir::success(); } }; +struct ReturnOpLowering : public mlir::OpConversionPattern +{ + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::ReturnOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const + { + llvm::errs() << "ReturnOpLowering\n"; + auto func = mlir::cast(op.getParentOp()); + auto res_types = func.getType().getResults(); + assert(res_types.size() == operands.size()); + bool converted = false; + llvm::SmallVector new_vals; + for (auto it : llvm::zip(operands, res_types)) + { + auto src = std::get<0>(it); + auto dst = std::get<1>(it); + src.getType().dump(); + llvm::errs() << "-"; + dst.dump(); + llvm::errs() << "\n"; + if (src.getType() != dst) + { + auto new_op = rewriter.create(op.getLoc(), dst, src); + new_vals.push_back(new_op); + converted = true; + } + else + { + new_vals.push_back(src); + } + } + if (converted) + { + rewriter.create(op.getLoc(), new_vals); + rewriter.eraseOp(op); + llvm::errs() << "ReturnOpLowering 1\n"; + return mlir::success(); + } + llvm::errs() << "ReturnOpLowering 2\n"; + return mlir::failure(); + } +}; + mlir::Type coerce(mlir::Type type0, mlir::Type type1) { // TODO: proper rules @@ -460,16 +507,26 @@ struct CallOpLowering : public mlir::OpRewritePattern } }; -struct CastOpLowering : public mlir::OpRewritePattern +struct CastOpLowering : public mlir::OpConversionPattern { - using mlir::OpRewritePattern::OpRewritePattern; + using mlir::OpConversionPattern::OpConversionPattern; - mlir::LogicalResult - matchAndRewrite(plier::CastOp op, mlir::PatternRewriter& rewriter) const override + mlir::LogicalResult matchAndRewrite( + plier::CastOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const { - auto src_type = op.getOperand().getType(); - auto dst_type = op.getType(); - if (is_supported_type(src_type) && is_supported_type(dst_type)) + assert(1 == operands.size()); + auto converter = getTypeConverter(); + assert(nullptr != converter); + auto src_type = operands[0].getType(); + auto dst_type = converter->convertType(op.getType()); +// auto dst_type = op.getType(); + llvm::errs() << "CastOpLowering "; + src_type.dump(); + llvm::errs() << " "; + dst_type.dump(); + + if (dst_type && is_supported_type(src_type) && is_supported_type(dst_type)) { auto new_op = do_cast(dst_type, op.getOperand(), rewriter); rewriter.replaceOp(op, new_op); @@ -568,19 +625,37 @@ bool check_for_plier_types(mlir::Type type) return false; } -bool check_op_for_plier_types(mlir::Operation* op) -{ - assert(nullptr != op); - return llvm::any_of(op->getResultTypes(), &check_for_plier_types); -} - void PlierToStdPass::runOnOperation() { mlir::TypeConverter type_converter; + type_converter.addConversion([](plier::Type type) { return type; }); type_converter.addConversion([](plier::Type type)->llvm::Optional { return map_plier_type(type); }); + type_converter.addSourceMaterialization( + [](mlir::OpBuilder& builder, plier::Type type, mlir::ValueRange inputs, mlir::Location loc)->mlir::Value + { + llvm::errs() << "SourceMaterialization "; + type.dump(); + llvm::errs() << "\n"; + assert(inputs.size() == 1); + inputs[0].dump(); + llvm::errs() << "\n"; + return builder.create(loc, type, inputs[0]); + }); + type_converter.addTargetMaterialization( + [](mlir::OpBuilder& builder, plier::Type type, mlir::ValueRange inputs, mlir::Location loc)->mlir::Value + { + llvm::errs() << "TargetMaterialization\n"; + type.dump(); + llvm::errs() << "\n"; + assert(inputs.size() == 1); + inputs[0].dump(); + llvm::errs() << "\n"; +// return builder.create(loc, type, inputs[0]); + return inputs[0]; + }); mlir::OwningRewritePatternList patterns; // patterns.insert(); + + auto& ctx = getContext(); + mlir::ConversionTarget target(ctx); target.addDynamicallyLegalOp( + [](mlir::FuncOp op)->bool + { + return !check_for_plier_types(op.getType()); + }); + target.addDynamicallyLegalOp( + [&](plier::CastOp op)->bool + { + auto src_type = op.getOperand().getType(); + auto dst_type = type_converter.convertType(op.getType()); + llvm::errs() << "check cast "; + src_type.dump(); + llvm::errs() << " "; + dst_type.dump(); + llvm::errs() << "\n"; + return !is_supported_type(src_type) || !is_supported_type(dst_type); + }); +// target.addLegalDialect(); + target.addDynamicallyLegalDialect( [](mlir::Operation* op)->bool { - return !check_for_plier_types(mlir::cast(op).getType()); + auto res = !llvm::any_of(op->getOperandTypes(), &check_for_plier_types); + llvm::errs() << "Check op " << op->getName() << " " << res << "\n"; + return res; }); -// target.addDynamicallyLegalDialect( -// [](mlir::Operation* op)->bool -// { -// auto res = !check_op_for_plier_types(op); -// llvm::errs() << "Check op " << op->getName() << " " << res << "\n"; -// return res; -// }); - mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(), type_converter); + mlir::populateFuncOpTypeConversionPattern(patterns, &ctx, type_converter); + patterns.insert + (type_converter, &ctx); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) { signalPassFailure(); + return; } } From a130bc55e2d3b7d28d3689bc256818eca25658c9 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 17:15:43 +0300 Subject: [PATCH 099/259] cleanup --- mlir-compiler/src/pipelines/plier_to_std.cpp | 44 ++------------------ 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index e0ea6f6ddf0..d433e94cbce 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -520,12 +520,6 @@ struct CastOpLowering : public mlir::OpConversionPattern assert(nullptr != converter); auto src_type = operands[0].getType(); auto dst_type = converter->convertType(op.getType()); -// auto dst_type = op.getType(); - llvm::errs() << "CastOpLowering "; - src_type.dump(); - llvm::errs() << " "; - dst_type.dump(); - if (dst_type && is_supported_type(src_type) && is_supported_type(dst_type)) { auto new_op = do_cast(dst_type, op.getOperand(), rewriter); @@ -636,40 +630,18 @@ void PlierToStdPass::runOnOperation() type_converter.addSourceMaterialization( [](mlir::OpBuilder& builder, plier::Type type, mlir::ValueRange inputs, mlir::Location loc)->mlir::Value { - llvm::errs() << "SourceMaterialization "; - type.dump(); - llvm::errs() << "\n"; assert(inputs.size() == 1); - inputs[0].dump(); - llvm::errs() << "\n"; return builder.create(loc, type, inputs[0]); }); type_converter.addTargetMaterialization( - [](mlir::OpBuilder& builder, plier::Type type, mlir::ValueRange inputs, mlir::Location loc)->mlir::Value + [](mlir::OpBuilder& /*builder*/, plier::Type /*type*/, mlir::ValueRange inputs, mlir::Location /*loc*/)->mlir::Value { - llvm::errs() << "TargetMaterialization\n"; - type.dump(); - llvm::errs() << "\n"; assert(inputs.size() == 1); - inputs[0].dump(); - llvm::errs() << "\n"; -// return builder.create(loc, type, inputs[0]); + // TODO return inputs[0]; }); mlir::OwningRewritePatternList patterns; -// patterns.insert(&getContext(), type_converter); -// patterns.insert(&getContext()); - -// auto apply_conv = [&]() -// { -// (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); -// }; - -// apply_conv(); auto& ctx = getContext(); mlir::ConversionTarget target(ctx); @@ -683,20 +655,12 @@ void PlierToStdPass::runOnOperation() { auto src_type = op.getOperand().getType(); auto dst_type = type_converter.convertType(op.getType()); - llvm::errs() << "check cast "; - src_type.dump(); - llvm::errs() << " "; - dst_type.dump(); - llvm::errs() << "\n"; - return !is_supported_type(src_type) || !is_supported_type(dst_type); + return !dst_type || !is_supported_type(src_type) || !is_supported_type(dst_type); }); -// target.addLegalDialect(); target.addDynamicallyLegalDialect( [](mlir::Operation* op)->bool { - auto res = !llvm::any_of(op->getOperandTypes(), &check_for_plier_types); - llvm::errs() << "Check op " << op->getName() << " " << res << "\n"; - return res; + return !llvm::any_of(op->getOperandTypes(), &check_for_plier_types); }); mlir::populateFuncOpTypeConversionPattern(patterns, &ctx, type_converter); From c3119e4431d810c979d4b4b663d69b8ecb3e48b3 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 17:57:35 +0300 Subject: [PATCH 100/259] refac binop --- mlir-compiler/src/pipelines/plier_to_std.cpp | 42 ++++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index d433e94cbce..86ee3a8caa2 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -203,9 +203,11 @@ struct ConstOpLowering : public mlir::OpConversionPattern using mlir::OpConversionPattern::OpConversionPattern; mlir::LogicalResult matchAndRewrite( - plier::ConstOp op, llvm::ArrayRef /*operands*/, - mlir::ConversionPatternRewriter &rewriter) const + plier::ConstOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + (void)operands; + assert(operands.empty()); auto value = op.val(); if (!is_supported_type(value.getType())) { @@ -222,9 +224,8 @@ struct ReturnOpLowering : public mlir::OpConversionPattern mlir::LogicalResult matchAndRewrite( mlir::ReturnOp op, llvm::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const + mlir::ConversionPatternRewriter &rewriter) const override { - llvm::errs() << "ReturnOpLowering\n"; auto func = mlir::cast(op.getParentOp()); auto res_types = func.getType().getResults(); assert(res_types.size() == operands.size()); @@ -234,10 +235,6 @@ struct ReturnOpLowering : public mlir::OpConversionPattern { auto src = std::get<0>(it); auto dst = std::get<1>(it); - src.getType().dump(); - llvm::errs() << "-"; - dst.dump(); - llvm::errs() << "\n"; if (src.getType() != dst) { auto new_op = rewriter.create(op.getLoc(), dst, src); @@ -363,32 +360,34 @@ mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& llvm_unreachable("Unhandled cast"); } -struct BinOpLowering : public mlir::OpRewritePattern +struct BinOpLowering : public mlir::OpConversionPattern { - using mlir::OpRewritePattern::OpRewritePattern; + using mlir::OpConversionPattern::OpConversionPattern; - mlir::LogicalResult matchAndRewrite(plier::BinOp op, - mlir::PatternRewriter& rewriter) const + mlir::LogicalResult matchAndRewrite( + plier::BinOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { - assert(op.getNumOperands() == 2); - auto type0 = op.getOperand(0).getType(); - auto type1 = op.getOperand(1).getType(); + assert(operands.size() == 2); + auto type0 = operands[0].getType(); + auto type1 = operands[1].getType(); if (!is_supported_type(type0) || !is_supported_type(type1)) { return mlir::failure(); } mlir::Type final_type; - std::array operands; + std::array converted_operands; if (type0 != type1) { final_type = coerce(type0, type1); - operands = {do_cast(final_type, op.getOperand(0), rewriter), - do_cast(final_type, op.getOperand(1), rewriter)}; + converted_operands = { + do_cast(final_type, operands[0], rewriter), + do_cast(final_type, operands[1], rewriter)}; } else { final_type = type0; - operands = {op.getOperand(0), op.getOperand(1)}; + converted_operands = {operands[0], operands[1]}; } assert(static_cast(final_type)); @@ -426,7 +425,7 @@ struct BinOpLowering : public mlir::OpRewritePattern { if (h.type == op.op()) { - (h.*mem)(op, rewriter, final_type, operands); + (h.*mem)(op, rewriter, final_type, converted_operands); return mlir::success(); } } @@ -664,7 +663,8 @@ void PlierToStdPass::runOnOperation() }); mlir::populateFuncOpTypeConversionPattern(patterns, &ctx, type_converter); - patterns.insert + patterns.insert (type_converter, &ctx); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) From e69920d59025d0136024f77dd382675edd730546 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 18:15:51 +0300 Subject: [PATCH 101/259] some work --- mlir-compiler/src/pipelines/plier_to_std.cpp | 58 +++++++++++++++----- 1 file changed, 43 insertions(+), 15 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 86ee3a8caa2..2e269e7d9ba 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -250,10 +250,30 @@ struct ReturnOpLowering : public mlir::OpConversionPattern { rewriter.create(op.getLoc(), new_vals); rewriter.eraseOp(op); - llvm::errs() << "ReturnOpLowering 1\n"; return mlir::success(); } - llvm::errs() << "ReturnOpLowering 2\n"; + return mlir::failure(); + } +}; + +struct SelectOpLowering : public mlir::OpConversionPattern +{ + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::SelectOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override + { + assert(operands.size() == 3); + auto true_val = operands[1]; + auto false_val = operands[2]; + if (true_val.getType() == false_val.getType() && + true_val.getType() != op.getType()) + { + auto cond = operands[0]; + rewriter.replaceOpWithNewOp(op, cond, true_val, false_val); + return mlir::success(); + } return mlir::failure(); } }; @@ -445,13 +465,13 @@ struct BinOpLowering : public mlir::OpConversionPattern } }; -mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) { - if (op.getNumOperands() != 2) + if (operands.size() != 2) { return mlir::failure(); } - auto val = op.getOperand(1); + auto val = operands[1]; bool success = false; auto replace_op = [&](mlir::Value val) { @@ -469,39 +489,46 @@ mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, mlir::PatternRewriter& r return mlir::success(success); } -using call_lowerer_func_t = mlir::LogicalResult(*)(plier::PyCallOp, mlir::PatternRewriter&); +using call_lowerer_func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef operands, mlir::PatternRewriter&); const constexpr std::pair builtin_calls[] = { {"", &lower_bool_cast}, }; -struct CallOpLowering : public mlir::OpRewritePattern +struct CallOpLowering : public mlir::OpConversionPattern { - using mlir::OpRewritePattern::OpRewritePattern; + using mlir::OpConversionPattern::OpConversionPattern; - mlir::LogicalResult - matchAndRewrite(plier::PyCallOp op, mlir::PatternRewriter& rewriter) const override + mlir::LogicalResult matchAndRewrite( + plier::PyCallOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { - if (op.getNumOperands() == 0) + llvm::errs() << "CallOpLowering\n"; + if (operands.empty()) { + llvm::errs() << "CallOpLowering 1\n"; return mlir::failure(); } - auto func_type = op.getOperand(0).getType(); + auto func_type = operands[0].getType(); if (!func_type.isa()) { + llvm::errs() << "CallOpLowering 2\n"; return mlir::failure(); } auto name = func_type.cast().getName(); if (!name.consume_front("Function(") || !name.consume_back(")")) { + llvm::errs() << "CallOpLowering 3\n"; return mlir::failure(); } for (auto& c : builtin_calls) { if (c.first == name) { - return c.second(op, rewriter); + llvm::errs() << "CallOpLowering 4\n"; + return c.second(op, operands, rewriter); } } + llvm::errs() << "CallOpLowering 5\n"; return mlir::failure(); } }; @@ -659,12 +686,13 @@ void PlierToStdPass::runOnOperation() target.addDynamicallyLegalDialect( [](mlir::Operation* op)->bool { - return !llvm::any_of(op->getOperandTypes(), &check_for_plier_types); + return !llvm::any_of(op->getOperandTypes(), &check_for_plier_types) || + !llvm::any_of(op->getResultTypes(), &check_for_plier_types); }); mlir::populateFuncOpTypeConversionPattern(patterns, &ctx, type_converter); patterns.insert + BinOpLowering, CallOpLowering, SelectOpLowering> (type_converter, &ctx); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) From 73fb697892d385bcde009107f9e737602498878f Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 18:40:54 +0300 Subject: [PATCH 102/259] refac --- mlir-compiler/src/pipelines/plier_to_std.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 2e269e7d9ba..163687ebc7f 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -691,9 +691,14 @@ void PlierToStdPass::runOnOperation() }); mlir::populateFuncOpTypeConversionPattern(patterns, &ctx, type_converter); - patterns.insert - (type_converter, &ctx); + patterns.insert< + ConstOpLowering, + CastOpLowering, + ReturnOpLowering, + BinOpLowering, + CallOpLowering, + SelectOpLowering + >(type_converter, &ctx); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) { From 767d46824e844dfb67919e4b7011b5c921d4c522 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 19:15:12 +0300 Subject: [PATCH 103/259] call op --- mlir-compiler/src/pipelines/plier_to_std.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 163687ebc7f..6205ce4a3c7 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -132,6 +132,17 @@ mlir::Type map_tuple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) return mlir::TupleType::get(types, &ctx); } +mlir::Type map_func_type(mlir::MLIRContext& ctx, llvm::StringRef& name) +{ + if (name.consume_front("Function(") && + name.consume_front("") && // TODO unhardcode; + name.consume_front(")")) + { + return mlir::FunctionType::get({}, {}, &ctx); + } + return nullptr; +} + mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) { using func_t = mlir::Type(*)(mlir::MLIRContext& ctx, llvm::StringRef& name); @@ -143,6 +154,7 @@ mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) &map_pair_type, &map_unituple_type, &map_tuple_type, + &map_func_type, }; for (auto h : handlers) { @@ -683,6 +695,7 @@ void PlierToStdPass::runOnOperation() auto dst_type = type_converter.convertType(op.getType()); return !dst_type || !is_supported_type(src_type) || !is_supported_type(dst_type); }); + target.addLegalOp(); target.addDynamicallyLegalDialect( [](mlir::Operation* op)->bool { From ea264e709a20aba68a23e841e8c1217b72dd5783 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 21:01:29 +0300 Subject: [PATCH 104/259] conversion --- mlir-compiler/src/pipelines/plier_to_std.cpp | 121 ++++++++++++++++--- 1 file changed, 103 insertions(+), 18 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 6205ce4a3c7..108fc1af332 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -290,6 +290,56 @@ struct SelectOpLowering : public mlir::OpConversionPattern } }; +struct CondBrOpLowering : public mlir::OpConversionPattern +{ + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult matchAndRewrite( + mlir::CondBranchOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override + { + assert(!operands.empty()); + auto cond = operands.front(); + operands = operands.drop_front(); + bool changed = false; + + auto process_operand = [&](mlir::Block& block, auto& ret) + { + for (auto arg : block.getArguments()) + { + assert(!operands.empty()); + auto val = operands.front(); + operands = operands.drop_front(); + auto src_type = val.getType(); + auto dst_type = arg.getType(); + if (src_type != dst_type) + { + ret.push_back(rewriter.create(op.getLoc(), dst_type, val)); + changed = true; + } + else + { + ret.push_back(val); + } + } + }; + + llvm::SmallVector true_vals; + llvm::SmallVector false_vals; + auto true_dest = op.getTrueDest(); + auto false_dest = op.getFalseDest(); + process_operand(*true_dest, true_vals); + process_operand(*false_dest, false_vals); + if (changed) + { + rewriter.create(op.getLoc(), cond, true_dest, true_vals, false_dest, false_vals); + rewriter.eraseOp(op); + return mlir::success(); + } + return mlir::failure(); + } +}; + mlir::Type coerce(mlir::Type type0, mlir::Type type1) { // TODO: proper rules @@ -514,33 +564,27 @@ struct CallOpLowering : public mlir::OpConversionPattern plier::PyCallOp op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter &rewriter) const override { - llvm::errs() << "CallOpLowering\n"; if (operands.empty()) { - llvm::errs() << "CallOpLowering 1\n"; return mlir::failure(); } auto func_type = operands[0].getType(); if (!func_type.isa()) { - llvm::errs() << "CallOpLowering 2\n"; return mlir::failure(); } auto name = func_type.cast().getName(); if (!name.consume_front("Function(") || !name.consume_back(")")) { - llvm::errs() << "CallOpLowering 3\n"; return mlir::failure(); } for (auto& c : builtin_calls) { if (c.first == name) { - llvm::errs() << "CallOpLowering 4\n"; return c.second(op, operands, rewriter); } } - llvm::errs() << "CallOpLowering 5\n"; return mlir::failure(); } }; @@ -553,6 +597,7 @@ struct CastOpLowering : public mlir::OpConversionPattern plier::CastOp op, llvm::ArrayRef operands, mlir::ConversionPatternRewriter &rewriter) const { + assert(1 == operands.size()); auto converter = getTypeConverter(); assert(nullptr != converter); @@ -560,6 +605,11 @@ struct CastOpLowering : public mlir::OpConversionPattern auto dst_type = converter->convertType(op.getType()); if (dst_type && is_supported_type(src_type) && is_supported_type(dst_type)) { + if (src_type == dst_type) + { + rewriter.replaceOp(op, operands[0]); + return mlir::success(); + } auto new_op = do_cast(dst_type, op.getOperand(), rewriter); rewriter.replaceOp(op, new_op); return mlir::success(); @@ -657,27 +707,42 @@ bool check_for_plier_types(mlir::Type type) return false; } +bool check_op_for_plier_types(mlir::Value val) +{ + return check_for_plier_types(val.getType()); +} + +template +mlir::Value cast_materializer( + mlir::OpBuilder& builder, T type, mlir::ValueRange inputs, + mlir::Location loc) +{ + assert(inputs.size() == 1); + if (type == inputs[0].getType()) + { + return inputs[0]; + } + return builder.create(loc, type, inputs[0]); +} + void PlierToStdPass::runOnOperation() { mlir::TypeConverter type_converter; - type_converter.addConversion([](plier::Type type) { return type; }); +// type_converter.addConversion([](plier::Type type) { return type; }); type_converter.addConversion([](plier::Type type)->llvm::Optional { return map_plier_type(type); }); - type_converter.addSourceMaterialization( - [](mlir::OpBuilder& builder, plier::Type type, mlir::ValueRange inputs, mlir::Location loc)->mlir::Value - { - assert(inputs.size() == 1); - return builder.create(loc, type, inputs[0]); - }); + type_converter.addSourceMaterialization(&cast_materializer); type_converter.addTargetMaterialization( - [](mlir::OpBuilder& /*builder*/, plier::Type /*type*/, mlir::ValueRange inputs, mlir::Location /*loc*/)->mlir::Value + [](mlir::OpBuilder& /*builder*/, mlir::Type /*type*/, mlir::ValueRange inputs, mlir::Location /*loc*/)->mlir::Value { assert(inputs.size() == 1); // TODO return inputs[0]; }); +// type_converter.addArgumentMaterialization(&cast_materializer); +// type_converter.addArgumentMaterialization(&cast_materializer); mlir::OwningRewritePatternList patterns; @@ -699,18 +764,26 @@ void PlierToStdPass::runOnOperation() target.addDynamicallyLegalDialect( [](mlir::Operation* op)->bool { - return !llvm::any_of(op->getOperandTypes(), &check_for_plier_types) || + return !llvm::any_of(op->getOperandTypes(), &check_for_plier_types) && !llvm::any_of(op->getResultTypes(), &check_for_plier_types); }); + target.addDynamicallyLegalOp( + [](mlir::CondBranchOp op) + { + return !check_op_for_plier_types(op.getCondition()) && + !llvm::any_of(op.getTrueOperands(), &check_op_for_plier_types) && + !llvm::any_of(op.getFalseOperands(), &check_op_for_plier_types); + }); mlir::populateFuncOpTypeConversionPattern(patterns, &ctx, type_converter); patterns.insert< ConstOpLowering, - CastOpLowering, ReturnOpLowering, + SelectOpLowering, + CondBrOpLowering, + CastOpLowering, BinOpLowering, - CallOpLowering, - SelectOpLowering + CallOpLowering >(type_converter, &ctx); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) @@ -718,12 +791,24 @@ void PlierToStdPass::runOnOperation() signalPassFailure(); return; } + + patterns.clear(); + patterns.insert< + CastOpLowering + >(type_converter, &ctx); + // final casts cleanup, investigate how to get rid of that + if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) + { + signalPassFailure(); + return; + } } void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(std::make_unique()); + pm.addPass(mlir::createCanonicalizerPass()); } } From 715529e7d66244961a7bbfc8a8b7f119c2126c67 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sun, 18 Oct 2020 21:01:42 +0300 Subject: [PATCH 105/259] update --- mlir-compiler/test.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index d761cdbbdd8..44419e78f33 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -1,4 +1,5 @@ import numba +import numpy as np _tests_total = 0 _tests_passes = 0 @@ -7,6 +8,9 @@ def ret(a): return a +def const(): + return 42 + def sum1(a): return a + 42 @@ -40,12 +44,18 @@ def tuple(a,b,c): t = (a,b,c) return t[0] + t[1] + t[2] -def loop(n): +def arr_loop(): res = 0 - for i in range(n): + arr = [1,2,3] + for i in arr: res = res + i return res +def range_loop(n): + res = 0 + for i in range(n): + res = res + i + return res def test(func, params): global _tests_total @@ -67,8 +77,10 @@ def test(func, params): print('FAILED') _failed_tests.append(test_name) +print('=========================================================') test(ret, (7,)) +test(const, ()) test(sum1, (5,)) test(sum2, (3,4)) test(cond, (5,6)) @@ -79,7 +91,9 @@ def test(func, params): test(call, (1,2,3)) test(tuple, (1,2,3)) test(tuple, (1,2.0,3)) -test(loop, (8,)) +test(arr_loop, ()) +test(range_loop, (8,)) +test(sum2, (np.asarray([1,2,3]),np.asarray([4,5,6]))) print(f'Tests passed: {_tests_passes}/{_tests_total}') if (len(_failed_tests) != 0): From f21faeb5a015194f7a4fb4e2bed6db69d7c20cd5 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 19 Oct 2020 14:46:30 +0300 Subject: [PATCH 106/259] mlir dialect conversion is garbage --- mlir-compiler/src/pipelines/plier_to_std.cpp | 154 ++++++------------ .../src/rewrites/type_conversion.cpp | 4 +- .../src/rewrites/type_conversion.hpp | 4 +- 3 files changed, 58 insertions(+), 104 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 108fc1af332..2abf7785bcf 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -210,16 +210,15 @@ bool is_float(mlir::Type type) return type.isa(); } -struct ConstOpLowering : public mlir::OpConversionPattern +struct ConstOpLowering : public mlir::OpRewritePattern { - using mlir::OpConversionPattern::OpConversionPattern; + ConstOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( - plier::ConstOp op, llvm::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const override + plier::ConstOp op, mlir::PatternRewriter &rewriter) const override { - (void)operands; - assert(operands.empty()); auto value = op.val(); if (!is_supported_type(value.getType())) { @@ -230,14 +229,16 @@ struct ConstOpLowering : public mlir::OpConversionPattern } }; -struct ReturnOpLowering : public mlir::OpConversionPattern +struct ReturnOpLowering : public mlir::OpRewritePattern { - using mlir::OpConversionPattern::OpConversionPattern; + ReturnOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( - mlir::ReturnOp op, llvm::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const override + mlir::ReturnOp op, mlir::PatternRewriter &rewriter) const override { + auto operands = op.getOperands(); auto func = mlir::cast(op.getParentOp()); auto res_types = func.getType().getResults(); assert(res_types.size() == operands.size()); @@ -268,14 +269,16 @@ struct ReturnOpLowering : public mlir::OpConversionPattern } }; -struct SelectOpLowering : public mlir::OpConversionPattern +struct SelectOpLowering : public mlir::OpRewritePattern { - using mlir::OpConversionPattern::OpConversionPattern; + SelectOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( - mlir::SelectOp op, llvm::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const override + mlir::SelectOp op, mlir::PatternRewriter &rewriter) const override { + auto operands = op.getOperands(); assert(operands.size() == 3); auto true_val = operands[1]; auto false_val = operands[2]; @@ -290,14 +293,16 @@ struct SelectOpLowering : public mlir::OpConversionPattern } }; -struct CondBrOpLowering : public mlir::OpConversionPattern +struct CondBrOpLowering : public mlir::OpRewritePattern { - using mlir::OpConversionPattern::OpConversionPattern; + CondBrOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( - mlir::CondBranchOp op, llvm::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const override + mlir::CondBranchOp op, mlir::PatternRewriter &rewriter) const override { + auto operands = op.getOperands(); assert(!operands.empty()); auto cond = operands.front(); operands = operands.drop_front(); @@ -442,14 +447,16 @@ mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& llvm_unreachable("Unhandled cast"); } -struct BinOpLowering : public mlir::OpConversionPattern +struct BinOpLowering : public mlir::OpRewritePattern { - using mlir::OpConversionPattern::OpConversionPattern; + BinOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( - plier::BinOp op, llvm::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const override + plier::BinOp op, mlir::PatternRewriter &rewriter) const override { + auto operands = op.getOperands(); assert(operands.size() == 2); auto type0 = operands[0].getType(); auto type1 = operands[1].getType(); @@ -527,8 +534,9 @@ struct BinOpLowering : public mlir::OpConversionPattern } }; -mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, mlir::PatternRewriter& rewriter) { + auto operands = op.getOperands(); if (operands.size() != 2) { return mlir::failure(); @@ -551,19 +559,21 @@ mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter&); +using call_lowerer_func_t = mlir::LogicalResult(*)(plier::PyCallOp, mlir::PatternRewriter&); const constexpr std::pair builtin_calls[] = { {"", &lower_bool_cast}, }; -struct CallOpLowering : public mlir::OpConversionPattern +struct CallOpLowering : public mlir::OpRewritePattern { - using mlir::OpConversionPattern::OpConversionPattern; + CallOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( - plier::PyCallOp op, llvm::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const override + plier::PyCallOp op, mlir::PatternRewriter &rewriter) const override { + auto operands = op.getOperands(); if (operands.empty()) { return mlir::failure(); @@ -582,32 +592,29 @@ struct CallOpLowering : public mlir::OpConversionPattern { if (c.first == name) { - return c.second(op, operands, rewriter); + return c.second(op, rewriter); } } return mlir::failure(); } }; -struct CastOpLowering : public mlir::OpConversionPattern +struct CastOpLowering : public mlir::OpRewritePattern { - using mlir::OpConversionPattern::OpConversionPattern; + CastOpLowering(mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context): + OpRewritePattern(context), converter(typeConverter) {} mlir::LogicalResult matchAndRewrite( - plier::CastOp op, llvm::ArrayRef operands, - mlir::ConversionPatternRewriter &rewriter) const + plier::CastOp op, mlir::PatternRewriter &rewriter) const override { - - assert(1 == operands.size()); - auto converter = getTypeConverter(); - assert(nullptr != converter); - auto src_type = operands[0].getType(); - auto dst_type = converter->convertType(op.getType()); + auto src_type = op.getOperand().getType(); + auto dst_type = converter.convertType(op.getType()); if (dst_type && is_supported_type(src_type) && is_supported_type(dst_type)) { if (src_type == dst_type) { - rewriter.replaceOp(op, operands[0]); + rewriter.replaceOp(op, op.getOperand()); return mlir::success(); } auto new_op = do_cast(dst_type, op.getOperand(), rewriter); @@ -616,6 +623,9 @@ struct CastOpLowering : public mlir::OpConversionPattern } return mlir::failure(); } + +private: + mlir::TypeConverter& converter; }; mlir::Operation* change_op_ret_type(mlir::Operation* op, @@ -728,87 +738,31 @@ mlir::Value cast_materializer( void PlierToStdPass::runOnOperation() { mlir::TypeConverter type_converter; -// type_converter.addConversion([](plier::Type type) { return type; }); type_converter.addConversion([](plier::Type type)->llvm::Optional { return map_plier_type(type); }); - type_converter.addSourceMaterialization(&cast_materializer); - type_converter.addTargetMaterialization( - [](mlir::OpBuilder& /*builder*/, mlir::Type /*type*/, mlir::ValueRange inputs, mlir::Location /*loc*/)->mlir::Value - { - assert(inputs.size() == 1); - // TODO - return inputs[0]; - }); -// type_converter.addArgumentMaterialization(&cast_materializer); -// type_converter.addArgumentMaterialization(&cast_materializer); mlir::OwningRewritePatternList patterns; - auto& ctx = getContext(); - mlir::ConversionTarget target(ctx); - target.addDynamicallyLegalOp( - [](mlir::FuncOp op)->bool - { - return !check_for_plier_types(op.getType()); - }); - target.addDynamicallyLegalOp( - [&](plier::CastOp op)->bool - { - auto src_type = op.getOperand().getType(); - auto dst_type = type_converter.convertType(op.getType()); - return !dst_type || !is_supported_type(src_type) || !is_supported_type(dst_type); - }); - target.addLegalOp(); - target.addDynamicallyLegalDialect( - [](mlir::Operation* op)->bool - { - return !llvm::any_of(op->getOperandTypes(), &check_for_plier_types) && - !llvm::any_of(op->getResultTypes(), &check_for_plier_types); - }); - target.addDynamicallyLegalOp( - [](mlir::CondBranchOp op) - { - return !check_op_for_plier_types(op.getCondition()) && - !llvm::any_of(op.getTrueOperands(), &check_op_for_plier_types) && - !llvm::any_of(op.getFalseOperands(), &check_op_for_plier_types); - }); - - mlir::populateFuncOpTypeConversionPattern(patterns, &ctx, type_converter); patterns.insert< - ConstOpLowering, + FuncOpSignatureConversion, ReturnOpLowering, + ConstOpLowering, SelectOpLowering, CondBrOpLowering, CastOpLowering, BinOpLowering, CallOpLowering - >(type_converter, &ctx); - - if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) - { - signalPassFailure(); - return; - } + >(type_converter, &getContext()); - patterns.clear(); - patterns.insert< - CastOpLowering - >(type_converter, &ctx); - // final casts cleanup, investigate how to get rid of that - if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, patterns))) - { - signalPassFailure(); - return; - } + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::createCanonicalizerPass()); pm.addPass(std::make_unique()); - pm.addPass(mlir::createCanonicalizerPass()); } } diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index af89bda511f..775f1552640 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -55,8 +55,8 @@ mlir::LogicalResult convertRegionTypes( } } -FuncOpSignatureConversion::FuncOpSignatureConversion( - mlir::MLIRContext* ctx, mlir::TypeConverter& conv) +FuncOpSignatureConversion::FuncOpSignatureConversion(mlir::TypeConverter& conv, + mlir::MLIRContext* ctx) : OpRewritePattern(ctx), converter(conv) {} mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( diff --git a/mlir-compiler/src/rewrites/type_conversion.hpp b/mlir-compiler/src/rewrites/type_conversion.hpp index d0c13f22016..6af35454ead 100644 --- a/mlir-compiler/src/rewrites/type_conversion.hpp +++ b/mlir-compiler/src/rewrites/type_conversion.hpp @@ -10,8 +10,8 @@ class TypeConverter; struct FuncOpSignatureConversion : public mlir::OpRewritePattern { - FuncOpSignatureConversion(mlir::MLIRContext* ctx, - mlir::TypeConverter& conv); + FuncOpSignatureConversion(mlir::TypeConverter& conv, + mlir::MLIRContext* ctx); /// Hook for derived classes to implement combined matching and rewriting. mlir::LogicalResult From 0af9f76e44d530dbb23109b7ce3383cf35ccb7fc Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 21 Oct 2020 18:17:49 +0300 Subject: [PATCH 107/259] plier getattr --- mlir-compiler/include/plier/PlierOps.td | 12 ++++++++++++ mlir-compiler/src/dialect.cpp | 8 +++++++- mlir-compiler/src/lowering.cpp | 9 +++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index ce3b2482c33..7e10d021b67 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -163,4 +163,16 @@ def DelOp : Plier_Op<"del", []> { AnyType:$value); } +def GetattrOp : Plier_Op<"getattr", []> { + let arguments = (ins + AnyType:$value, + StrAttr:$name); + + let results = (outs AnyType); + + let builders = [ + OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, ::mlir::StringRef name"> + ]; +} + #endif // PLIER_OPS diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index ffab9bb8dff..945e7dc1a46 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -211,7 +211,7 @@ void PairfirstOp::build(OpBuilder &builder, OperationState &state, //} void PairsecondOp::build(OpBuilder &builder, OperationState &state, - ::mlir::Value value) + ::mlir::Value value) { PairsecondOp::build(builder, state, PyType::getUndefined(state.getContext()), value); @@ -226,6 +226,12 @@ void PairsecondOp::build(OpBuilder &builder, OperationState &state, // return nullptr; //} +void GetattrOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value, mlir::StringRef name) { + GetattrOp::build(builder, state, PyType::getUndefined(state.getContext()), + value, name); +} + } #define GET_OP_CLASSES diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 5fe6bca34dd..0e9efbaef53 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -291,6 +291,7 @@ struct plier_lowerer {"iternext", &plier_lowerer::lower_simple}, {"pair_first", &plier_lowerer::lower_simple}, {"pair_second", &plier_lowerer::lower_simple}, + {"getattr", &plier_lowerer::lower_getattr}, }; for (auto& h : handlers) { @@ -409,6 +410,14 @@ struct plier_lowerer report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); } + mlir::Value lower_getattr(const py::handle& inst) + { + auto val = inst.attr("value"); + auto value = loadvar(val); + auto name = val.attr("name").cast(); + return builder.create(get_current_loc(), value, name); + } + void storevar(mlir::Value val, const py::handle& inst) { vars_map[inst.attr("name").cast()] = val; From 769f34a942aed8d474e2de5bc867b9ae32c7f16c Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 21 Oct 2020 18:22:21 +0300 Subject: [PATCH 108/259] some cleanup --- mlir-compiler/include/plier/PlierOps.td | 24 ++++++++++---------- mlir-compiler/include/plier/dialect.hpp | 6 ----- mlir-compiler/src/dialect.cpp | 16 ++++++------- mlir-compiler/src/pipelines/plier_to_std.cpp | 4 ++-- 4 files changed, 22 insertions(+), 28 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 7e10d021b67..e5efeee090b 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -26,7 +26,7 @@ def ArgOp : Plier_Op<"arg", [NoSideEffect]> { let hasFolder = 1; let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, unsigned index, StringRef name"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, unsigned index, ::mlir::StringRef name"> ]; } @@ -37,7 +37,7 @@ def ConstOp : Plier_Op<"const", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Attribute val"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Attribute val"> ]; } @@ -48,7 +48,7 @@ def GlobalOp : Plier_Op<"global", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, StringRef name"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::StringRef name"> ]; } @@ -61,7 +61,7 @@ def BinOp : Plier_Op<"binop", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value rhs, ::mlir::Value lhs, StringRef op"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value rhs, ::mlir::Value lhs, ::mlir::StringRef op"> ]; } @@ -83,7 +83,7 @@ def PyCallOp : Plier_Op<"call", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value func, " + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value func, " "::mlir::ValueRange args, " "::mlir::ArrayRef> kwargs"> ]; @@ -96,7 +96,7 @@ def BuildTupleOp : Plier_Op<"build_tuple", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::ValueRange args"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::ValueRange args"> ]; } @@ -109,7 +109,7 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, " + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value, " "::mlir::Value index_var, unsigned index"> ]; } @@ -121,7 +121,7 @@ def GetiterOp : Plier_Op<"getiter", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value"> ]; } @@ -132,7 +132,7 @@ def IternextOp : Plier_Op<"iternext", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value"> ]; } @@ -143,7 +143,7 @@ def PairfirstOp : Plier_Op<"pair_first", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value"> ]; } @@ -154,7 +154,7 @@ def PairsecondOp : Plier_Op<"pair_second", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value"> ]; } @@ -171,7 +171,7 @@ def GetattrOp : Plier_Op<"getattr", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"OpBuilder &b, OperationState &state, ::mlir::Value value, ::mlir::StringRef name"> + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value, ::mlir::StringRef name"> ]; } diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 451652df992..f5081c48797 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -7,12 +7,6 @@ #include #include "plier/PlierOpsEnums.h.inc" - -namespace plier -{ -using namespace mlir; // TODO: remove -} - #include "plier/PlierOpsDialect.h.inc" #define GET_OP_CLASSES #include "plier/PlierOps.h.inc" diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 945e7dc1a46..b0387b060f3 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -60,7 +60,7 @@ PyType PyType::get(mlir::MLIRContext* context, llvm::StringRef name) return Base::get(context, name); } -PyType PyType::getUndefined(MLIRContext* context) +PyType PyType::getUndefined(mlir::MLIRContext* context) { return Base::get(context, ""); } @@ -119,7 +119,7 @@ mlir::OpFoldResult CastOp::fold(llvm::ArrayRef /*operands*/) return nullptr; } -void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func, +void PyCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value func, mlir::ValueRange args, mlir::ArrayRef> kwargs) { auto ctx = builder.getContext(); @@ -138,7 +138,7 @@ void PyCallOp::build(OpBuilder &builder, OperationState &state, mlir::Value func func, all_args, kw_start, mlir::ArrayAttr::get(kw_names, ctx)); } -void BuildTupleOp::build(OpBuilder &builder, OperationState &state, +void BuildTupleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::ValueRange args) { BuildTupleOp::build(builder, state, @@ -159,7 +159,7 @@ void BuildTupleOp::build(OpBuilder &builder, OperationState &state, // return mlir::failure(); //} -void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, +void StaticGetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value, ::mlir::Value index_var, unsigned int index) { @@ -180,21 +180,21 @@ void StaticGetItemOp::build(OpBuilder &builder, OperationState &state, // return nullptr; //} -void GetiterOp::build(OpBuilder &builder, OperationState &state, +void GetiterOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value) { GetiterOp::build(builder, state, PyType::getUndefined(state.getContext()), value); } -void IternextOp::build(OpBuilder &builder, OperationState &state, +void IternextOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value) { IternextOp::build(builder, state, PyType::getUndefined(state.getContext()), value); } -void PairfirstOp::build(OpBuilder &builder, OperationState &state, +void PairfirstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value) { PairfirstOp::build(builder, state, PyType::getUndefined(state.getContext()), @@ -210,7 +210,7 @@ void PairfirstOp::build(OpBuilder &builder, OperationState &state, // return nullptr; //} -void PairsecondOp::build(OpBuilder &builder, OperationState &state, +void PairsecondOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value) { PairsecondOp::build(builder, state, diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 2abf7785bcf..7947f1f8621 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -648,7 +648,7 @@ struct ExpandTuples : public mlir::RewritePattern } mlir::LogicalResult - matchAndRewrite(plier::Operation* op, mlir::PatternRewriter& rewriter) const override + matchAndRewrite(mlir::Operation* op, mlir::PatternRewriter& rewriter) const override { if (op->getResultTypes().size() != 1 || !op->getResultTypes()[0].isa() || @@ -738,7 +738,7 @@ mlir::Value cast_materializer( void PlierToStdPass::runOnOperation() { mlir::TypeConverter type_converter; - type_converter.addConversion([](plier::Type type)->llvm::Optional + type_converter.addConversion([](mlir::Type type)->llvm::Optional { return map_plier_type(type); }); From ee97e1b75d42563c3abd768fc4d61e4b2b30e4ae Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 21 Oct 2020 20:59:14 +0300 Subject: [PATCH 109/259] some type conversion --- .../src/pipelines/plier_to_linalg.cpp | 90 +++++++++++++++---- mlir-compiler/src/pipelines/plier_to_std.cpp | 17 +++- mlir-compiler/src/pipelines/plier_to_std.hpp | 7 ++ 3 files changed, 94 insertions(+), 20 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index f1e11d73fe3..7a555484378 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -10,12 +10,71 @@ #include "plier/dialect.hpp" #include "pipelines/plier_to_std.hpp" +#include "rewrites/type_conversion.hpp" #include "base_pipeline.hpp" #include "pipeline_registry.hpp" +#include + namespace { +bool parse_layout(llvm::StringRef& name) +{ + return name.consume_back("C"); // TODO +} + +template +bool consume_int_back(llvm::StringRef& name, T& result) +{ + unsigned len = 0; + auto tmp_name = name; + while (!tmp_name.empty() && std::isdigit(tmp_name.back())) + { + ++len; + tmp_name = tmp_name.drop_back(); + } + tmp_name = name.substr(name.size() - len); + if (!tmp_name.consumeInteger(10, result)) + { + name = name.substr(0, name.size() - len); + return true; + } + return false; +} + +mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter, + llvm::StringRef& name) +{ + unsigned num_dims = 0; + if (name.consume_front("array(") && + name.consume_back(")") && + parse_layout(name) && + name.consume_back(", ") && + name.consume_back("d") && + consume_int_back(name, num_dims) && + name.consume_back(", ") && + !name.empty()) + { + if (auto type = conveter.convertType(plier::PyType::get(&ctx, name))) + { + llvm::SmallVector shape(num_dims, -1); + return mlir::MemRefType::get(shape, type); + } + } + return nullptr; +} + + +mlir::Type map_plier_type(mlir::TypeConverter& converter, mlir::Type type) +{ + if (type.isa()) + { + auto name = type.cast().getName(); + return map_array_type(*type.getContext(), converter, name); + } + return nullptr; +} struct PlierToLinalgPass : public mlir::PassWrapper> @@ -32,25 +91,24 @@ struct PlierToLinalgPass : void PlierToLinalgPass::runOnOperation() { -// mlir::TypeConverter type_converter; -// type_converter.addConversion([](plier::Type type)->llvm::Optional -// { -// return map_plier_type(type); -// }); + mlir::TypeConverter type_converter; + populate_std_type_converter(type_converter); + type_converter.addConversion([&](plier::PyType type)->llvm::Optional + { + auto ret = map_plier_type(type_converter, type); + if (!ret) + { + return llvm::None; + } + return ret; + }); mlir::OwningRewritePatternList patterns; -// patterns.insert(&getContext(), type_converter); -// patterns.insert(&getContext()); - - auto apply_conv = [&]() - { - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); - }; + patterns.insert< + FuncOpSignatureConversion + >(type_converter, &getContext()); - apply_conv(); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 7947f1f8621..63f88f2ee41 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -738,10 +738,7 @@ mlir::Value cast_materializer( void PlierToStdPass::runOnOperation() { mlir::TypeConverter type_converter; - type_converter.addConversion([](mlir::Type type)->llvm::Optional - { - return map_plier_type(type); - }); + populate_std_type_converter(type_converter); mlir::OwningRewritePatternList patterns; @@ -766,6 +763,18 @@ void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) } } +void populate_std_type_converter(mlir::TypeConverter& converter) +{ + converter.addConversion([](mlir::Type type)->llvm::Optional + { + auto ret = map_plier_type(type); + if (!ret) + { + return llvm::None; + } + return ret; + }); +} void register_plier_to_std_pipeline(PipelineRegistry& registry) { diff --git a/mlir-compiler/src/pipelines/plier_to_std.hpp b/mlir-compiler/src/pipelines/plier_to_std.hpp index 1768d294a51..2965d10335a 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.hpp +++ b/mlir-compiler/src/pipelines/plier_to_std.hpp @@ -7,6 +7,13 @@ namespace llvm class StringRef; } +namespace mlir +{ +class TypeConverter; +} + +void populate_std_type_converter(mlir::TypeConverter& converter); + void register_plier_to_std_pipeline(PipelineRegistry& registry); llvm::StringRef plier_to_std_pipeline_name(); From ac435fefec8e289e6ebd20e8533bba9a18ad7566 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 21 Oct 2020 21:06:00 +0300 Subject: [PATCH 110/259] fix to conversion --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 ++ mlir-compiler/src/pipelines/plier_to_std.cpp | 2 ++ mlir-compiler/src/rewrites/type_conversion.cpp | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 7a555484378..61d67563ec2 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -92,6 +92,8 @@ struct PlierToLinalgPass : void PlierToLinalgPass::runOnOperation() { mlir::TypeConverter type_converter; + // Convert unknown types to itself + type_converter.addConversion([](mlir::Type type) { return type; }); populate_std_type_converter(type_converter); type_converter.addConversion([&](plier::PyType type)->llvm::Optional { diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 63f88f2ee41..233dced6f70 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -738,6 +738,8 @@ mlir::Value cast_materializer( void PlierToStdPass::runOnOperation() { mlir::TypeConverter type_converter; + // Convert unknown types to itself + type_converter.addConversion([](mlir::Type type) { return type; }); populate_std_type_converter(type_converter); mlir::OwningRewritePatternList patterns; diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index 775f1552640..61ed1bdfc53 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -68,8 +68,8 @@ mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( mlir::TypeConverter::SignatureConversion result(type.getNumInputs()); llvm::SmallVector newResults; if (mlir::failed(converter.convertSignatureArgs(type.getInputs(), result)) || - mlir::failed(converter.convertTypes(type.getResults(), newResults)) || - mlir::failed(convertRegionTypes(&funcOp.getBody(), converter, false))) + mlir::failed(converter.convertTypes(type.getResults(), newResults)) || + mlir::failed(convertRegionTypes(&funcOp.getBody(), converter, false))) { return mlir::failure(); } From af76faa81b1b62cbc28a2ff90b1379bfa5ed1f07 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 22 Oct 2020 14:09:28 +0300 Subject: [PATCH 111/259] call op lowering --- .../src/pipelines/plier_to_linalg.cpp | 76 +++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 61d67563ec2..deea0237ea4 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -76,6 +76,78 @@ mlir::Type map_plier_type(mlir::TypeConverter& converter, mlir::Type type) return nullptr; } +llvm::StringRef extract_bound_func_name(llvm::StringRef name) +{ + auto len = name.find(' '); + return name.substr(0, len); +} + +struct CallOpLowering : public mlir::OpRewritePattern +{ + using check_t = mlir::LogicalResult(*)(llvm::StringRef, llvm::ArrayRef); + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::StringRef, llvm::ArrayRef, mlir::PatternRewriter&); + using resolver_t = std::pair; + + CallOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context, + llvm::ArrayRef resolvers): + OpRewritePattern(context), resolvers(resolvers) {} + + mlir::LogicalResult matchAndRewrite( + plier::PyCallOp op, mlir::PatternRewriter &rewriter) const override + { + auto operands = op.getOperands(); + if (operands.empty()) + { + return mlir::failure(); + } + auto func_type = operands[0].getType(); + if (!func_type.isa()) + { + return mlir::failure(); + } + auto name = func_type.cast().getName(); + llvm::SmallVector arg_types; + llvm::SmallVector args; + if (name.consume_front("Function(") && name.consume_back(")")) + { + llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); + llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); + // TODO kwargs + } + else if (name.consume_front("BoundFunction(") && name.consume_back(")")) + { + auto getattr = mlir::dyn_cast(operands[0].getDefiningOp()); + if (!getattr) + { + return mlir::failure(); + } + arg_types.push_back(getattr.getOperand().getType()); + args.push_back(getattr.getOperand()); + llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); + llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); + name = extract_bound_func_name(name); + // TODO kwargs + } + else + { + return mlir::failure(); + } + for (auto& c : resolvers) + { + if (mlir::succeeded(c.first(name, arg_types))) + { + return c.second(op, name, args, rewriter); + } + } + + return mlir::failure(); + } + +private: + llvm::ArrayRef resolvers; +}; + struct PlierToLinalgPass : public mlir::PassWrapper> { @@ -110,6 +182,10 @@ void PlierToLinalgPass::runOnOperation() FuncOpSignatureConversion >(type_converter, &getContext()); + patterns.insert< + CallOpLowering + >(type_converter, &getContext(), llvm::None); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } From 6c0f4b7277d4b4ce632868ef9ab1e99b18c04a39 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 23 Oct 2020 14:03:31 +0300 Subject: [PATCH 112/259] some work on linalg translation --- .../src/pipelines/plier_to_linalg.cpp | 49 ++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index deea0237ea4..4c9e161fef8 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -148,6 +149,47 @@ struct CallOpLowering : public mlir::OpRewritePattern llvm::ArrayRef resolvers; }; +mlir::LogicalResult numpy_check(llvm::StringRef name, llvm::ArrayRef types) +{ + return mlir::success(name == "array.sum"); // TODO +} + +mlir::LogicalResult numpy_rewrite( + plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, + mlir::PatternRewriter& rewriter) +{ + assert(args.size() == 1); + mlir::Value inputs[] = { args[0] }; + auto elem_type = inputs[0].getType().cast().getElementType(); + auto res_type = mlir::MemRefType::get({}, elem_type); + auto loc = op.getLoc(); + mlir::Value outputs[] = { rewriter.create(loc, res_type) }; + mlir::AffineMap map[] = { + mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), + mlir::AffineMap::get(1, 0, op.getContext()), + }; + mlir::StringRef iterators[] = { "reduction" }; + auto body = [](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + { + assert(args.size() == 2); + mlir::Value res = builder.create(loc, args[0], args[1]); + builder.create(loc, res); + }; + auto o = rewriter.create( + loc, + llvm::makeArrayRef(inputs), + llvm::makeArrayRef(outputs), + llvm::makeArrayRef(map), + llvm::makeArrayRef(iterators), + llvm::StringRef(), // doc + llvm::StringRef(), // library call + nullptr, // symbol source + body); + mlir::Value res = rewriter.create(loc, outputs[0]); + rewriter.replaceOp(op, res); + return mlir::failure(); +} + struct PlierToLinalgPass : public mlir::PassWrapper> { @@ -156,6 +198,7 @@ struct PlierToLinalgPass : { registry.insert(); registry.insert(); + registry.insert(); } void runOnOperation() override; @@ -182,9 +225,13 @@ void PlierToLinalgPass::runOnOperation() FuncOpSignatureConversion >(type_converter, &getContext()); + CallOpLowering::resolver_t resolvers[] = { + {numpy_check, numpy_rewrite} + }; + patterns.insert< CallOpLowering - >(type_converter, &getContext(), llvm::None); + >(type_converter, &getContext(), resolvers); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } From 2eb7a0c4c247575691cb3c1cc77b05d82040f97c Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 23 Oct 2020 19:12:15 +0300 Subject: [PATCH 113/259] some work on linalg --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/include/plier/PlierOps.td | 2 +- .../src/pipelines/plier_to_linalg.cpp | 44 ++++++++++++++++--- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index b0818370b45..c5b84dbfa64 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -64,6 +64,8 @@ target_link_libraries(${PROJECT_NAME} PRIVATE MLIRTargetLLVMIR MLIRTransforms MLIRStandardToLLVM + MLIRLinalgTransforms + MLIRSCFToStandard ) target_include_directories(${PROJECT_NAME} PRIVATE diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index e5efeee090b..9ee6eade8fb 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -163,7 +163,7 @@ def DelOp : Plier_Op<"del", []> { AnyType:$value); } -def GetattrOp : Plier_Op<"getattr", []> { +def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { let arguments = (ins AnyType:$value, StrAttr:$name); diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 4c9e161fef8..2f83b9bf7d6 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -1,8 +1,12 @@ #include "pipelines/plier_to_linalg.hpp" #include +#include #include #include +#include +#include +#include #include #include #include @@ -160,7 +164,8 @@ mlir::LogicalResult numpy_rewrite( { assert(args.size() == 1); mlir::Value inputs[] = { args[0] }; - auto elem_type = inputs[0].getType().cast().getElementType(); +// auto elem_type = inputs[0].getType().cast().getElementType(); + auto elem_type = mlir::IntegerType::get(64, op.getContext()); auto res_type = mlir::MemRefType::get({}, elem_type); auto loc = op.getLoc(); mlir::Value outputs[] = { rewriter.create(loc, res_type) }; @@ -168,14 +173,15 @@ mlir::LogicalResult numpy_rewrite( mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), mlir::AffineMap::get(1, 0, op.getContext()), }; - mlir::StringRef iterators[] = { "reduction" }; - auto body = [](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + mlir::StringRef iterators[] = { "parallel" }; + auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) { assert(args.size() == 2); - mlir::Value res = builder.create(loc, args[0], args[1]); + auto val = builder.create(loc, args[0], elem_type); + mlir::Value res = builder.create(loc, val, args[1]); builder.create(loc, res); }; - auto o = rewriter.create( + rewriter.create( loc, llvm::makeArrayRef(inputs), llvm::makeArrayRef(outputs), @@ -236,9 +242,37 @@ void PlierToLinalgPass::runOnOperation() (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } +struct LowerLinalgPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +void LowerLinalgPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + patterns.insert> + (&getContext(), mlir::linalg::LinalgLoweringType::ParallelLoops); + + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); +} + void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); + pm.addPass(mlir::createLowerToCFGPass()); } } From 7329c032edc88d685497579305af394913912890 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 23 Oct 2020 19:21:11 +0300 Subject: [PATCH 114/259] simplify --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 2f83b9bf7d6..8866d5d5362 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -187,9 +187,6 @@ mlir::LogicalResult numpy_rewrite( llvm::makeArrayRef(outputs), llvm::makeArrayRef(map), llvm::makeArrayRef(iterators), - llvm::StringRef(), // doc - llvm::StringRef(), // library call - nullptr, // symbol source body); mlir::Value res = rewriter.create(loc, outputs[0]); rewriter.replaceOp(op, res); From 5f974ca29fb276786ac2383469495a7ef4ee7643 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 23 Oct 2020 22:06:28 +0300 Subject: [PATCH 115/259] reduction --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 8866d5d5362..34ad5385941 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -173,7 +173,7 @@ mlir::LogicalResult numpy_rewrite( mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), mlir::AffineMap::get(1, 0, op.getContext()), }; - mlir::StringRef iterators[] = { "parallel" }; + mlir::StringRef iterators[] = { "reduction" }; auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) { assert(args.size() == 2); From 07c437a2f8f206c655cff08ebad2e674248875f9 Mon Sep 17 00:00:00 2001 From: Butygin Date: Mon, 26 Oct 2020 13:25:19 +0300 Subject: [PATCH 116/259] flag --- numba/core/compiler.py | 5 ++++- numba/core/typed_passes.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/numba/core/compiler.py b/numba/core/compiler.py index dd811a39066..c7f44dd7403 100644 --- a/numba/core/compiler.py +++ b/numba/core/compiler.py @@ -35,6 +35,8 @@ ObjectModeBackEnd, CompileInterpMode) +from numba.core.lowering import _use_mlir + class Flags(utils.ConfigOptions): # These options are all false by default, but the defaults are # different with the @jit decorator (see targets.options.TargetOptions). @@ -504,7 +506,8 @@ def define_typed_pipeline(state, name="typed"): pm.add_pass(NopythonTypeInference, "nopython frontend") pm.add_pass(AnnotateTypes, "annotate types") - pm.add_pass(MlirBackend, "mlir backend") + if _use_mlir: + pm.add_pass(MlirBackend, "mlir backend") # strip phis pm.add_pass(PreLowerStripPhis, "remove phis nodes") diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 7edddf33ca5..5a5b077a091 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -20,6 +20,7 @@ build_definitions, compute_cfg_from_blocks) from numba.core import postproc +from numba.core.lowering import _use_mlir @contextmanager def fallback_context(state, msg): @@ -367,7 +368,9 @@ def run_pass(self, state): with targetctx.push_code_library(library): lower = lowering.Lower(targetctx, library, fndesc, interp, metadata=metadata) - setattr(lower, 'mlir_blob', state.mlir_blob) + if _use_mlir: + setattr(lower, 'mlir_blob', state.mlir_blob) + lower.lower() if not flags.no_cpython_wrapper: lower.create_cpython_wrapper(flags.release_gil) From 2ab03e1c17c0483f79c4d9801f27210bd5f272bf Mon Sep 17 00:00:00 2001 From: Butygin Date: Tue, 27 Oct 2020 22:01:57 +0300 Subject: [PATCH 117/259] some work on array param translation --- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 222 +++++++++++++++++- 1 file changed, 220 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 8ecbd192f6d..c91e73617dc 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -1,6 +1,7 @@ #include "pipelines/lower_to_llvm.hpp" #include +#include #include #include #include @@ -9,6 +10,8 @@ #include #include +#include +#include #include #include #include @@ -87,8 +90,177 @@ mlir::Type getExceptInfoType(LLVMTypeHelper& type_helper) return mlir::LLVM::LLVMStructType::getLiteral(&type_helper.get_context(), elems); } +mlir::LLVM::LLVMStructType get_array_type(mlir::TypeConverter& converter, mlir::MemRefType type) +{ + assert(type); + auto ctx = type.getContext(); + auto i8p = mlir::LLVM::LLVMType::getInt8Ty(ctx).getPointerTo(); + auto i64 = mlir::LLVM::LLVMType::getIntNTy(ctx, 64); + auto data_type = converter.convertType(type.getElementType()).cast(); + assert(data_type); + auto shape_type = mlir::LLVM::LLVMArrayType::get(i64, static_cast(type.getRank())); + const mlir::LLVM::LLVMType members[] = { + i8p, // 0, meminfo + i8p, // 1, parent + i64, // 2, nitems + i64, // 3, itemsize + data_type.getPointerTo(), // 4, data + shape_type, // 5, shape + shape_type, // 6, strides + }; + return mlir::LLVM::LLVMStructType::getLiteral(ctx, members); +} + +template +void flatten_type(mlir::LLVM::LLVMType type, F&& func) +{ + if (auto struct_type = type.dyn_cast()) + { + for (auto elem : struct_type.getBody()) + { + flatten_type(elem, std::forward(func)); + } + } + else if (auto arr_type = type.dyn_cast()) + { + auto elem = arr_type.getElementType(); + auto size = arr_type.getNumElements(); + for (unsigned i = 0 ; i < size; ++i) + { + flatten_type(elem, std::forward(func)); + } + } + else + { + func(type); + } +} + +template +mlir::Value unflatten(mlir::LLVM::LLVMType type, mlir::Location loc, mlir::OpBuilder& builder, F&& next_func) +{ + namespace mllvm = mlir::LLVM; + if (auto struct_type = type.dyn_cast()) + { + mlir::Value val = builder.create(loc, struct_type); + for (auto elem : llvm::enumerate(struct_type.getBody())) + { + auto elem_index = builder.getI64ArrayAttr(static_cast(elem.index())); + auto elem_type = elem.value(); + auto elem_val = unflatten(elem_type, loc, builder, std::forward(next_func)); + val = builder.create(loc, val, elem_val, elem_index); + } + return val; + } + else if (auto arr_type = type.dyn_cast()) + { + auto elem_type = arr_type.getElementType(); + auto size = arr_type.getNumElements(); + mlir::Value val = builder.create(loc, arr_type); + for (unsigned i = 0 ; i < size; ++i) + { + auto elem_index = builder.getI64ArrayAttr(static_cast(i)); + auto elem_val = unflatten(elem_type, loc, builder, std::forward(next_func)); + val = builder.create(loc, val, elem_val, elem_index); + } + return val; + } + else + { + return next_func(); + } +} + +std::string gen_conversion_func_name(mlir::MemRefType memref_type) +{ + assert(memref_type); + std::string ret; + llvm::raw_string_ostream ss(ret); + ss << "__convert_memref_"; + memref_type.getElementType().print(ss); + ss.flush(); + return ret; +} + +const constexpr llvm::StringRef linkage_attr = "numba_linkage"; + +struct MemRefConversionCache +{ + mlir::FuncOp get_conversion_func( + mlir::ModuleOp module, mlir::OpBuilder& builder, mlir::MemRefType memref_type, + mlir::LLVM::LLVMStructType src_type, mlir::LLVM::LLVMStructType dst_type) + { + assert(memref_type); + assert(src_type); + assert(dst_type); + auto it = cache.find(memref_type); + if (it != cache.end()) + { + auto func = it->second; + assert(func.getType().getNumResults() == 1); + assert(func.getType().getResult(0) == dst_type); + return func; + } + auto func_name = gen_conversion_func_name(memref_type); + auto func_type = mlir::FunctionType::get(src_type, dst_type, builder.getContext()); + auto loc = builder.getUnknownLoc(); + auto new_func = mlir::FuncOp::create(loc, func_name, func_type); + new_func.setAttr(linkage_attr, mlir::StringAttr::get("internal", builder.getContext())); + module.push_back(new_func); + cache.insert({memref_type, new_func}); + mlir::OpBuilder::InsertionGuard guard(builder); + auto block = new_func.addEntryBlock(); + builder.setInsertionPointToStart(block); + namespace mllvm = mlir::LLVM; + mlir::Value arg = block->getArgument(0); + auto extract = [&](unsigned index) + { + auto res_type = src_type.getBody()[index]; + auto i = builder.getI64ArrayAttr(index); + return builder.create(loc, res_type, arg, i); + }; + auto ptr = extract(4); + auto shape = extract(5); + auto strides = extract(6); + auto i64 = mllvm::LLVMIntegerType::get(builder.getContext(), 64); + auto offset = builder.create(loc, i64, builder.getI64IntegerAttr(0)); + mlir::Value res = builder.create(loc, dst_type); + auto insert = [&](unsigned index, mlir::Value val) + { + auto i = builder.getI64ArrayAttr(index); + res = builder.create(loc, res, val, i); + }; + insert(0, ptr); + insert(1, ptr); + insert(2, offset); + insert(3, shape); + insert(4, strides); + builder.create(loc, res); + return new_func; + } +private: + llvm::DenseMap cache; +}; + +llvm::StringRef get_linkage(mlir::Operation* op) +{ + assert(nullptr != op); + llvm::errs() << "get_linkage 1\n"; + if (auto attr = op->getAttr(linkage_attr).dyn_cast_or_null()) + { + llvm::errs() << "get_linkage 2\n"; + return attr.getValue(); + } + llvm::errs() << "get_linkage 3\n"; + return {}; +} + void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) { + if (get_linkage(func) == "internal") + { + return; + } auto old_type = func.getType(); assert(old_type.getNumResults() == 1); auto& ctx = *old_type.getContext(); @@ -103,15 +275,61 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) auto add_arg = [&](mlir::Type type) { args.push_back(type); - func.getBody().insertArgument(index, type); + auto ret = func.getBody().insertArgument(index, type); ++index; + return ret; + }; + + MemRefConversionCache conversion_cache; + + mlir::OpBuilder builder(&ctx); + builder.setInsertionPointToStart(&func.getBody().front()); + + auto loc = builder.getUnknownLoc(); + llvm::SmallVector new_args; + auto process_arg = [&](mlir::Type type) + { + if (auto memref_type = type.dyn_cast()) + { + new_args.clear(); + auto arr_type = get_array_type(type_helper.get_type_converter(), memref_type); + flatten_type(arr_type, [&](mlir::Type new_type) + { + new_args.push_back(add_arg(new_type)); + }); + auto it = new_args.begin(); + mlir::Value desc = unflatten(arr_type, loc, builder, [&]() + { + auto ret = *it; + ++it; + return ret; + }); + + auto mod = mlir::cast(func.getParentOp()); + auto dst_type = type_helper.get_type_converter().convertType(memref_type); + assert(dst_type); + auto conv_func = conversion_cache.get_conversion_func(mod, builder, memref_type, arr_type, dst_type.cast()); + auto converted = builder.create(loc, conv_func, desc).getResult(0); + auto casted = builder.create(loc, memref_type, converted); + func.getBody().getArgument(index).replaceAllUsesWith(casted); + func.getBody().eraseArgument(index); + } + else + { + args.push_back(type); + ++index; + } }; add_arg(ptr(old_type.getResult(0))); add_arg(ptr(ptr(getExceptInfoType(type_helper)))); auto old_args = old_type.getInputs(); - std::copy(old_args.begin(), old_args.end(), std::back_inserter(args)); +// std::copy(old_args.begin(), old_args.end(), std::back_inserter(args)); + for (auto arg : old_args) + { + process_arg(arg); + } auto ret_type = mlir::IntegerType::get(32, &ctx); func.setType(mlir::FunctionType::get(args, ret_type, &ctx)); } From cd82f88bd338ed18f93ccee6245c399dd5ef1c10 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 28 Oct 2020 13:28:40 +0300 Subject: [PATCH 118/259] to llvm pass --- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 46 +++++++++++++++++-- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index c91e73617dc..3afa01e17fb 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -245,13 +245,10 @@ struct MemRefConversionCache llvm::StringRef get_linkage(mlir::Operation* op) { assert(nullptr != op); - llvm::errs() << "get_linkage 1\n"; if (auto attr = op->getAttr(linkage_attr).dyn_cast_or_null()) { - llvm::errs() << "get_linkage 2\n"; return attr.getValue(); } - llvm::errs() << "get_linkage 3\n"; return {}; } @@ -496,11 +493,52 @@ struct PostLLVMLowering : } }; +// Copypasted from mlir +struct LLVMLoweringPass : public mlir::PassWrapper> { + LLVMLoweringPass(const mlir::LowerToLLVMOptions& opts): + options(opts) {} + + /// Run the dialect converter on the module. + void runOnOperation() override { + using namespace mlir; + if (options.useBarePtrCallConv && options.emitCWrappers) { + getOperation().emitError() + << "incompatible conversion options: bare-pointer calling convention " + "and C wrapper emission"; + signalPassFailure(); + return; + } + if (failed(LLVM::LLVMDialect::verifyDataLayoutString( + options.dataLayout.getStringRepresentation(), [this](const Twine &message) { + getOperation().emitError() << message.str(); + }))) { + signalPassFailure(); + return; + } + + ModuleOp m = getOperation(); + + LLVMTypeConverter typeConverter(&getContext(), options); + + OwningRewritePatternList patterns; + populateStdToLLVMConversionPatterns(typeConverter, patterns); + + LLVMConversionTarget target(getContext()); + if (failed(applyPartialConversion(m, target, patterns))) + signalPassFailure(); + m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), + StringAttr::get(options.dataLayout.getStringRepresentation(), m.getContext())); + } + +private: + mlir::LowerToLLVMOptions options; +}; + void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); pm.addPass(std::make_unique()); - pm.addPass(mlir::createLowerToLLVMPass(getLLVMOptions())); + pm.addPass(std::make_unique(getLLVMOptions())); pm.addPass(std::make_unique()); } } From 833d247a5e29bf4fd9045f2276b8fb3bd3f8e03d Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 28 Oct 2020 13:57:24 +0300 Subject: [PATCH 119/259] lower temp casts, add test --- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 22 +++++++++++++++++++ numba/mlir/tests/test_numpy.py | 19 ++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 numba/mlir/tests/test_numpy.py diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 3afa01e17fb..9ae985806ee 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -493,6 +493,27 @@ struct PostLLVMLowering : } }; +struct LowerCasts : public mlir::OpConversionPattern +{ + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(plier::CastOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1); + auto converter = getTypeConverter(); + assert(nullptr != converter); + auto src_type = operands[0].getType(); + auto dst_type = converter->convertType(op.getType()); + if (src_type == dst_type) + { + rewriter.replaceOp(op, operands[0]); + return mlir::success(); + } + return mlir::failure(); + } +}; + // Copypasted from mlir struct LLVMLoweringPass : public mlir::PassWrapper> { LLVMLoweringPass(const mlir::LowerToLLVMOptions& opts): @@ -522,6 +543,7 @@ struct LLVMLoweringPass : public mlir::PassWrapper(typeConverter, &getContext()); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, patterns))) diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py new file mode 100644 index 00000000000..28b1949df54 --- /dev/null +++ b/numba/mlir/tests/test_numpy.py @@ -0,0 +1,19 @@ +import numba +from numba import njit +from numpy.testing import assert_equal # for nans comparison +import numpy as np +from numba.tests.support import TestCase +import unittest + +class TestMlirBasic(TestCase): + + def test_sum(self): + def py_func(a): + return a.sum() + + jit_func = njit(py_func) + arr = np.asarray([1,2,3]) + assert_equal(py_func(arr), jit_func(arr)) + +if __name__ == '__main__': + unittest.main() From 1e2dfaff74b758b05c1935016344c130169cbe12 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 28 Oct 2020 21:55:16 +0300 Subject: [PATCH 120/259] refactor CallOpLowering --- .../src/pipelines/plier_to_linalg.cpp | 35 ++++++------------- 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 34ad5385941..3b160a2a0ec 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -83,20 +83,19 @@ mlir::Type map_plier_type(mlir::TypeConverter& converter, mlir::Type type) llvm::StringRef extract_bound_func_name(llvm::StringRef name) { + assert(!name.empty()); auto len = name.find(' '); return name.substr(0, len); } struct CallOpLowering : public mlir::OpRewritePattern { - using check_t = mlir::LogicalResult(*)(llvm::StringRef, llvm::ArrayRef); - using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::StringRef, llvm::ArrayRef, mlir::PatternRewriter&); - using resolver_t = std::pair; + using resolver_t = llvm::function_ref, mlir::PatternRewriter&)>; CallOpLowering(mlir::TypeConverter &/*typeConverter*/, mlir::MLIRContext *context, - llvm::ArrayRef resolvers): - OpRewritePattern(context), resolvers(resolvers) {} + resolver_t resolver): + OpRewritePattern(context), resolver(resolver) {} mlir::LogicalResult matchAndRewrite( plier::PyCallOp op, mlir::PatternRewriter &rewriter) const override @@ -138,31 +137,23 @@ struct CallOpLowering : public mlir::OpRewritePattern { return mlir::failure(); } - for (auto& c : resolvers) - { - if (mlir::succeeded(c.first(name, arg_types))) - { - return c.second(op, name, args, rewriter); - } - } - return mlir::failure(); + return resolver(op, name, args, rewriter); } private: - llvm::ArrayRef resolvers; + resolver_t resolver; }; -mlir::LogicalResult numpy_check(llvm::StringRef name, llvm::ArrayRef types) -{ - return mlir::success(name == "array.sum"); // TODO -} - mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { assert(args.size() == 1); + if (name != "array.sum") + { + return mlir::failure(); // TODO + } mlir::Value inputs[] = { args[0] }; // auto elem_type = inputs[0].getType().cast().getElementType(); auto elem_type = mlir::IntegerType::get(64, op.getContext()); @@ -228,13 +219,9 @@ void PlierToLinalgPass::runOnOperation() FuncOpSignatureConversion >(type_converter, &getContext()); - CallOpLowering::resolver_t resolvers[] = { - {numpy_check, numpy_rewrite} - }; - patterns.insert< CallOpLowering - >(type_converter, &getContext(), resolvers); + >(type_converter, &getContext(), &numpy_rewrite); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } From dc317a25262bc7987455a8901f82dff043d1038c Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 28 Oct 2020 22:02:36 +0300 Subject: [PATCH 121/259] move call lowering to separate file --- mlir-compiler/CMakeLists.txt | 2 + .../src/pipelines/plier_to_linalg.cpp | 65 +------------------ mlir-compiler/src/rewrites/call_lowering.cpp | 59 +++++++++++++++++ mlir-compiler/src/rewrites/call_lowering.hpp | 25 +++++++ 4 files changed, 87 insertions(+), 64 deletions(-) create mode 100644 mlir-compiler/src/rewrites/call_lowering.cpp create mode 100644 mlir-compiler/src/rewrites/call_lowering.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index c5b84dbfa64..baa7e836c74 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -24,6 +24,7 @@ set(SOURCES_LIST src/pipelines/lower_to_llvm.cpp src/pipelines/plier_to_linalg.cpp src/pipelines/plier_to_std.cpp + src/rewrites/call_lowering.cpp src/rewrites/type_conversion.cpp src/compiler.cpp src/dialect.cpp @@ -39,6 +40,7 @@ set(HEADERS_LIST src/pipelines/lower_to_llvm.hpp src/pipelines/plier_to_linalg.hpp src/pipelines/plier_to_std.hpp + src/rewrites/call_lowering.hpp src/rewrites/type_conversion.hpp src/compiler.hpp src/lowering.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 3b160a2a0ec..ab39d50bbfc 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -15,6 +15,7 @@ #include "plier/dialect.hpp" #include "pipelines/plier_to_std.hpp" +#include "rewrites/call_lowering.hpp" #include "rewrites/type_conversion.hpp" #include "base_pipeline.hpp" @@ -81,70 +82,6 @@ mlir::Type map_plier_type(mlir::TypeConverter& converter, mlir::Type type) return nullptr; } -llvm::StringRef extract_bound_func_name(llvm::StringRef name) -{ - assert(!name.empty()); - auto len = name.find(' '); - return name.substr(0, len); -} - -struct CallOpLowering : public mlir::OpRewritePattern -{ - using resolver_t = llvm::function_ref, mlir::PatternRewriter&)>; - - CallOpLowering(mlir::TypeConverter &/*typeConverter*/, - mlir::MLIRContext *context, - resolver_t resolver): - OpRewritePattern(context), resolver(resolver) {} - - mlir::LogicalResult matchAndRewrite( - plier::PyCallOp op, mlir::PatternRewriter &rewriter) const override - { - auto operands = op.getOperands(); - if (operands.empty()) - { - return mlir::failure(); - } - auto func_type = operands[0].getType(); - if (!func_type.isa()) - { - return mlir::failure(); - } - auto name = func_type.cast().getName(); - llvm::SmallVector arg_types; - llvm::SmallVector args; - if (name.consume_front("Function(") && name.consume_back(")")) - { - llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); - llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); - // TODO kwargs - } - else if (name.consume_front("BoundFunction(") && name.consume_back(")")) - { - auto getattr = mlir::dyn_cast(operands[0].getDefiningOp()); - if (!getattr) - { - return mlir::failure(); - } - arg_types.push_back(getattr.getOperand().getType()); - args.push_back(getattr.getOperand()); - llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); - llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); - name = extract_bound_func_name(name); - // TODO kwargs - } - else - { - return mlir::failure(); - } - - return resolver(op, name, args, rewriter); - } - -private: - resolver_t resolver; -}; - mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) diff --git a/mlir-compiler/src/rewrites/call_lowering.cpp b/mlir-compiler/src/rewrites/call_lowering.cpp new file mode 100644 index 00000000000..65f308fee0b --- /dev/null +++ b/mlir-compiler/src/rewrites/call_lowering.cpp @@ -0,0 +1,59 @@ +#include "call_lowering.hpp" + +namespace +{ +llvm::StringRef extract_bound_func_name(llvm::StringRef name) +{ + assert(!name.empty()); + auto len = name.find(' '); + return name.substr(0, len); +} +} + +CallOpLowering::CallOpLowering( + mlir::TypeConverter&, mlir::MLIRContext* context, + CallOpLowering::resolver_t resolver): + OpRewritePattern(context), resolver(resolver) {} + +mlir::LogicalResult CallOpLowering::matchAndRewrite(plier::PyCallOp op, mlir::PatternRewriter& rewriter) const +{ + auto operands = op.getOperands(); + if (operands.empty()) + { + return mlir::failure(); + } + auto func_type = operands[0].getType(); + if (!func_type.isa()) + { + return mlir::failure(); + } + auto name = func_type.cast().getName(); + llvm::SmallVector arg_types; + llvm::SmallVector args; + if (name.consume_front("Function(") && name.consume_back(")")) + { + llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); + llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); + // TODO kwargs + } + else if (name.consume_front("BoundFunction(") && name.consume_back(")")) + { + auto getattr = mlir::dyn_cast(operands[0].getDefiningOp()); + if (!getattr) + { + return mlir::failure(); + } + arg_types.push_back(getattr.getOperand().getType()); + args.push_back(getattr.getOperand()); + llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); + llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); + name = extract_bound_func_name(name); + // TODO kwargs + } + else + { + return mlir::failure(); + } + + return resolver(op, name, args, rewriter); +} diff --git a/mlir-compiler/src/rewrites/call_lowering.hpp b/mlir-compiler/src/rewrites/call_lowering.hpp new file mode 100644 index 00000000000..0119aad450d --- /dev/null +++ b/mlir-compiler/src/rewrites/call_lowering.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "plier/dialect.hpp" + +#include + +namespace mlir +{ +class TypeConverter; +} + +struct CallOpLowering : public mlir::OpRewritePattern +{ + using resolver_t = llvm::function_ref, mlir::PatternRewriter&)>; + + CallOpLowering(mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + resolver_t resolver); + + mlir::LogicalResult matchAndRewrite( + plier::PyCallOp op, mlir::PatternRewriter &rewriter) const override; + +private: + resolver_t resolver; +}; From 058461dbdecd9449e6d347f74383a1f8aa8666dd Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 28 Oct 2020 22:14:29 +0300 Subject: [PATCH 122/259] refactor std call lowering --- mlir-compiler/src/pipelines/plier_to_std.cpp | 59 ++++++-------------- 1 file changed, 17 insertions(+), 42 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 233dced6f70..b0578780307 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -11,6 +11,7 @@ #include "plier/dialect.hpp" +#include "rewrites/call_lowering.hpp" #include "rewrites/type_conversion.hpp" #include "base_pipeline.hpp" @@ -534,14 +535,13 @@ struct BinOpLowering : public mlir::OpRewritePattern } }; -mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) { - auto operands = op.getOperands(); - if (operands.size() != 2) + if (operands.size() != 1) { return mlir::failure(); } - auto val = operands[1]; + auto val = operands[0]; bool success = false; auto replace_op = [&](mlir::Value val) { @@ -559,45 +559,17 @@ mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, mlir::PatternRewriter& r return mlir::success(success); } -using call_lowerer_func_t = mlir::LogicalResult(*)(plier::PyCallOp, mlir::PatternRewriter&); -const constexpr std::pair builtin_calls[] = { - {"", &lower_bool_cast}, -}; - -struct CallOpLowering : public mlir::OpRewritePattern +mlir::LogicalResult basic_rewrite( + plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, + mlir::PatternRewriter& rewriter) { - CallOpLowering(mlir::TypeConverter &/*typeConverter*/, - mlir::MLIRContext *context): - OpRewritePattern(context) {} - - mlir::LogicalResult matchAndRewrite( - plier::PyCallOp op, mlir::PatternRewriter &rewriter) const override + assert(args.size() == 1); + if (name == "") { - auto operands = op.getOperands(); - if (operands.empty()) - { - return mlir::failure(); - } - auto func_type = operands[0].getType(); - if (!func_type.isa()) - { - return mlir::failure(); - } - auto name = func_type.cast().getName(); - if (!name.consume_front("Function(") || !name.consume_back(")")) - { - return mlir::failure(); - } - for (auto& c : builtin_calls) - { - if (c.first == name) - { - return c.second(op, rewriter); - } - } - return mlir::failure(); + return lower_bool_cast(op, args, rewriter); } -}; + return mlir::failure(); +} struct CastOpLowering : public mlir::OpRewritePattern { @@ -751,10 +723,13 @@ void PlierToStdPass::runOnOperation() SelectOpLowering, CondBrOpLowering, CastOpLowering, - BinOpLowering, - CallOpLowering + BinOpLowering >(type_converter, &getContext()); + patterns.insert< + CallOpLowering + >(type_converter, &getContext(), &basic_rewrite); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } From e07ac56524279389aee8b68d496e840c6665bade Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 123/259] fix --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 1 - mlir-compiler/src/pipelines/plier_to_std.cpp | 1 - 2 files changed, 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index ab39d50bbfc..fb006c98b7c 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -86,7 +86,6 @@ mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { - assert(args.size() == 1); if (name != "array.sum") { return mlir::failure(); // TODO diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index b0578780307..567fef6f6d9 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -563,7 +563,6 @@ mlir::LogicalResult basic_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { - assert(args.size() == 1); if (name == "") { return lower_bool_cast(op, args, rewriter); From 10854ed0ef0d9a99a500af717b96823562cfa692 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 124/259] refac --- .../src/pipelines/plier_to_linalg.cpp | 75 ++++++++++++------- 1 file changed, 46 insertions(+), 29 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index fb006c98b7c..b74dfdbf336 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -82,41 +82,58 @@ mlir::Type map_plier_type(mlir::TypeConverter& converter, mlir::Type type) return nullptr; } +bool check_numpy_args(llvm::ArrayRef args, unsigned expected_count) +{ + if (args.size() != expected_count) + { + return false; + } + for (auto arg : args) + { + auto type = arg.getType(); + if (!type.isa() && !type.isa()) + { + return false; + } + } + return true; +} + mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { - if (name != "array.sum") + if (name == "array.sum" && check_numpy_args(args, 1)) { - return mlir::failure(); // TODO + mlir::Value inputs[] = { args[0] }; + // auto elem_type = inputs[0].getType().cast().getElementType(); + auto elem_type = mlir::IntegerType::get(64, op.getContext()); + auto res_type = mlir::MemRefType::get({}, elem_type); + auto loc = op.getLoc(); + mlir::Value outputs[] = { rewriter.create(loc, res_type) }; + mlir::AffineMap map[] = { + mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), + mlir::AffineMap::get(1, 0, op.getContext()), + }; + mlir::StringRef iterators[] = { "reduction" }; + auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + { + assert(args.size() == 2); + auto val = builder.create(loc, args[0], elem_type); + mlir::Value res = builder.create(loc, val, args[1]); + builder.create(loc, res); + }; + rewriter.create( + loc, + llvm::makeArrayRef(inputs), + llvm::makeArrayRef(outputs), + llvm::makeArrayRef(map), + llvm::makeArrayRef(iterators), + body); + mlir::Value res = rewriter.create(loc, outputs[0]); + rewriter.replaceOp(op, res); + return mlir::success(); } - mlir::Value inputs[] = { args[0] }; -// auto elem_type = inputs[0].getType().cast().getElementType(); - auto elem_type = mlir::IntegerType::get(64, op.getContext()); - auto res_type = mlir::MemRefType::get({}, elem_type); - auto loc = op.getLoc(); - mlir::Value outputs[] = { rewriter.create(loc, res_type) }; - mlir::AffineMap map[] = { - mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), - mlir::AffineMap::get(1, 0, op.getContext()), - }; - mlir::StringRef iterators[] = { "reduction" }; - auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) - { - assert(args.size() == 2); - auto val = builder.create(loc, args[0], elem_type); - mlir::Value res = builder.create(loc, val, args[1]); - builder.create(loc, res); - }; - rewriter.create( - loc, - llvm::makeArrayRef(inputs), - llvm::makeArrayRef(outputs), - llvm::makeArrayRef(map), - llvm::makeArrayRef(iterators), - body); - mlir::Value res = rewriter.create(loc, outputs[0]); - rewriter.replaceOp(op, res); return mlir::failure(); } From c69dcb92fdcae8bbf6d8991b4e0dd92571338f10 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 125/259] numpy add --- .../src/pipelines/plier_to_linalg.cpp | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index b74dfdbf336..2dfcd57afd8 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -103,6 +103,49 @@ mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { + if (name == "" && check_numpy_args(args, 2)) + { + mlir::Value inputs[] = { args[0], args[1] }; + auto elem_type = args[0].getType().cast().getElementType(); + mlir::Type res_type = mlir::RankedTensorType::get({-1}, elem_type); + auto loc = op.getLoc(); + mlir::AffineMap map[] = { + mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), + mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), + mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), + }; + mlir::StringRef iterators[] = { "parallel" }; + +// mlir::Value size = rewriter.create(loc, args[0], 0); +// mlir::Value init = rewriter.create( +// loc, +// res_type, +// mlir::ValueRange(size), +// [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) +// { +// assert(args.size() == 1); +// auto val = builder.create(loc, mlir::IntegerAttr::get(elem_type, 0)); +// builder.create(loc, val); +// }); + + auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + { + assert(args.size() == 2); + mlir::Value res = builder.create(loc, args[0], args[1]); + builder.create(loc, res); + }; + auto res = rewriter.create( + loc, + mlir::TypeRange(res_type), + mlir::ValueRange(inputs), + llvm::None, // outputs + llvm::None, // init, + llvm::makeArrayRef(map), + llvm::makeArrayRef(iterators), + body).getResult(0); + rewriter.replaceOp(op, res); + return mlir::success(); + } if (name == "array.sum" && check_numpy_args(args, 1)) { mlir::Value inputs[] = { args[0] }; From a25d6240239e64acbd1f97629f9f745fed6725cd Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 126/259] somme work --- .../src/pipelines/plier_to_linalg.cpp | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 2dfcd57afd8..97d97c12d1b 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -107,8 +108,10 @@ mlir::LogicalResult numpy_rewrite( { mlir::Value inputs[] = { args[0], args[1] }; auto elem_type = args[0].getType().cast().getElementType(); - mlir::Type res_type = mlir::RankedTensorType::get({-1}, elem_type); + mlir::Type res_type = mlir::MemRefType::get({-1}, elem_type); auto loc = op.getLoc(); + mlir::Value size = rewriter.create(loc, args[0], 0); + mlir::Value outputs[] = { rewriter.create(loc, res_type, size) }; mlir::AffineMap map[] = { mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), @@ -130,20 +133,18 @@ mlir::LogicalResult numpy_rewrite( auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) { - assert(args.size() == 2); + assert(args.size() == 3); mlir::Value res = builder.create(loc, args[0], args[1]); builder.create(loc, res); }; - auto res = rewriter.create( + rewriter.create( loc, - mlir::TypeRange(res_type), mlir::ValueRange(inputs), - llvm::None, // outputs - llvm::None, // init, + mlir::ValueRange(outputs), llvm::makeArrayRef(map), llvm::makeArrayRef(iterators), - body).getResult(0); - rewriter.replaceOp(op, res); + body); + rewriter.replaceOp(op, outputs[0]); return mlir::success(); } if (name == "array.sum" && check_numpy_args(args, 1)) @@ -168,8 +169,8 @@ mlir::LogicalResult numpy_rewrite( }; rewriter.create( loc, - llvm::makeArrayRef(inputs), - llvm::makeArrayRef(outputs), + mlir::ValueRange(inputs), + mlir::ValueRange(outputs), llvm::makeArrayRef(map), llvm::makeArrayRef(iterators), body); From a90b9c23dd297564297f0554688223d3a5399dc7 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 127/259] plier getitem --- mlir-compiler/include/plier/PlierOps.td | 13 +++++++++++++ mlir-compiler/src/dialect.cpp | 7 +++++++ mlir-compiler/src/lowering.cpp | 8 ++++++++ 3 files changed, 28 insertions(+) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 9ee6eade8fb..4cba5de5dab 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -100,6 +100,19 @@ def BuildTupleOp : Plier_Op<"build_tuple", [NoSideEffect]> { ]; } +def GetItemOp : Plier_Op<"getitem", []> { + let arguments = (ins + AnyType:$value, + AnyType:$index); + + let results = (outs AnyType); + + let builders = [ + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value, " + "::mlir::Value index"> +]; +} + def StaticGetItemOp : Plier_Op<"static_getitem", []> { let arguments = (ins AnyType:$value, diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index b0387b060f3..528dc912c80 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -159,6 +159,13 @@ void BuildTupleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, // return mlir::failure(); //} +void GetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + ::mlir::Value value, ::mlir::Value index) +{ + GetItemOp::build(builder, state, + PyType::getUndefined(state.getContext()), value, index); +} + void StaticGetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value, ::mlir::Value index_var, unsigned int index) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 0e9efbaef53..1ca7618ae31 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -286,6 +286,7 @@ struct plier_lowerer {"call", &plier_lowerer::lower_call}, {"phi", &plier_lowerer::lower_phi}, {"build_tuple", &plier_lowerer::lower_build_tuple}, + {"getitem", &plier_lowerer::lower_getitem}, {"static_getitem", &plier_lowerer::lower_static_getitem}, {"getiter", &plier_lowerer::lower_simple}, {"iternext", &plier_lowerer::lower_simple}, @@ -317,6 +318,13 @@ struct plier_lowerer return builder.create(get_current_loc(), res_type, value); } + mlir::Value lower_getitem(const py::handle& inst) + { + auto value = loadvar(inst.attr("value")); + auto index = loadvar(inst.attr("index")); + return builder.create(get_current_loc(), value, index); + } + mlir::Value lower_static_getitem(const py::handle& inst) { auto value = loadvar(inst.attr("value")); From d6faf60f99934324b84add068e46e3a4d69fa506 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 128/259] simple getitem --- .../src/pipelines/plier_to_linalg.cpp | 33 +++++++++++++++++++ numba/mlir/tests/test_numpy.py | 10 ++++++ 2 files changed, 43 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 97d97c12d1b..a4e931a3ef6 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -181,6 +181,35 @@ mlir::LogicalResult numpy_rewrite( return mlir::failure(); } +struct GetitemOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + plier::GetItemOp op, mlir::PatternRewriter &rewriter) const override + { + assert(op.getNumOperands() == 2); + auto val = op.getOperand(0); + auto index = op.getOperand(1); + if (!val.getType().isa()) + { + return mlir::failure(); + } + if (!index.getType().isa() && !index.getType().isa()) + { + return mlir::failure(); + } + auto loc = op.getLoc(); + if (index.getType().isa()) + { + index = rewriter.create(loc, index, mlir::IndexType::get(op.getContext())); + } + mlir::Value res = rewriter.create(loc, val, index); + rewriter.replaceOp(op, res); + return mlir::success(); + } +}; + struct PlierToLinalgPass : public mlir::PassWrapper> { @@ -220,6 +249,10 @@ void PlierToLinalgPass::runOnOperation() CallOpLowering >(type_converter, &getContext(), &numpy_rewrite); + patterns.insert< + GetitemOpLowering + >(&getContext()); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); } diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 28b1949df54..2e4f70d950d 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -7,6 +7,16 @@ class TestMlirBasic(TestCase): + def test_getitem(self): + def py_func(a, b): + return a[b] + + jit_func = njit(py_func) + arr = np.asarray([5,6,7]) + for i in range(3): + assert_equal(py_func(arr, i), jit_func(arr, i)) + + @unittest.skip def test_sum(self): def py_func(a): return a.sum() From 3ef8fe093682a4094599fb561c2c101299aba859 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 129/259] fix --- mlir-compiler/src/rewrites/call_lowering.hpp | 2 +- numba/mlir/tests/test_numpy.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir-compiler/src/rewrites/call_lowering.hpp b/mlir-compiler/src/rewrites/call_lowering.hpp index 0119aad450d..cafee219fc4 100644 --- a/mlir-compiler/src/rewrites/call_lowering.hpp +++ b/mlir-compiler/src/rewrites/call_lowering.hpp @@ -11,7 +11,7 @@ class TypeConverter; struct CallOpLowering : public mlir::OpRewritePattern { - using resolver_t = llvm::function_ref, mlir::PatternRewriter&)>; + using resolver_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::StringRef, llvm::ArrayRef, mlir::PatternRewriter&); CallOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 2e4f70d950d..9d5bad98f73 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -16,7 +16,6 @@ def py_func(a, b): for i in range(3): assert_equal(py_func(arr, i), jit_func(arr, i)) - @unittest.skip def test_sum(self): def py_func(a): return a.sum() From cd976fd7ccae83207a2611929802c57ac40b6a41 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 130/259] numpy static getitem --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 10 ++++++---- numba/mlir/tests/test_numpy.py | 8 ++++++++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index a4e931a3ef6..3f7d4eabe2b 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -181,12 +181,13 @@ mlir::LogicalResult numpy_rewrite( return mlir::failure(); } -struct GetitemOpLowering : public mlir::OpRewritePattern +template +struct GetitemOpLowering : public mlir::OpRewritePattern { - using mlir::OpRewritePattern::OpRewritePattern; + using mlir::OpRewritePattern::OpRewritePattern; mlir::LogicalResult matchAndRewrite( - plier::GetItemOp op, mlir::PatternRewriter &rewriter) const override + T op, mlir::PatternRewriter &rewriter) const override { assert(op.getNumOperands() == 2); auto val = op.getOperand(0); @@ -250,7 +251,8 @@ void PlierToLinalgPass::runOnOperation() >(type_converter, &getContext(), &numpy_rewrite); patterns.insert< - GetitemOpLowering + GetitemOpLowering, + GetitemOpLowering >(&getContext()); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 9d5bad98f73..d20044d7dc1 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -7,6 +7,14 @@ class TestMlirBasic(TestCase): + def test_staticgetitem(self): + def py_func(a): + return a[1] + + jit_func = njit(py_func) + arr = np.asarray([5,6,7]) + assert_equal(py_func(arr), jit_func(arr)) + def test_getitem(self): def py_func(a, b): return a[b] From f1c549e6ca45784442a79f5e94b7ebb88820a6e4 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 131/259] fix --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 3f7d4eabe2b..dfb5929848a 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -192,16 +192,17 @@ struct GetitemOpLowering : public mlir::OpRewritePattern assert(op.getNumOperands() == 2); auto val = op.getOperand(0); auto index = op.getOperand(1); - if (!val.getType().isa()) + if (!val.getType().template isa()) { return mlir::failure(); } - if (!index.getType().isa() && !index.getType().isa()) + if (!index.getType().template isa() && + !index.getType().template isa()) { return mlir::failure(); } auto loc = op.getLoc(); - if (index.getType().isa()) + if (index.getType().template isa()) { index = rewriter.create(loc, index, mlir::IndexType::get(op.getContext())); } From ccfc055d8ed9da8ce7bb65146d580abf45c8d508 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 29 Oct 2020 13:26:17 +0300 Subject: [PATCH 132/259] fix uninitialized val --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 16 ++++++++++++++++ mlir-compiler/test.py | 16 ++++++++++++++++ numba/mlir/tests/test_numpy.py | 9 +++++++++ 3 files changed, 41 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index dfb5929848a..7a204a497a1 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -100,6 +100,20 @@ bool check_numpy_args(llvm::ArrayRef args, unsigned expected_count) return true; } +mlir::Attribute get_zero(mlir::Type type) +{ + assert(type); + if (auto int_type = type.dyn_cast()) + { + return mlir::IntegerAttr::get(type, 0); + } + if (auto float_type = type.dyn_cast()) + { + return mlir::FloatAttr::get(type, 0.0); + } + llvm_unreachable("get_zero: usupported type"); +} + mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) @@ -155,6 +169,8 @@ mlir::LogicalResult numpy_rewrite( auto res_type = mlir::MemRefType::get({}, elem_type); auto loc = op.getLoc(); mlir::Value outputs[] = { rewriter.create(loc, res_type) }; + auto zero = rewriter.create(loc, get_zero(elem_type)); + rewriter.create(loc, zero, outputs[0]); mlir::AffineMap map[] = { mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), mlir::AffineMap::get(1, 0, op.getContext()), diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 44419e78f33..09f792aed8b 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -57,6 +57,18 @@ def range_loop(n): res = res + i return res +def np_getitem(a, b): + return a[b] + +def np_getitem2(a, b, c): + return b[c] + +def np_sum(a): + return a.sum() + +def np_add(a, b): + return np.add(a, b).sum() + def test(func, params): global _tests_total global _tests_passes @@ -94,6 +106,10 @@ def test(func, params): test(arr_loop, ()) test(range_loop, (8,)) test(sum2, (np.asarray([1,2,3]),np.asarray([4,5,6]))) +test(np_getitem, (np.asarray([1,2,3]),1)) +test(np_getitem2, (np.asarray([1,2,3]),np.asarray([4,5,6]),1)) +test(np_sum, (np.asarray([1,2,3]),)) +test(np_add, (np.asarray([1,2,3]),np.asarray([4,5,6]))) print(f'Tests passed: {_tests_passes}/{_tests_total}') if (len(_failed_tests) != 0): diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index d20044d7dc1..02f5444372c 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -32,5 +32,14 @@ def py_func(a): arr = np.asarray([1,2,3]) assert_equal(py_func(arr), jit_func(arr)) + def test_sum_add(self): + def py_func(a, b): + return np.add(a, b).sum() + + jit_func = njit(py_func) + arr1 = np.asarray([1,2,3]) + arr2 = np.asarray([4,5,6]) + assert_equal(py_func(arr1, arr2), jit_func(arr1, arr2)) + if __name__ == '__main__': unittest.main() From 7e77b8233846744b1a06d7f03c642f39ffd1545e Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 4 Nov 2020 17:23:33 +0300 Subject: [PATCH 133/259] refactor cast --- mlir-compiler/src/pipelines/plier_to_std.cpp | 34 ++++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 567fef6f6d9..0d24605f136 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -445,7 +445,7 @@ mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& } } - llvm_unreachable("Unhandled cast"); + return nullptr; } struct BinOpLowering : public mlir::OpRewritePattern @@ -572,31 +572,42 @@ mlir::LogicalResult basic_rewrite( struct CastOpLowering : public mlir::OpRewritePattern { + using cast_t = std::function; + CastOpLowering(mlir::TypeConverter &typeConverter, - mlir::MLIRContext *context): - OpRewritePattern(context), converter(typeConverter) {} + mlir::MLIRContext *context, + cast_t cast_func = nullptr): + OpRewritePattern(context), converter(typeConverter), + cast_func(std::move(cast_func)) {} mlir::LogicalResult matchAndRewrite( plier::CastOp op, mlir::PatternRewriter &rewriter) const override { - auto src_type = op.getOperand().getType(); + auto src = op.getOperand(); + auto src_type = src.getType(); auto dst_type = converter.convertType(op.getType()); - if (dst_type && is_supported_type(src_type) && is_supported_type(dst_type)) + if (dst_type) { if (src_type == dst_type) { - rewriter.replaceOp(op, op.getOperand()); + rewriter.replaceOp(op, src); return mlir::success(); } - auto new_op = do_cast(dst_type, op.getOperand(), rewriter); - rewriter.replaceOp(op, new_op); - return mlir::success(); + if (nullptr != cast_func) + { + if (auto new_op = cast_func(dst_type, src, rewriter)) + { + rewriter.replaceOp(op, new_op); + return mlir::success(); + } + } } return mlir::failure(); } private: mlir::TypeConverter& converter; + cast_t cast_func; }; mlir::Operation* change_op_ret_type(mlir::Operation* op, @@ -721,10 +732,13 @@ void PlierToStdPass::runOnOperation() ConstOpLowering, SelectOpLowering, CondBrOpLowering, - CastOpLowering, BinOpLowering >(type_converter, &getContext()); + patterns.insert< + CastOpLowering + >(type_converter, &getContext(), &do_cast); + patterns.insert< CallOpLowering >(type_converter, &getContext(), &basic_rewrite); From 644bafb59d622717aa03a93a97982e10b9c056d4 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 4 Nov 2020 17:28:14 +0300 Subject: [PATCH 134/259] move cast lowering to separate file --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/src/pipelines/plier_to_std.cpp | 41 +------------------- mlir-compiler/src/rewrites/cast_lowering.cpp | 34 ++++++++++++++++ mlir-compiler/src/rewrites/cast_lowering.hpp | 28 +++++++++++++ 4 files changed, 65 insertions(+), 40 deletions(-) create mode 100644 mlir-compiler/src/rewrites/cast_lowering.cpp create mode 100644 mlir-compiler/src/rewrites/cast_lowering.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index baa7e836c74..acca1296b1f 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -25,6 +25,7 @@ set(SOURCES_LIST src/pipelines/plier_to_linalg.cpp src/pipelines/plier_to_std.cpp src/rewrites/call_lowering.cpp + src/rewrites/cast_lowering.cpp src/rewrites/type_conversion.cpp src/compiler.cpp src/dialect.cpp @@ -41,6 +42,7 @@ set(HEADERS_LIST src/pipelines/plier_to_linalg.hpp src/pipelines/plier_to_std.hpp src/rewrites/call_lowering.hpp + src/rewrites/cast_lowering.hpp src/rewrites/type_conversion.hpp src/compiler.hpp src/lowering.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 0d24605f136..cf5c186ddd5 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -12,6 +12,7 @@ #include "plier/dialect.hpp" #include "rewrites/call_lowering.hpp" +#include "rewrites/cast_lowering.hpp" #include "rewrites/type_conversion.hpp" #include "base_pipeline.hpp" @@ -570,46 +571,6 @@ mlir::LogicalResult basic_rewrite( return mlir::failure(); } -struct CastOpLowering : public mlir::OpRewritePattern -{ - using cast_t = std::function; - - CastOpLowering(mlir::TypeConverter &typeConverter, - mlir::MLIRContext *context, - cast_t cast_func = nullptr): - OpRewritePattern(context), converter(typeConverter), - cast_func(std::move(cast_func)) {} - - mlir::LogicalResult matchAndRewrite( - plier::CastOp op, mlir::PatternRewriter &rewriter) const override - { - auto src = op.getOperand(); - auto src_type = src.getType(); - auto dst_type = converter.convertType(op.getType()); - if (dst_type) - { - if (src_type == dst_type) - { - rewriter.replaceOp(op, src); - return mlir::success(); - } - if (nullptr != cast_func) - { - if (auto new_op = cast_func(dst_type, src, rewriter)) - { - rewriter.replaceOp(op, new_op); - return mlir::success(); - } - } - } - return mlir::failure(); - } - -private: - mlir::TypeConverter& converter; - cast_t cast_func; -}; - mlir::Operation* change_op_ret_type(mlir::Operation* op, mlir::PatternRewriter& rewriter, llvm::ArrayRef types) diff --git a/mlir-compiler/src/rewrites/cast_lowering.cpp b/mlir-compiler/src/rewrites/cast_lowering.cpp new file mode 100644 index 00000000000..c2a4e2e0a57 --- /dev/null +++ b/mlir-compiler/src/rewrites/cast_lowering.cpp @@ -0,0 +1,34 @@ +#include "rewrites/cast_lowering.hpp" + +#include + +CastOpLowering::CastOpLowering( + mlir::TypeConverter& typeConverter, mlir::MLIRContext* context, + CastOpLowering::cast_t cast_func): + OpRewritePattern(context), converter(typeConverter), + cast_func(std::move(cast_func)) {} + +mlir::LogicalResult CastOpLowering::matchAndRewrite( + plier::CastOp op, mlir::PatternRewriter& rewriter) const +{ + auto src = op.getOperand(); + auto src_type = src.getType(); + auto dst_type = converter.convertType(op.getType()); + if (dst_type) + { + if (src_type == dst_type) + { + rewriter.replaceOp(op, src); + return mlir::success(); + } + if (nullptr != cast_func) + { + if (auto new_op = cast_func(dst_type, src, rewriter)) + { + rewriter.replaceOp(op, new_op); + return mlir::success(); + } + } + } + return mlir::failure(); +} diff --git a/mlir-compiler/src/rewrites/cast_lowering.hpp b/mlir-compiler/src/rewrites/cast_lowering.hpp new file mode 100644 index 00000000000..bb3fcc30dc0 --- /dev/null +++ b/mlir-compiler/src/rewrites/cast_lowering.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include + +#include "plier/dialect.hpp" + +#include + +namespace mlir +{ +class TypeConverter; +} + +struct CastOpLowering : public mlir::OpRewritePattern +{ + using cast_t = std::function; + + CastOpLowering(mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context, + cast_t cast_func = nullptr); + + mlir::LogicalResult matchAndRewrite( + plier::CastOp op, mlir::PatternRewriter &rewriter) const override; + +private: + mlir::TypeConverter& converter; + cast_t cast_func; +}; From d2484d2ce3a8d56bbac1f51717c8e33d17697adb Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 4 Nov 2020 17:29:36 +0300 Subject: [PATCH 135/259] fix --- mlir-compiler/src/rewrites/call_lowering.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/src/rewrites/call_lowering.hpp b/mlir-compiler/src/rewrites/call_lowering.hpp index cafee219fc4..93046a5795a 100644 --- a/mlir-compiler/src/rewrites/call_lowering.hpp +++ b/mlir-compiler/src/rewrites/call_lowering.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include "plier/dialect.hpp" #include @@ -11,7 +13,7 @@ class TypeConverter; struct CallOpLowering : public mlir::OpRewritePattern { - using resolver_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::StringRef, llvm::ArrayRef, mlir::PatternRewriter&); + using resolver_t = std::function, mlir::PatternRewriter&)>; CallOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, From df87b44eab12c127310dd8413d8e86205bc8c25a Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 4 Nov 2020 17:39:22 +0300 Subject: [PATCH 136/259] some work --- .../src/pipelines/plier_to_linalg.cpp | 8 +++++-- .../src/rewrites/type_conversion.cpp | 24 +++++++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 7a204a497a1..bcd832c01c3 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -5,9 +5,9 @@ #include #include #include +#include #include #include -#include #include #include #include @@ -17,6 +17,7 @@ #include "pipelines/plier_to_std.hpp" #include "rewrites/call_lowering.hpp" +#include "rewrites/cast_lowering.hpp" #include "rewrites/type_conversion.hpp" #include "base_pipeline.hpp" @@ -67,6 +68,7 @@ mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter, { llvm::SmallVector shape(num_dims, -1); return mlir::MemRefType::get(shape, type); +// return mlir::RankedTensorType::get(shape, type); } } return nullptr; @@ -260,7 +262,8 @@ void PlierToLinalgPass::runOnOperation() mlir::OwningRewritePatternList patterns; patterns.insert< - FuncOpSignatureConversion + FuncOpSignatureConversion, + CastOpLowering >(type_converter, &getContext()); patterns.insert< @@ -304,6 +307,7 @@ void LowerLinalgPass::runOnOperation() void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); + pm.addPass(mlir::createBufferPlacementPass()); pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToCFGPass()); } diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index 61ed1bdfc53..f261856d3f9 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -2,10 +2,13 @@ #include +#include "plier/dialect.hpp" + namespace { mlir::LogicalResult setBlockSig( - mlir::Block& block, const mlir::TypeConverter::SignatureConversion& conversion) + mlir::Block& block, mlir::OpBuilder& builder, + const mlir::TypeConverter::SignatureConversion& conversion) { if (conversion.getConvertedTypes().size() != block.getNumArguments()) { @@ -15,7 +18,16 @@ mlir::LogicalResult setBlockSig( { auto arg = std::get<0>(it); auto type = std::get<1>(it); - arg.setType(type); + if (arg.getType() != type) + { + builder.setInsertionPointToStart(&block); + auto res = builder.create(builder.getUnknownLoc(), arg.getType(), arg); + arg.replaceUsesWithIf(res, [&](mlir::OpOperand& op) + { + return op.getOwner() != res; + }); + arg.setType(type); + } } return mlir::success(); } @@ -23,18 +35,22 @@ mlir::LogicalResult setBlockSig( mlir::LogicalResult convertRegionTypes( mlir::Region *region, mlir::TypeConverter &converter, bool apply) { + assert(nullptr != region); if (region->empty()) { return mlir::failure(); } + mlir::OpBuilder builder(region->getContext()); + // Convert the arguments of each block within the region. auto sig = converter.convertBlockSignature(®ion->front()); assert(static_cast(sig)); if (apply) { - auto res = setBlockSig(region->front(), *sig); + auto res = setBlockSig(region->front(), builder, *sig); assert(mlir::succeeded(res)); + (void)res; } for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) { @@ -45,7 +61,7 @@ mlir::LogicalResult convertRegionTypes( } if (apply) { - if (mlir::failed(setBlockSig(block, *sig))) + if (mlir::failed(setBlockSig(block, builder, *sig))) { return mlir::failure(); } From d09b17f95132e03c6b135800bc0d436300d69b1b Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 5 Nov 2020 14:42:30 +0300 Subject: [PATCH 137/259] move to mlir master --- mlir-compiler/llvm-sha.txt | 1 + mlir-compiler/src/compiler.cpp | 6 +++-- mlir-compiler/src/lowering.cpp | 22 +++++++++---------- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 20 +++++------------ .../src/pipelines/plier_to_linalg.cpp | 7 +++--- mlir-compiler/src/pipelines/plier_to_std.cpp | 3 ++- 6 files changed, 28 insertions(+), 31 deletions(-) create mode 100644 mlir-compiler/llvm-sha.txt diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt new file mode 100644 index 00000000000..0a3ec6923fc --- /dev/null +++ b/mlir-compiler/llvm-sha.txt @@ -0,0 +1 @@ +2ce5b8f78c65a1aa4b1017807f59cc28e222040c diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index 640cebf0edb..6b0124ac245 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -19,9 +19,9 @@ class CompilerContext::CompilerContextImpl CompilerContextImpl(mlir::MLIRContext& ctx, const CompilerContext::Settings& settings, const PipelineRegistry& registry): - pm(&ctx, settings.verify) + pm(&ctx) { - registry.populate_pass_manager(pm); + pm.enableVerifier(settings.verify); if (settings.pass_statistics) { @@ -36,6 +36,8 @@ class CompilerContext::CompilerContextImpl ctx.enableMultithreading(false); pm.enableIRPrinting(); } + + registry.populate_pass_manager(pm); } void run(mlir::ModuleOp& module) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 1ca7618ae31..a21547d913d 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -41,15 +41,15 @@ std::string serialize_mod(const llvm::Module& mod) return ret; } -template -std::string to_str(T& obj) -{ - std::string ret; - llvm::raw_string_ostream stream(ret); - obj.print(stream); - stream.flush(); - return ret; -} +//template +//std::string to_str(T& obj) +//{ +// std::string ret; +// llvm::raw_string_ostream stream(ret); +// obj.print(stream); +// stream.flush(); +// return ret; +//} std::vector> get_blocks(const py::object& func) { @@ -610,8 +610,8 @@ void create_pipeline(PipelineRegistry& registry) py::bytes lower_function(const py::object& compilation_context, const py::object& func_ir) { - mlir::registerDialect(); - mlir::registerDialect(); +// mlir::registerDialect(); +// mlir::registerDialect(); mlir::MLIRContext context; auto mod = plier_lowerer(context).lower(compilation_context, func_ir); PipelineRegistry registry; diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 9ae985806ee..9f151f4629f 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -454,18 +455,13 @@ struct PreLLVMLowering : public mlir::PassWrapper(&getContext(), type_helper.get_type_converter()); - apply_conv(); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -481,15 +477,11 @@ struct PostLLVMLowering : void runOnFunction() override final { mlir::OwningRewritePatternList patterns; - auto apply_conv = [&]() - { - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); - }; // Remove redundant bitcasts we have created on PreLowering patterns.insert(&getContext()); - apply_conv(); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -546,7 +538,7 @@ struct LLVMLoweringPass : public mlir::PassWrapper(typeConverter, &getContext()); LLVMConversionTarget target(getContext()); - if (failed(applyPartialConversion(m, target, patterns))) + if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m.setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), StringAttr::get(options.dataLayout.getStringRepresentation(), m.getContext())); @@ -559,9 +551,9 @@ struct LLVMLoweringPass : public mlir::PassWrapper()); - pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); pm.addPass(std::make_unique(getLLVMOptions())); - pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); } } diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index bcd832c01c3..e4eb7f6863c 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "plier/dialect.hpp" @@ -275,7 +276,7 @@ void PlierToLinalgPass::runOnOperation() GetitemOpLowering >(&getContext()); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } struct LowerLinalgPass : @@ -301,13 +302,13 @@ void LowerLinalgPass::runOnOperation() (&getContext(), mlir::linalg::LinalgLoweringType::ParallelLoops); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); - pm.addPass(mlir::createBufferPlacementPass()); +// pm.addPass(mlir::createBufferPlacementPass()); pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToCFGPass()); } diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index cf5c186ddd5..c531a352092 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include @@ -704,7 +705,7 @@ void PlierToStdPass::runOnOperation() CallOpLowering >(type_converter, &getContext(), &basic_rewrite); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), patterns); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) From 3a0f09b6ae0c220822c114180bf1113469657ce5 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 5 Nov 2020 16:54:37 +0300 Subject: [PATCH 138/259] operate linalg on tensors --- .../src/pipelines/plier_to_linalg.cpp | 83 ++++++++++--------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index e4eb7f6863c..cdf7e332d5a 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -68,8 +69,8 @@ mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter, if (auto type = conveter.convertType(plier::PyType::get(&ctx, name))) { llvm::SmallVector shape(num_dims, -1); - return mlir::MemRefType::get(shape, type); -// return mlir::RankedTensorType::get(shape, type); +// return mlir::MemRefType::get(shape, type); + return mlir::RankedTensorType::get(shape, type); } } return nullptr; @@ -117,18 +118,29 @@ mlir::Attribute get_zero(mlir::Type type) llvm_unreachable("get_zero: usupported type"); } +mlir::Type get_elem_type(mlir::Type type) +{ + if (auto memref = type.dyn_cast()) + { + return memref.getElementType(); + } + if (auto tensor = type.dyn_cast()) + { + return tensor.getElementType(); + } + llvm_unreachable("get_elem_type: unknown type"); +} + mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { if (name == "" && check_numpy_args(args, 2)) { - mlir::Value inputs[] = { args[0], args[1] }; - auto elem_type = args[0].getType().cast().getElementType(); - mlir::Type res_type = mlir::MemRefType::get({-1}, elem_type); auto loc = op.getLoc(); - mlir::Value size = rewriter.create(loc, args[0], 0); - mlir::Value outputs[] = { rewriter.create(loc, res_type, size) }; + mlir::Value inputs[] = { args[0], args[1] }; + auto elem_type = get_elem_type(args[0].getType()); + mlir::Type res_type = mlir::RankedTensorType::get(-1, elem_type); mlir::AffineMap map[] = { mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), @@ -136,47 +148,35 @@ mlir::LogicalResult numpy_rewrite( }; mlir::StringRef iterators[] = { "parallel" }; -// mlir::Value size = rewriter.create(loc, args[0], 0); -// mlir::Value init = rewriter.create( -// loc, -// res_type, -// mlir::ValueRange(size), -// [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) -// { -// assert(args.size() == 1); -// auto val = builder.create(loc, mlir::IntegerAttr::get(elem_type, 0)); -// builder.create(loc, val); -// }); - auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) { - assert(args.size() == 3); + assert(args.size() == 2); mlir::Value res = builder.create(loc, args[0], args[1]); builder.create(loc, res); }; - rewriter.create( + auto res = rewriter.create( loc, + mlir::TypeRange(res_type), mlir::ValueRange(inputs), - mlir::ValueRange(outputs), + mlir::ValueRange(), // outputs + mlir::ValueRange(), // init llvm::makeArrayRef(map), llvm::makeArrayRef(iterators), - body); - rewriter.replaceOp(op, outputs[0]); + body).getResult(0); + rewriter.replaceOp(op, res); return mlir::success(); } if (name == "array.sum" && check_numpy_args(args, 1)) { + auto loc = op.getLoc(); mlir::Value inputs[] = { args[0] }; - // auto elem_type = inputs[0].getType().cast().getElementType(); auto elem_type = mlir::IntegerType::get(64, op.getContext()); - auto res_type = mlir::MemRefType::get({}, elem_type); - auto loc = op.getLoc(); - mlir::Value outputs[] = { rewriter.create(loc, res_type) }; - auto zero = rewriter.create(loc, get_zero(elem_type)); - rewriter.create(loc, zero, outputs[0]); + auto res_type = mlir::RankedTensorType::get(1, elem_type); + mlir::Value zero = rewriter.create(loc, get_zero(elem_type)); + mlir::Value init = rewriter.create(loc, zero); mlir::AffineMap map[] = { mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), - mlir::AffineMap::get(1, 0, op.getContext()), + mlir::AffineMap::get(1, 0, mlir::getAffineConstantExpr(0, op.getContext())), }; mlir::StringRef iterators[] = { "reduction" }; auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) @@ -186,14 +186,17 @@ mlir::LogicalResult numpy_rewrite( mlir::Value res = builder.create(loc, val, args[1]); builder.create(loc, res); }; - rewriter.create( + auto val = rewriter.create( loc, + mlir::TypeRange(res_type), mlir::ValueRange(inputs), - mlir::ValueRange(outputs), + mlir::ValueRange(), // outputs + mlir::ValueRange(init), llvm::makeArrayRef(map), llvm::makeArrayRef(iterators), - body); - mlir::Value res = rewriter.create(loc, outputs[0]); + body).getResult(0); + mlir::Value index = rewriter.create(loc, 0); + mlir::Value res = rewriter.create(loc, val, index); rewriter.replaceOp(op, res); return mlir::success(); } @@ -298,8 +301,10 @@ void LowerLinalgPass::runOnOperation() { mlir::OwningRewritePatternList patterns; - patterns.insert> - (&getContext(), mlir::linalg::LinalgLoweringType::ParallelLoops); + patterns.insert< + mlir::linalg::LinalgLoweringPattern, + mlir::linalg::LinalgLoweringPattern + >(&getContext(), mlir::linalg::LinalgLoweringType::ParallelLoops); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); @@ -308,7 +313,9 @@ void LowerLinalgPass::runOnOperation() void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); -// pm.addPass(mlir::createBufferPlacementPass()); + pm.addPass(mlir::createLinalgBufferizePass()); + pm.addNestedPass(mlir::createStdBufferizePass()); + pm.addPass(mlir::createFuncBufferizePass()); pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToCFGPass()); } From a4ee0ecadc283fce3dd17449a33b06c8528ef8c4 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 5 Nov 2020 20:28:09 +0300 Subject: [PATCH 139/259] getitem lowering with tensors --- .../src/pipelines/plier_to_linalg.cpp | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index cdf7e332d5a..ba3315c637f 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -214,7 +214,10 @@ struct GetitemOpLowering : public mlir::OpRewritePattern assert(op.getNumOperands() == 2); auto val = op.getOperand(0); auto index = op.getOperand(1); - if (!val.getType().template isa()) + auto type = val.getType(); + bool is_memref = type.template isa(); + bool is_tensor = type.template isa(); + if (!is_memref && !is_tensor) { return mlir::failure(); } @@ -228,7 +231,19 @@ struct GetitemOpLowering : public mlir::OpRewritePattern { index = rewriter.create(loc, index, mlir::IndexType::get(op.getContext())); } - mlir::Value res = rewriter.create(loc, val, index); + mlir::Value res; + if (is_memref) + { + res = rewriter.create(loc, val, index); + } + else if (is_tensor) + { + res = rewriter.create(loc, val, index); + } + else + { + llvm_unreachable("Invalid getitem"); + } rewriter.replaceOp(op, res); return mlir::success(); } @@ -313,9 +328,9 @@ void LowerLinalgPass::runOnOperation() void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); - pm.addPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::createStdBufferizePass()); pm.addPass(mlir::createFuncBufferizePass()); + pm.addPass(mlir::createLinalgBufferizePass()); pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToCFGPass()); } From 540e2812a2b1c7bf3867d48630fd2ef6c6830a6b Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 6 Nov 2020 02:35:16 +0300 Subject: [PATCH 140/259] update llvm --- mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 0a3ec6923fc..18732e73d11 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -2ce5b8f78c65a1aa4b1017807f59cc28e222040c +f253823398dd2894ee5d9333c541c534b7a407fb diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index ba3315c637f..a1156b0f7f9 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -328,9 +328,9 @@ void LowerLinalgPass::runOnOperation() void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); + pm.addPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::createStdBufferizePass()); pm.addPass(mlir::createFuncBufferizePass()); - pm.addPass(mlir::createLinalgBufferizePass()); pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToCFGPass()); } From 7f4ad2e1e470e707a6049ac0ecfda23eab41cf71 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 6 Nov 2020 02:43:58 +0300 Subject: [PATCH 141/259] move createLowerToCFGPass to low lowering --- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 2 ++ mlir-compiler/src/pipelines/plier_to_linalg.cpp | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 9f151f4629f..9a1fe776376 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -550,6 +551,7 @@ struct LLVMLoweringPass : public mlir::PassWrapper()); pm.addNestedPass(std::make_unique()); pm.addPass(std::make_unique(getLLVMOptions())); diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index a1156b0f7f9..b8aad323f4e 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -332,7 +332,6 @@ void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) pm.addNestedPass(mlir::createStdBufferizePass()); pm.addPass(mlir::createFuncBufferizePass()); pm.addPass(std::make_unique()); - pm.addPass(mlir::createLowerToCFGPass()); } } From 285067a75a5fc7eb564ef8bf86e674fdeed03063 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 6 Nov 2020 16:17:44 +0300 Subject: [PATCH 142/259] fusing pass --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index b8aad323f4e..d00e1a1e3cf 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -328,9 +328,13 @@ void LowerLinalgPass::runOnOperation() void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); + + pm.addPass(mlir::createLinalgFusionOfTensorOpsPass()); + pm.addPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::createStdBufferizePass()); pm.addPass(mlir::createFuncBufferizePass()); + pm.addPass(std::make_unique()); } } From 399d337ca78924cfeafb76a8253407bec3b45f70 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 6 Nov 2020 16:55:19 +0300 Subject: [PATCH 143/259] some optimization passes --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index d00e1a1e3cf..e106d20ba31 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -335,6 +335,12 @@ void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) pm.addNestedPass(mlir::createStdBufferizePass()); pm.addPass(mlir::createFuncBufferizePass()); + pm.addNestedPass(mlir::createPromoteBuffersToStackPass(1024)); + pm.addNestedPass(mlir::createBufferHoistingPass()); + pm.addNestedPass(mlir::createBufferLoopHoistingPass()); + pm.addNestedPass(mlir::createCopyRemovalPass()); + pm.addNestedPass(mlir::createBufferDeallocationPass()); + pm.addPass(std::make_unique()); } } From d322574113ed1dcd6c0c8bc25f3746b14b71b410 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 6 Nov 2020 17:14:20 +0300 Subject: [PATCH 144/259] test --- numba/mlir/tests/test_numpy.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 02f5444372c..760093f7f82 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -41,5 +41,16 @@ def py_func(a, b): arr2 = np.asarray([4,5,6]) assert_equal(py_func(arr1, arr2), jit_func(arr1, arr2)) + def test_sum_add2(self): + def py_func(a, b, c): + t = np.add(a, b) + return np.add(t, c).sum() + + jit_func = njit(py_func) + arr1 = np.asarray([1,2,3]) + arr2 = np.asarray([4,5,6]) + arr3 = np.asarray([7,8,9]) + assert_equal(py_func(arr1, arr2, arr3), jit_func(arr1, arr2, arr3)) + if __name__ == '__main__': unittest.main() From ac688841f78fa0555dea57d220f6c0e32c338133 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 6 Nov 2020 18:02:28 +0300 Subject: [PATCH 145/259] asd --- mlir-compiler/test.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 09f792aed8b..735121e59c3 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -69,6 +69,10 @@ def np_sum(a): def np_add(a, b): return np.add(a, b).sum() +def np_add2(a, b, c): + t = np.add(a, b) + return np.add(t, c).sum() + def test(func, params): global _tests_total global _tests_passes @@ -110,6 +114,7 @@ def test(func, params): test(np_getitem2, (np.asarray([1,2,3]),np.asarray([4,5,6]),1)) test(np_sum, (np.asarray([1,2,3]),)) test(np_add, (np.asarray([1,2,3]),np.asarray([4,5,6]))) +test(np_add2, (np.asarray([1,2,3]),np.asarray([4,5,6]),np.asarray([7,8,9]))) print(f'Tests passed: {_tests_passes}/{_tests_total}') if (len(_failed_tests) != 0): From 46a335ce48f42e559c9961bcd8d5b3701419c6a0 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 6 Nov 2020 19:04:36 +0300 Subject: [PATCH 146/259] fix type conversion --- .../src/rewrites/type_conversion.cpp | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index f261856d3f9..7f9be14e420 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -1,6 +1,7 @@ #include "rewrites/type_conversion.hpp" #include +#include #include "plier/dialect.hpp" @@ -14,6 +15,7 @@ mlir::LogicalResult setBlockSig( { return mlir::failure(); } + unsigned i = 0; for (auto it : llvm::zip(block.getArguments(), conversion.getConvertedTypes())) { auto arg = std::get<0>(it); @@ -26,8 +28,41 @@ mlir::LogicalResult setBlockSig( { return op.getOwner() != res; }); + + for (auto& use : block.getUses()) + { + auto op = use.getOwner(); + builder.setInsertionPoint(op); + if (auto br = mlir::dyn_cast(op)) + { + assert(&block == br.dest()); + auto src = br.destOperands()[i]; + auto new_op = builder.create(op->getLoc(), type, src); + br.destOperandsMutable().slice(i, 1).assign(new_op); + } + else if (auto cond_br = mlir::dyn_cast(op)) + { + if (&block == cond_br.trueDest()) + { + auto src = cond_br.trueDestOperands()[i]; + auto new_op = builder.create(op->getLoc(), type, src); + cond_br.trueDestOperandsMutable().slice(i, 1).assign(new_op); + } + if (&block == cond_br.falseDest()) + { + auto src = cond_br.falseDestOperands()[i]; + auto new_op = builder.create(op->getLoc(), type, src); + cond_br.falseDestOperandsMutable().slice(i, 1).assign(new_op); + } + } + else + { + llvm_unreachable("setBlockSig: unknown operation type"); + } + } arg.setType(type); } + ++i; } return mlir::success(); } From 4090e9f3f27c727418b35b95af849cf393d285d5 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 147/259] work on loop lowering --- mlir-compiler/src/pipelines/plier_to_std.cpp | 149 ++++++++++++++++++- 1 file changed, 147 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index c531a352092..b225e57a679 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1,7 +1,9 @@ #include "pipelines/plier_to_std.hpp" +#include #include #include +#include #include #include #include @@ -537,6 +539,141 @@ struct BinOpLowering : public mlir::OpRewritePattern } }; +template +Op get_next_op(llvm::iterator_range& iters) +{ + if (iters.empty()) + { + return nullptr; + } + auto res = mlir::dyn_cast(iters.begin()); + if (res) + { + auto next = std::next(iters.begin()); + iters = {next, iters.end()}; + } + return res; +} + +mlir::LogicalResult lower_loop(plier::GetiterOp getiter, mlir::PatternRewriter& builder, llvm::function_ref()> get_bounds) +{ + auto getiter_block = getiter.getOperation()->getBlock(); + auto get_next_block = [](mlir::Block* block)->mlir::Block* + { + assert(nullptr != block); + if (auto br = mlir::dyn_cast_or_null(block->getTerminator())) + { + return br.dest(); + } + return nullptr; + }; + auto iternext_block = get_next_block(getiter_block); + if (nullptr == iternext_block) + { + return mlir::failure(); + } + auto iters = llvm::iterator_range(*iternext_block); + auto iternext = get_next_op(iters); + auto pairfirst = get_next_op(iters); + auto pairsecond = get_next_op(iters); + while (get_next_op(iters)) {} // skip casts + auto cond_br = get_next_op(iters); + auto skip_casts = [](mlir::Value op) + { + while (auto cast = op.dyn_cast_or_null()) + { + op = cast.getOperand(); + } + return op; + }; + if (!iternext || !pairfirst || !pairsecond || !cond_br || + skip_casts(cond_br.condition()) != pairsecond) + { + return mlir::failure(); + } + auto body_block = cond_br.trueDest(); + auto post_block = cond_br.falseDest(); + assert(nullptr != body_block); + assert(nullptr != post_block); + if (get_next_block(body_block) != iternext_block) + { + return mlir::failure(); + } + + auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) + { + mlir::BlockAndValueMapping mapper; + assert(iternext_block->getNumArguments() == iterargs.size()); + for (auto it : llvm::zip(iternext_block->getArguments(), iterargs)) + { + mapper.map(std::get<0>(it), std::get<1>(it)); + } + auto index = builder.create(loc, pairfirst.getType(), iv); + mapper.map(pairfirst, index); + for (auto& op : body_block->without_terminator()) + { + auto new_op = builder.clone(op, mapper); + mapper.map(op, new_op); + } + + auto term_operands = mlir::cast(body_block->getTerminator()).destOperands(); + llvm::SmallVector yield_vars; + yield_vars.reserve(term_operands.size()); + for (auto arg : term_operands) + { + yield_vars.emplace_back(mapper.lookupOrDefault(arg)); + } + builder.create(loc, yield_vars); + }; + + auto loc = getiter.getLoc(); + + auto index_cast = [&](mlir::Value val)->mlir::Value + { + if (!val.getType().isa()) + { + return builder.create(loc, val); + } + return val; + }; + + auto bounds = get_bounds(); + std::get<0>(bounds) = index_cast(std::get<0>(bounds)); + std::get<1>(bounds) = index_cast(std::get<1>(bounds)); + std::get<2>(bounds) = index_cast(std::get<2>(bounds)); + mlir::OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(getiter); + auto loop_op = builder.create( + loc, + std::get<0>(bounds), // lower bound + std::get<1>(bounds), // upper bound + std::get<2>(bounds), // step + llvm::None, // iterArgs + body + ); + assert(loop_op.getNumResults() == iternext_block->getNumArguments()); + for (auto arg : llvm::zip(iternext_block->getArguments(), loop_op.getResults())) + { + std::get<0>(arg).replaceAllUsesWith(std::get<1>(arg)); + } + builder.eraseBlock(body_block); + builder.eraseBlock(iternext_block); + builder.eraseOp(getiter); + auto term = mlir::cast(getiter_block->getTerminator()); + term.setSuccessor(post_block); + return mlir::success(); +} + +mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +{ + if ((operands.size() < 1 || operands.size() > 3) || + !llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());})) + { + return mlir::failure(); + } + return mlir::failure(); +} + mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) { if (operands.size() != 1) @@ -565,9 +702,17 @@ mlir::LogicalResult basic_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { - if (name == "") + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); + std::pair handlers[] = { + {"", lower_bool_cast}, + {"range", lower_range}, + }; + for (auto& handler : handlers) { - return lower_bool_cast(op, args, rewriter); + if (handler.first == name) + { + return handler.second(op, args, rewriter); + } } return mlir::failure(); } From f009f463427439af55677ce12a667af0f329ff57 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 148/259] range lowering --- mlir-compiler/include/plier/PlierOps.td | 2 +- mlir-compiler/src/pipelines/plier_to_std.cpp | 93 +++++++++++++++----- numba/mlir/tests/test_basic.py | 30 +++++++ 3 files changed, 100 insertions(+), 25 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 4cba5de5dab..0e9e81ab27c 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -127,7 +127,7 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { ]; } -def GetiterOp : Plier_Op<"getiter", []> { +def GetiterOp : Plier_Op<"getiter", [NoSideEffect]> { let arguments = (ins AnyType:$value); diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index b225e57a679..d616ebb210b 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -215,6 +216,11 @@ bool is_float(mlir::Type type) return type.isa(); } +bool is_index(mlir::Type type) +{ + return type.isa(); +} + struct ConstOpLowering : public mlir::OpRewritePattern { ConstOpLowering(mlir::TypeConverter &/*typeConverter*/, @@ -418,6 +424,11 @@ mlir::Value float_int_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRe return rewriter.create(val.getLoc(), val, dst_type); } +mlir::Value index_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +{ + return rewriter.create(val.getLoc(), val, dst_type); +} + mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) { auto src_type = val.getType(); @@ -439,6 +450,8 @@ mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& {&is_int, &is_int, &int_cast}, {&is_int, &is_float, &int_float_cast}, {&is_float, &is_int, &float_int_cast}, + {&is_index, &is_int, &index_cast}, + {&is_int, &is_index, &index_cast}, }; for (auto& h : handlers) @@ -555,7 +568,9 @@ Op get_next_op(llvm::iterator_range& iters) return res; } -mlir::LogicalResult lower_loop(plier::GetiterOp getiter, mlir::PatternRewriter& builder, llvm::function_ref()> get_bounds) +mlir::LogicalResult lower_loop( + plier::GetiterOp getiter, mlir::PatternRewriter& builder, + llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds) { auto getiter_block = getiter.getOperation()->getBlock(); auto get_next_block = [](mlir::Block* block)->mlir::Block* @@ -572,6 +587,7 @@ mlir::LogicalResult lower_loop(plier::GetiterOp getiter, mlir::PatternRewriter& { return mlir::failure(); } + auto iters = llvm::iterator_range(*iternext_block); auto iternext = get_next_op(iters); auto pairfirst = get_next_op(iters); @@ -580,12 +596,13 @@ mlir::LogicalResult lower_loop(plier::GetiterOp getiter, mlir::PatternRewriter& auto cond_br = get_next_op(iters); auto skip_casts = [](mlir::Value op) { - while (auto cast = op.dyn_cast_or_null()) + while (auto cast = mlir::dyn_cast_or_null(op.getDefiningOp())) { op = cast.getOperand(); } return op; }; + if (!iternext || !pairfirst || !pairsecond || !cond_br || skip_casts(cond_br.condition()) != pairsecond) { @@ -600,6 +617,7 @@ mlir::LogicalResult lower_loop(plier::GetiterOp getiter, mlir::PatternRewriter& return mlir::failure(); } + auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) { mlir::BlockAndValueMapping mapper; @@ -610,10 +628,10 @@ mlir::LogicalResult lower_loop(plier::GetiterOp getiter, mlir::PatternRewriter& } auto index = builder.create(loc, pairfirst.getType(), iv); mapper.map(pairfirst, index); + for (auto& op : body_block->without_terminator()) { - auto new_op = builder.clone(op, mapper); - mapper.map(op, new_op); + builder.clone(op, mapper); } auto term_operands = mlir::cast(body_block->getTerminator()).destOperands(); @@ -632,23 +650,24 @@ mlir::LogicalResult lower_loop(plier::GetiterOp getiter, mlir::PatternRewriter& { if (!val.getType().isa()) { - return builder.create(loc, val); + return builder.create(loc, val, mlir::IndexType::get(val.getContext())); } return val; }; - auto bounds = get_bounds(); - std::get<0>(bounds) = index_cast(std::get<0>(bounds)); - std::get<1>(bounds) = index_cast(std::get<1>(bounds)); - std::get<2>(bounds) = index_cast(std::get<2>(bounds)); + auto term = mlir::cast(getiter_block->getTerminator()); + auto bounds = get_bounds(builder, loc); + auto lower_bound = index_cast(std::get<0>(bounds)); + auto upper_bound = index_cast(std::get<1>(bounds)); + auto step = index_cast(std::get<2>(bounds)); mlir::OpBuilder::InsertionGuard g(builder); - builder.setInsertionPoint(getiter); + builder.setInsertionPointAfter(getiter); auto loop_op = builder.create( loc, - std::get<0>(bounds), // lower bound - std::get<1>(bounds), // upper bound - std::get<2>(bounds), // step - llvm::None, // iterArgs + lower_bound, + upper_bound, + step, + term.destOperands(), // iterArgs body ); assert(loop_op.getNumResults() == iternext_block->getNumArguments()); @@ -656,11 +675,9 @@ mlir::LogicalResult lower_loop(plier::GetiterOp getiter, mlir::PatternRewriter& { std::get<0>(arg).replaceAllUsesWith(std::get<1>(arg)); } - builder.eraseBlock(body_block); - builder.eraseBlock(iternext_block); - builder.eraseOp(getiter); - auto term = mlir::cast(getiter_block->getTerminator()); - term.setSuccessor(post_block); + + builder.create(loc, post_block); + builder.eraseOp(term); return mlir::success(); } @@ -671,7 +688,28 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef { return mlir::failure(); } - return mlir::failure(); + mlir::Value val(op); + if (!val.getUsers().empty()) + { + auto user = mlir::dyn_cast(*val.getUsers().begin()); + auto get_bounds = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + auto lower_bound = (operands.size() >= 2 ? operands[0] : builder.create(loc, 0)); + auto upper_bound = (operands.size() >= 2 ? operands[1] : operands[0]); + auto step = (operands.size() == 3 ? operands[2] : builder.create(loc, 1)); + return std::make_tuple(lower_bound, upper_bound, step); + }; + if (!user || mlir::failed(lower_loop(user,rewriter, get_bounds))) + { + return mlir::failure(); + } + } + + if (val.getUsers().empty()) + { + rewriter.eraseOp(op); + } + return mlir::success(); } mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) @@ -705,7 +743,7 @@ mlir::LogicalResult basic_rewrite( using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); std::pair handlers[] = { {"", lower_bool_cast}, - {"range", lower_range}, + {"", lower_range}, }; for (auto& handler : handlers) { @@ -831,6 +869,8 @@ void PlierToStdPass::runOnOperation() type_converter.addConversion([](mlir::Type type) { return type; }); populate_std_type_converter(type_converter); + auto context = &getContext(); + mlir::OwningRewritePatternList patterns; patterns.insert< @@ -840,15 +880,20 @@ void PlierToStdPass::runOnOperation() SelectOpLowering, CondBrOpLowering, BinOpLowering - >(type_converter, &getContext()); + >(type_converter, context); patterns.insert< CastOpLowering - >(type_converter, &getContext(), &do_cast); + >(type_converter, context, &do_cast); patterns.insert< CallOpLowering - >(type_converter, &getContext(), &basic_rewrite); + >(type_converter, context, &basic_rewrite); + + for (auto *op : context->getRegisteredOperations()) + { + op->getCanonicalizationPatterns(patterns, context); + } (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 5f1488cd08b..d9060de2518 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -90,6 +90,36 @@ def py_func(a, b, c): for a, b, c in itertools.product(_test_values, _test_values, _test_values): assert_equal(py_func(a, b, c), jit_func(a, b, c)) + def test_range1(self): + def py_func(a): + res = 0 + for i in range(a): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + + def test_range2(self): + def py_func(a, b): + res = 0 + for i in range(a, b): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10, 20), jit_func(10, 20)) + + def test_range3(self): + def py_func(a, b, c): + res = 0 + for i in range(a, b, c): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + if __name__ == '__main__': unittest.main() From 25979c98081ba38005edd9b307918050fbd32ae2 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 149/259] fix for nested range --- mlir-compiler/src/pipelines/plier_to_std.cpp | 14 +++++++++++--- numba/mlir/tests/test_basic.py | 12 ++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index d616ebb210b..a1dd4eb47ab 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -617,7 +617,6 @@ mlir::LogicalResult lower_loop( return mlir::failure(); } - auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) { mlir::BlockAndValueMapping mapper; @@ -660,7 +659,7 @@ mlir::LogicalResult lower_loop( auto lower_bound = index_cast(std::get<0>(bounds)); auto upper_bound = index_cast(std::get<1>(bounds)); auto step = index_cast(std::get<2>(bounds)); - mlir::OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointAfter(getiter); auto loop_op = builder.create( loc, @@ -676,8 +675,17 @@ mlir::LogicalResult lower_loop( std::get<0>(arg).replaceAllUsesWith(std::get<1>(arg)); } - builder.create(loc, post_block); + auto iternext_term = mlir::cast(iternext_block->getTerminator()); + + builder.create(loc, post_block, iternext_term.falseDestOperands()); builder.eraseOp(term); + + iternext_block->dropAllDefinedValueUses(); + iternext_block->erase(); + body_block->dropAllDefinedValueUses(); + body_block->erase(); + builder.eraseOp(getiter); + return mlir::success(); } diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index d9060de2518..53e04222637 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -120,6 +120,18 @@ def py_func(a, b, c): jit_func = njit(py_func) assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + def test_range_nested(self): + def py_func(a, b, c): + res = 0 + for i in range(a): + for j in range(b): + for k in range(c): + res = res + i + j * 10 + k * 100 + return res + + jit_func = njit(py_func) + assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + if __name__ == '__main__': unittest.main() From 6a1ea495a4fd72250aef3c7dbe8b40308275182c Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 150/259] update tests sandbox --- mlir-compiler/test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index 735121e59c3..f5254a7312d 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -53,8 +53,18 @@ def arr_loop(): def range_loop(n): res = 0 + res1 = 2 for i in range(n): res = res + i + res1 = res1 + i * 2 + return res + res1 + +def range_loop_nested(a, b, c): + res = 0 + for i in range(a): + for j in range(b): + for k in range(c): + res = res + i + j * 10 + k * 100 return res def np_getitem(a, b): @@ -109,6 +119,7 @@ def test(func, params): test(tuple, (1,2.0,3)) test(arr_loop, ()) test(range_loop, (8,)) +test(range_loop_nested, (8,9,10)) test(sum2, (np.asarray([1,2,3]),np.asarray([4,5,6]))) test(np_getitem, (np.asarray([1,2,3]),1)) test(np_getitem2, (np.asarray([1,2,3]),np.asarray([4,5,6]),1)) From 421d5a91cf5010f23cce66f96fd418f4a2b20971 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 151/259] fix --- mlir-compiler/src/pipelines/plier_to_std.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index a1dd4eb47ab..d50ce8c6206 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include @@ -681,8 +680,9 @@ mlir::LogicalResult lower_loop( builder.eraseOp(term); iternext_block->dropAllDefinedValueUses(); - iternext_block->erase(); body_block->dropAllDefinedValueUses(); + + iternext_block->erase(); body_block->erase(); builder.eraseOp(getiter); From d8db043d5e946e4165627713e806b8d59a6de526 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 152/259] cfg to scf::if --- mlir-compiler/src/pipelines/plier_to_std.cpp | 146 +++++++++++++++++-- mlir-compiler/test.py | 6 +- numba/mlir/tests/test_basic.py | 14 ++ 3 files changed, 154 insertions(+), 12 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index d50ce8c6206..43e13832117 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -551,6 +551,139 @@ struct BinOpLowering : public mlir::OpRewritePattern } }; +mlir::Block* get_next_block(mlir::Block* block) +{ + assert(nullptr != block); + if (auto br = mlir::dyn_cast_or_null(block->getTerminator())) + { + return br.dest(); + } + return nullptr; +}; + +void erase_blocks(llvm::ArrayRef blocks) +{ + for (auto block : blocks) + { + assert(nullptr != block); + block->dropAllDefinedValueUses(); + } + for (auto block : blocks) + { + block->erase(); + } +} + +struct ScfIfRewrite : public mlir::OpRewritePattern +{ + ScfIfRewrite(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::CondBranchOp op, mlir::PatternRewriter &rewriter) const override + { + auto true_block = op.getTrueDest(); + auto post_block = get_next_block(true_block); + if (nullptr == post_block) + { + return mlir::failure(); + } + auto false_block = op.getFalseDest(); + if (false_block != post_block && + get_next_block(false_block) != post_block) + { + return mlir::failure(); + } + auto cond = op.condition(); + + mlir::BlockAndValueMapping mapper; + llvm::SmallVector yield_vals; + auto copy_block = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Block& block) + { + mapper.clear(); + for (auto& op : block.without_terminator()) + { + builder.clone(op, mapper); + } + auto term = mlir::cast(block.getTerminator()); + yield_vals.clear(); + yield_vals.reserve(term.getNumOperands()); + for (auto op : term.getOperands()) + { + yield_vals.emplace_back(mapper.lookupOrDefault(op)); + } + builder.create(loc, yield_vals); + }; + + auto true_body = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + copy_block(builder, loc, *true_block); + }; + + bool has_else = false_block != post_block; + auto res_types = mlir::cast(true_block->getTerminator()).getOperandTypes(); + mlir::scf::IfOp if_op; + if (has_else) + { + auto false_body = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + copy_block(builder, loc, *false_block); + }; + if_op = rewriter.create( + op.getLoc(), + res_types, + cond, + true_body, + false_body); + } + else + { + if (res_types.empty()) + { + if_op = rewriter.create( + op.getLoc(), + res_types, + cond, + true_body); + } + else + { + auto false_body = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + auto res = op.getFalseOperands(); + yield_vals.clear(); + yield_vals.reserve(res.size()); + for (auto op : res) + { + yield_vals.emplace_back(mapper.lookupOrDefault(op)); + } + builder.create(loc, yield_vals); + }; + if_op = rewriter.create( + op.getLoc(), + res_types, + cond, + true_body, + false_body); + } + } + + rewriter.create(op.getLoc(), post_block, if_op.getResults()); + rewriter.eraseOp(op); + + if (true_block->getUsers().empty()) + { + erase_blocks(true_block); + } + if (false_block->getUsers().empty()) + { + erase_blocks(false_block); + } + return mlir::success(); + } +}; + template Op get_next_op(llvm::iterator_range& iters) { @@ -572,15 +705,7 @@ mlir::LogicalResult lower_loop( llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds) { auto getiter_block = getiter.getOperation()->getBlock(); - auto get_next_block = [](mlir::Block* block)->mlir::Block* - { - assert(nullptr != block); - if (auto br = mlir::dyn_cast_or_null(block->getTerminator())) - { - return br.dest(); - } - return nullptr; - }; + auto iternext_block = get_next_block(getiter_block); if (nullptr == iternext_block) { @@ -887,7 +1012,8 @@ void PlierToStdPass::runOnOperation() ConstOpLowering, SelectOpLowering, CondBrOpLowering, - BinOpLowering + BinOpLowering, + ScfIfRewrite >(type_converter, context); patterns.insert< diff --git a/mlir-compiler/test.py b/mlir-compiler/test.py index f5254a7312d..9ef63842f92 100644 --- a/mlir-compiler/test.py +++ b/mlir-compiler/test.py @@ -55,8 +55,10 @@ def range_loop(n): res = 0 res1 = 2 for i in range(n): - res = res + i - res1 = res1 + i * 2 + if i > 5: + res = res + i + else: + res1 = res1 + i * 2 return res + res1 def range_loop_nested(a, b, c): diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 53e04222637..55216e6bbdd 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -120,6 +120,20 @@ def py_func(a, b, c): jit_func = njit(py_func) assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + def test_range_if(self): + def py_func(n): + res = 0 + res1 = 2 + for i in range(n): + if i > 5: + res = res + i + else: + res1 = res1 + i * 2 + return res + res1 + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + def test_range_nested(self): def py_func(a, b, c): res = 0 From 3a212f5b044e732b2a6c03eacb92493a6c0bc5ee Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 153/259] update llvm --- mlir-compiler/llvm-sha.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 18732e73d11..4a1f9a6027f 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -f253823398dd2894ee5d9333c541c534b7a407fb +b0de3f67874ac3eff465cb2ef8ab6081292625c3 From 920b5f997725b06bedf56955dfacbfba47bf7529 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 154/259] is_blocks_different --- mlir-compiler/src/pipelines/plier_to_std.cpp | 24 ++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 43e13832117..43ca0801d71 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -574,6 +574,24 @@ void erase_blocks(llvm::ArrayRef blocks) } } +bool is_blocks_different(llvm::ArrayRef blocks) +{ + for (auto it : llvm::enumerate(blocks)) + { + auto block1 = it.value(); + assert(nullptr != block1); + for (auto block2 : blocks.drop_front(it.index() + 1)) + { + assert(nullptr != block2); + if (block1 == block2) + { + return false; + } + } + } + return true; +} + struct ScfIfRewrite : public mlir::OpRewritePattern { ScfIfRewrite(mlir::TypeConverter &/*typeConverter*/, @@ -595,6 +613,12 @@ struct ScfIfRewrite : public mlir::OpRewritePattern { return mlir::failure(); } + + auto start_block = op.getOperation()->getBlock(); + if (!is_blocks_different({start_block, true_block, post_block})) + { + return mlir::failure(); + } auto cond = op.condition(); mlir::BlockAndValueMapping mapper; From 86807cabfaa01f5bcb7eb8864d9376c6004a780b Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 155/259] transform loops to scf.while --- mlir-compiler/src/pipelines/plier_to_std.cpp | 255 ++++++++++++++++++- 1 file changed, 253 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 43ca0801d71..f02e49a4eec 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -176,6 +176,7 @@ mlir::Type map_plier_type_name(mlir::MLIRContext& ctx, llvm::StringRef& name) mlir::Type map_plier_type(mlir::Type type) { + assert(type); if (!type.isa()) { return type; @@ -186,6 +187,7 @@ mlir::Type map_plier_type(mlir::Type type) bool is_supported_type(mlir::Type type) { + assert(type); return type.isIntOrFloat(); } @@ -207,16 +209,19 @@ void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir:: bool is_int(mlir::Type type) { + assert(type); return type.isa(); } bool is_float(mlir::Type type) { + assert(type); return type.isa(); } bool is_index(mlir::Type type) { + assert(type); return type.isa(); } @@ -361,6 +366,7 @@ mlir::Type coerce(mlir::Type type0, mlir::Type type1) assert(type0 != type1); auto get_bits_count = [](mlir::Type type)->unsigned { + assert(type); if (type.isa()) { return type.cast().getWidth(); @@ -430,6 +436,7 @@ mlir::Value index_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewrit mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) { + assert(dst_type); auto src_type = val.getType(); if (src_type == dst_type) { @@ -708,6 +715,248 @@ struct ScfIfRewrite : public mlir::OpRewritePattern } }; +mlir::scf::WhileOp create_while( + mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange iterArgs, + llvm::function_ref beforeBuilder, + llvm::function_ref afterBuilder) +{ + mlir::OperationState state(loc, mlir::scf::WhileOp::getOperationName()); + state.addOperands(iterArgs); + + { + mlir::OpBuilder::InsertionGuard g(builder); + auto add_region = [&](mlir::ValueRange args)->mlir::Block* + { + auto reg = state.addRegion(); + auto block = builder.createBlock(reg); + for (auto arg : args) + { + block->addArgument(arg.getType()); + } + return block; + }; + + auto beforeBlock = add_region(iterArgs); + beforeBuilder(builder, state.location, beforeBlock->getArguments()); + auto cond = mlir::cast(beforeBlock->getTerminator()); + state.addTypes(cond.args().getTypes()); + + auto afterblock = add_region(cond.args()); + afterBuilder(builder, state.location, afterblock->getArguments()); + } + return mlir::cast(builder.createOperation(state)); +} + +bool is_inside_block(mlir::Operation* op, mlir::Block* block) +{ + assert(nullptr != op); + assert(nullptr != block); + do + { + if (op->getBlock() == block) + { + return true; + } + } + while((op = op->getParentOp())); + return false; +} + +struct ScfWhileRewrite : public mlir::OpRewritePattern +{ + ScfWhileRewrite(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::BranchOp op, mlir::PatternRewriter &rewriter) const override + { + auto before_block = op.dest(); + auto before_term = mlir::dyn_cast(before_block->getTerminator()); + if (!before_term) + { + return mlir::failure(); + } + auto start_block = op.getOperation()->getBlock(); + auto after_block = before_term.trueDest(); + auto post_block = before_term.falseDest(); + if (get_next_block(after_block) != before_block || + !is_blocks_different({start_block, before_block, after_block, post_block})) + { + return mlir::failure(); + } + + auto check_outside_vals = [&](mlir::Operation* op)->mlir::WalkResult + { + for (auto user : op->getUsers()) + { + if (!is_inside_block(user, before_block) && + !is_inside_block(user, after_block)) + { + return mlir::WalkResult::interrupt(); + } + } + return mlir::WalkResult::advance(); + }; + + if (after_block->walk(check_outside_vals).wasInterrupted()) + { + return mlir::failure(); + } + + mlir::BlockAndValueMapping mapper; + llvm::SmallVector yield_vars; + auto before_block_args = before_block->getArguments(); + llvm::SmallVector orig_vars(before_block_args.begin(), before_block_args.end()); + + auto before_body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange iterargs) + { + mapper.map(before_block_args, iterargs); + yield_vars.resize(before_block_args.size()); + for (auto& op : before_block->without_terminator()) + { + auto new_op = builder.clone(op, mapper); + for (auto user : op.getUsers()) + { + user->isBeforeInBlock(user); + if (!is_inside_block(user, before_block)) + { + for (auto it : llvm::zip(op.getResults(), new_op->getResults())) + { + orig_vars.emplace_back(std::get<0>(it)); + yield_vars.emplace_back(std::get<1>(it)); + } + break; + } + } + } + + llvm::transform(before_block->getArguments(), yield_vars.begin(), + [&](mlir::Value val) { return mapper.lookupOrDefault(val); }); + + auto term = mlir::cast(before_block->getTerminator()); + for (auto arg : term.falseDestOperands()) + { + orig_vars.emplace_back(arg); + yield_vars.emplace_back(mapper.lookupOrDefault(arg)); + } + auto cond = mapper.lookupOrDefault(term.condition()); + builder.create(loc, cond, yield_vars); + }; + auto after_body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange iterargs) + { + mapper.clear(); + assert(orig_vars.size() == iterargs.size()); + mapper.map(orig_vars, iterargs); + for (auto& op : after_block->without_terminator()) + { + builder.clone(op, mapper); + } + yield_vars.clear(); + auto term = mlir::cast(after_block->getTerminator()); + for (auto arg : term.getOperands()) + { + yield_vars.emplace_back(mapper.lookupOrDefault(arg)); + } + builder.create(loc, yield_vars); + }; + + auto while_op = create_while( + rewriter, + op.getLoc(), + op.getOperands(), + before_body, + after_body); + + assert(orig_vars.size() == while_op.getNumResults()); + for (auto arg : llvm::zip(orig_vars, while_op.getResults())) + { + std::get<0>(arg).replaceAllUsesWith(std::get<1>(arg)); + } + + rewriter.create(op.getLoc(), post_block, before_term.falseDestOperands()); + rewriter.eraseOp(op); + erase_blocks({before_block, after_block}); + + return mlir::success(); + } +}; + +struct FixupWhileTypes : public mlir::OpRewritePattern +{ + FixupWhileTypes(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + mlir::scf::WhileOp op, mlir::PatternRewriter &rewriter) const override + { + bool changed = false; + mlir::OpBuilder::InsertionGuard g(rewriter); + auto before_block = &op.before().front(); + rewriter.startRootUpdate(op); + rewriter.setInsertionPointToStart(before_block); + assert(before_block->getNumArguments() == op.getNumOperands()); + auto loc = rewriter.getUnknownLoc(); + for (auto it : llvm::zip(op.getOperandTypes(), before_block->getArguments())) + { + auto new_type = std::get<0>(it); + auto arg = std::get<1>(it); + auto old_type = arg.getType(); + if (old_type != new_type) + { + rewriter.create(loc, old_type, arg); + arg.setType(new_type); + changed = true; + } + } + + auto term = mlir::cast(before_block->getTerminator()); + auto after_types = term.args().getTypes(); + + auto after_block = &op.after().front(); + rewriter.setInsertionPointToStart(after_block); + assert(after_block->getNumArguments() == term.args().size()); + for (auto it : llvm::zip(after_types, after_block->getArguments())) + { + auto new_type = std::get<0>(it); + auto arg = std::get<1>(it); + auto old_type = arg.getType(); + if (old_type != new_type) + { + rewriter.create(loc, old_type, arg); + arg.setType(new_type); + changed = true; + } + } + + rewriter.setInsertionPointAfter(op); + assert(op.getNumResults() == term.args().size()); + for (auto it : llvm::zip(after_types, op.getResults())) + { + auto new_type = std::get<0>(it); + auto arg = std::get<1>(it); + auto old_type = arg.getType(); + if (old_type != new_type) + { + rewriter.create(loc, old_type, arg); + arg.setType(new_type); + changed = true; + } + } + + if (changed) + { + rewriter.finalizeRootUpdate(op); + } + else + { + rewriter.cancelRootUpdate(op); + } + return mlir::success(changed); + } +}; + template Op get_next_op(llvm::iterator_range& iters) { @@ -900,7 +1149,7 @@ mlir::LogicalResult basic_rewrite( using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); std::pair handlers[] = { {"", lower_bool_cast}, - {"", lower_range}, +// {"", lower_range}, }; for (auto& handler : handlers) { @@ -1037,7 +1286,9 @@ void PlierToStdPass::runOnOperation() SelectOpLowering, CondBrOpLowering, BinOpLowering, - ScfIfRewrite + ScfIfRewrite, + ScfWhileRewrite, + FixupWhileTypes >(type_converter, context); patterns.insert< From 34ce7ab29d161978ab5f669e276f3e07faed6902 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 156/259] lower scf.for from scf.while --- mlir-compiler/src/pipelines/plier_to_std.cpp | 195 ++++++++++--------- 1 file changed, 104 insertions(+), 91 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index f02e49a4eec..086be9beaab 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -977,114 +977,127 @@ mlir::LogicalResult lower_loop( plier::GetiterOp getiter, mlir::PatternRewriter& builder, llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds) { - auto getiter_block = getiter.getOperation()->getBlock(); - - auto iternext_block = get_next_block(getiter_block); - if (nullptr == iternext_block) + llvm::SmallVector to_process; + for (auto user : getiter.getOperation()->getUsers()) { - return mlir::failure(); - } - - auto iters = llvm::iterator_range(*iternext_block); - auto iternext = get_next_op(iters); - auto pairfirst = get_next_op(iters); - auto pairsecond = get_next_op(iters); - while (get_next_op(iters)) {} // skip casts - auto cond_br = get_next_op(iters); - auto skip_casts = [](mlir::Value op) - { - while (auto cast = mlir::dyn_cast_or_null(op.getDefiningOp())) + if( auto while_op = mlir::dyn_cast(user->getParentOp())) { - op = cast.getOperand(); + to_process.emplace_back(while_op); } - return op; - }; - - if (!iternext || !pairfirst || !pairsecond || !cond_br || - skip_casts(cond_br.condition()) != pairsecond) - { - return mlir::failure(); - } - auto body_block = cond_br.trueDest(); - auto post_block = cond_br.falseDest(); - assert(nullptr != body_block); - assert(nullptr != post_block); - if (get_next_block(body_block) != iternext_block) - { - return mlir::failure(); } - auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) + bool changed = false; + for (auto while_op : to_process) { - mlir::BlockAndValueMapping mapper; - assert(iternext_block->getNumArguments() == iterargs.size()); - for (auto it : llvm::zip(iternext_block->getArguments(), iterargs)) - { - mapper.map(std::get<0>(it), std::get<1>(it)); - } - auto index = builder.create(loc, pairfirst.getType(), iv); - mapper.map(pairfirst, index); - - for (auto& op : body_block->without_terminator()) + auto& before_block = while_op.before().front(); + auto iters = llvm::iterator_range(before_block); + auto iternext = get_next_op(iters); + auto pairfirst = get_next_op(iters); + auto pairsecond = get_next_op(iters); + while (get_next_op(iters)) {} // skip casts + auto before_term = get_next_op(iters); + + auto skip_casts = [](mlir::Value op) { - builder.clone(op, mapper); - } - - auto term_operands = mlir::cast(body_block->getTerminator()).destOperands(); - llvm::SmallVector yield_vars; - yield_vars.reserve(term_operands.size()); - for (auto arg : term_operands) + while (auto cast = mlir::dyn_cast_or_null(op.getDefiningOp())) + { + op = cast.getOperand(); + } + return op; + }; + if (!iternext || !pairfirst || !pairsecond || !before_term || + skip_casts(before_term.condition()) != pairsecond) { - yield_vars.emplace_back(mapper.lookupOrDefault(arg)); + continue; } - builder.create(loc, yield_vars); - }; - auto loc = getiter.getLoc(); + auto& after_block = while_op.after().front(); - auto index_cast = [&](mlir::Value val)->mlir::Value - { - if (!val.getType().isa()) + auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) { - return builder.create(loc, val, mlir::IndexType::get(val.getContext())); - } - return val; - }; + mlir::BlockAndValueMapping mapper; + assert(before_block.getNumArguments() == iterargs.size()); + assert(after_block.getNumArguments() == before_term.args().size()); + mapper.map(before_block.getArguments(), iterargs); + for (auto it : llvm::zip(after_block.getArguments(), before_term.args())) + { + auto block_arg = std::get<0>(it); + auto term_arg = std::get<1>(it); + if (term_arg == pairfirst) // iter arg + { + auto index = builder.create(loc, pairfirst.getType(), iv); + mapper.map(block_arg, index); + } + else + { + mapper.map(block_arg, mapper.lookupOrDefault(term_arg)); + } + } - auto term = mlir::cast(getiter_block->getTerminator()); - auto bounds = get_bounds(builder, loc); - auto lower_bound = index_cast(std::get<0>(bounds)); - auto upper_bound = index_cast(std::get<1>(bounds)); - auto step = index_cast(std::get<2>(bounds)); - - builder.setInsertionPointAfter(getiter); - auto loop_op = builder.create( - loc, - lower_bound, - upper_bound, - step, - term.destOperands(), // iterArgs - body - ); - assert(loop_op.getNumResults() == iternext_block->getNumArguments()); - for (auto arg : llvm::zip(iternext_block->getArguments(), loop_op.getResults())) - { - std::get<0>(arg).replaceAllUsesWith(std::get<1>(arg)); - } + for (auto& op : after_block) // with terminator + { + builder.clone(op, mapper); + } + }; - auto iternext_term = mlir::cast(iternext_block->getTerminator()); + auto loc = getiter.getLoc(); - builder.create(loc, post_block, iternext_term.falseDestOperands()); - builder.eraseOp(term); + auto index_cast = [&](mlir::Value val)->mlir::Value + { + if (!val.getType().isa()) + { + return builder.create(loc, val, mlir::IndexType::get(val.getContext())); + } + return val; + }; - iternext_block->dropAllDefinedValueUses(); - body_block->dropAllDefinedValueUses(); + auto bounds = get_bounds(builder, loc); + auto lower_bound = index_cast(std::get<0>(bounds)); + auto upper_bound = index_cast(std::get<1>(bounds)); + auto step = index_cast(std::get<2>(bounds)); + + builder.setInsertionPoint(while_op); + auto loop_op = builder.create( + loc, + lower_bound, + upper_bound, + step, + while_op.getOperands(), // iterArgs + body + ); + + assert(while_op.getNumResults() >= loop_op.getNumResults()); + builder.updateRootInPlace(while_op, [&]() + { + assert(while_op.getNumResults() == before_term.args().size()); + for (auto it : llvm::zip(while_op.getResults(), before_term.args())) + { + auto old_res = std::get<0>(it); + auto operand = std::get<1>(it); + for (auto it2 : llvm::enumerate(before_block.getArguments())) + { + auto arg = it2.value(); + if (arg == operand) + { + assert(it2.index() < loop_op.getNumResults()); + auto new_res = loop_op.getResult(static_cast(it2.index())); + old_res.replaceAllUsesWith(new_res); + } + } + } + }); - iternext_block->erase(); - body_block->erase(); - builder.eraseOp(getiter); + assert(while_op.getOperation()->getUsers().empty()); + builder.eraseOp(while_op); + changed = true; + } - return mlir::success(); + if (getiter.getOperation()->getUsers().empty()) + { + builder.eraseOp(getiter); + changed = true; + } + return mlir::success(changed); } mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) @@ -1149,7 +1162,7 @@ mlir::LogicalResult basic_rewrite( using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); std::pair handlers[] = { {"", lower_bool_cast}, -// {"", lower_range}, + {"", lower_range}, }; for (auto& handler : handlers) { From adc0b32c5003aa42b5a033779c6889cb67da8510 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 157/259] remove some unsued code --- mlir-compiler/src/pipelines/plier_to_std.cpp | 71 ------------------- .../src/rewrites/type_conversion.cpp | 34 --------- .../src/rewrites/type_conversion.hpp | 13 ---- 3 files changed, 118 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 086be9beaab..476f9773247 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1184,58 +1184,6 @@ mlir::Operation* change_op_ret_type(mlir::Operation* op, return rewriter.createOperation(state); } -struct ExpandTuples : public mlir::RewritePattern -{ - ExpandTuples(mlir::MLIRContext* ctx): - RewritePattern(0, mlir::Pattern::MatchAnyOpTypeTag()), - dialect(ctx->getLoadedDialect()) - { - assert(nullptr != dialect); - } - - mlir::LogicalResult - matchAndRewrite(mlir::Operation* op, mlir::PatternRewriter& rewriter) const override - { - if (op->getResultTypes().size() != 1 || - !op->getResultTypes()[0].isa() || - (op->getDialect() != dialect)) - { - return mlir::failure(); - } - auto types = op->getResultTypes()[0].cast().getTypes(); - - auto new_op = change_op_ret_type(op, rewriter, types); - auto new_op_results = new_op->getResults(); - - llvm::SmallVector users(op->getUsers()); - llvm::SmallVector new_operands; - for (auto user_op : users) - { - new_operands.clear(); - for (auto arg : user_op->getOperands()) - { - if (arg.getDefiningOp() == op) - { - std::copy(new_op_results.begin(), new_op_results.end(), std::back_inserter(new_operands)); - } - else - { - new_operands.push_back(arg); - } - } - rewriter.updateRootInPlace(user_op, [&]() - { - user_op->setOperands(new_operands); - }); - } - rewriter.eraseOp(op); - return mlir::success(); - } - -private: - mlir::Dialect* dialect = nullptr; -}; - struct PlierToStdPass : public mlir::PassWrapper> { @@ -1249,25 +1197,6 @@ struct PlierToStdPass : void runOnOperation() override; }; -bool check_for_plier_types(mlir::Type type) -{ - if (type.isa()) - { - return true; - } - if (auto ftype = type.dyn_cast()) - { - return llvm::any_of(ftype.getResults(), &check_for_plier_types) || - llvm::any_of(ftype.getInputs(), &check_for_plier_types); - } - return false; -} - -bool check_op_for_plier_types(mlir::Value val) -{ - return check_for_plier_types(val.getType()); -} - template mlir::Value cast_materializer( mlir::OpBuilder& builder, T type, mlir::ValueRange inputs, diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index 7f9be14e420..c985b6e842d 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -134,37 +134,3 @@ mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( }); return mlir::success(); } - -OpTypeConversion::OpTypeConversion(mlir::MLIRContext*, mlir::TypeConverter& conv): - RewritePattern(0, mlir::Pattern::MatchAnyOpTypeTag()), - converter(conv) {} - -mlir::LogicalResult OpTypeConversion::matchAndRewrite(mlir::Operation* op, mlir::PatternRewriter& rewriter) const -{ - bool changed = false; - llvm::SmallVector new_types; - for (auto type : op->getResultTypes()) - { - if (auto new_type = converter.convertType(type)) - { - changed = changed || (new_type != type); - new_types.push_back(new_type); - } - else - { - new_types.push_back(type); - } - } - - if (changed) - { - rewriter.updateRootInPlace(op, [&] - { - for (unsigned i = 0; i < static_cast(new_types.size()); ++i) - { - op->getResult(i).setType(new_types[i]); - } - }); - } - return mlir::success(changed); -} diff --git a/mlir-compiler/src/rewrites/type_conversion.hpp b/mlir-compiler/src/rewrites/type_conversion.hpp index 6af35454ead..63b9b585bb7 100644 --- a/mlir-compiler/src/rewrites/type_conversion.hpp +++ b/mlir-compiler/src/rewrites/type_conversion.hpp @@ -20,16 +20,3 @@ struct FuncOpSignatureConversion : public mlir::OpRewritePattern private: mlir::TypeConverter& converter; }; - -struct OpTypeConversion : public mlir::RewritePattern -{ - OpTypeConversion(mlir::MLIRContext* ctx, - mlir::TypeConverter& conv); - - /// Hook for derived classes to implement combined matching and rewriting. - mlir::LogicalResult - matchAndRewrite(mlir::Operation* op, mlir::PatternRewriter &rewriter) const override; - -private: - mlir::TypeConverter& converter; -}; From 2c9759c0def12bb76165657cb70550e0adb13a83 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 158/259] refactor for loop lowering --- mlir-compiler/src/pipelines/plier_to_std.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 476f9773247..b63b16f4854 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -973,9 +973,10 @@ Op get_next_op(llvm::iterator_range& iters) return res; } -mlir::LogicalResult lower_loop( +mlir::LogicalResult lower_while_to_for( plier::GetiterOp getiter, mlir::PatternRewriter& builder, - llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds) + llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, + llvm::function_ref get_iter_val) { llvm::SmallVector to_process; for (auto user : getiter.getOperation()->getUsers()) @@ -1025,8 +1026,8 @@ mlir::LogicalResult lower_loop( auto term_arg = std::get<1>(it); if (term_arg == pairfirst) // iter arg { - auto index = builder.create(loc, pairfirst.getType(), iv); - mapper.map(block_arg, index); + auto iter_val = get_iter_val(builder, loc, pairfirst.getType(), iv); + mapper.map(block_arg, iter_val); } else { @@ -1118,7 +1119,11 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef auto step = (operands.size() == 3 ? operands[2] : builder.create(loc, 1)); return std::make_tuple(lower_bound, upper_bound, step); }; - if (!user || mlir::failed(lower_loop(user,rewriter, get_bounds))) + auto get_index = [](mlir::OpBuilder& builder, mlir::Location loc, mlir::Type dst_type, mlir::Value index) + { + return builder.create(loc, dst_type, index); + }; + if (!user || mlir::failed(lower_while_to_for(user,rewriter, get_bounds, get_index))) { return mlir::failure(); } From 948a716c76a44cad24e9dfbf5c900e28e9987e8b Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 7 Nov 2020 00:44:39 +0300 Subject: [PATCH 159/259] move loop utils to separate file --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/src/pipelines/plier_to_std.cpp | 145 +---------------- mlir-compiler/src/transforms/loop_utils.cpp | 159 +++++++++++++++++++ mlir-compiler/src/transforms/loop_utils.hpp | 23 +++ 4 files changed, 185 insertions(+), 144 deletions(-) create mode 100644 mlir-compiler/src/transforms/loop_utils.cpp create mode 100644 mlir-compiler/src/transforms/loop_utils.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index acca1296b1f..8a1801ae27e 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -27,6 +27,7 @@ set(SOURCES_LIST src/rewrites/call_lowering.cpp src/rewrites/cast_lowering.cpp src/rewrites/type_conversion.cpp + src/transforms/loop_utils.cpp src/compiler.cpp src/dialect.cpp src/lowering.cpp @@ -44,6 +45,7 @@ set(HEADERS_LIST src/rewrites/call_lowering.hpp src/rewrites/cast_lowering.hpp src/rewrites/type_conversion.hpp + src/transforms/loop_utils.hpp src/compiler.hpp src/lowering.hpp src/pipeline_registry.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index b63b16f4854..7c080eb3a69 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -17,6 +17,7 @@ #include "rewrites/call_lowering.hpp" #include "rewrites/cast_lowering.hpp" #include "rewrites/type_conversion.hpp" +#include "transforms/loop_utils.hpp" #include "base_pipeline.hpp" #include "pipeline_registry.hpp" @@ -957,150 +958,6 @@ struct FixupWhileTypes : public mlir::OpRewritePattern } }; -template -Op get_next_op(llvm::iterator_range& iters) -{ - if (iters.empty()) - { - return nullptr; - } - auto res = mlir::dyn_cast(iters.begin()); - if (res) - { - auto next = std::next(iters.begin()); - iters = {next, iters.end()}; - } - return res; -} - -mlir::LogicalResult lower_while_to_for( - plier::GetiterOp getiter, mlir::PatternRewriter& builder, - llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, - llvm::function_ref get_iter_val) -{ - llvm::SmallVector to_process; - for (auto user : getiter.getOperation()->getUsers()) - { - if( auto while_op = mlir::dyn_cast(user->getParentOp())) - { - to_process.emplace_back(while_op); - } - } - - bool changed = false; - for (auto while_op : to_process) - { - auto& before_block = while_op.before().front(); - auto iters = llvm::iterator_range(before_block); - auto iternext = get_next_op(iters); - auto pairfirst = get_next_op(iters); - auto pairsecond = get_next_op(iters); - while (get_next_op(iters)) {} // skip casts - auto before_term = get_next_op(iters); - - auto skip_casts = [](mlir::Value op) - { - while (auto cast = mlir::dyn_cast_or_null(op.getDefiningOp())) - { - op = cast.getOperand(); - } - return op; - }; - if (!iternext || !pairfirst || !pairsecond || !before_term || - skip_casts(before_term.condition()) != pairsecond) - { - continue; - } - - auto& after_block = while_op.after().front(); - - auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) - { - mlir::BlockAndValueMapping mapper; - assert(before_block.getNumArguments() == iterargs.size()); - assert(after_block.getNumArguments() == before_term.args().size()); - mapper.map(before_block.getArguments(), iterargs); - for (auto it : llvm::zip(after_block.getArguments(), before_term.args())) - { - auto block_arg = std::get<0>(it); - auto term_arg = std::get<1>(it); - if (term_arg == pairfirst) // iter arg - { - auto iter_val = get_iter_val(builder, loc, pairfirst.getType(), iv); - mapper.map(block_arg, iter_val); - } - else - { - mapper.map(block_arg, mapper.lookupOrDefault(term_arg)); - } - } - - for (auto& op : after_block) // with terminator - { - builder.clone(op, mapper); - } - }; - - auto loc = getiter.getLoc(); - - auto index_cast = [&](mlir::Value val)->mlir::Value - { - if (!val.getType().isa()) - { - return builder.create(loc, val, mlir::IndexType::get(val.getContext())); - } - return val; - }; - - auto bounds = get_bounds(builder, loc); - auto lower_bound = index_cast(std::get<0>(bounds)); - auto upper_bound = index_cast(std::get<1>(bounds)); - auto step = index_cast(std::get<2>(bounds)); - - builder.setInsertionPoint(while_op); - auto loop_op = builder.create( - loc, - lower_bound, - upper_bound, - step, - while_op.getOperands(), // iterArgs - body - ); - - assert(while_op.getNumResults() >= loop_op.getNumResults()); - builder.updateRootInPlace(while_op, [&]() - { - assert(while_op.getNumResults() == before_term.args().size()); - for (auto it : llvm::zip(while_op.getResults(), before_term.args())) - { - auto old_res = std::get<0>(it); - auto operand = std::get<1>(it); - for (auto it2 : llvm::enumerate(before_block.getArguments())) - { - auto arg = it2.value(); - if (arg == operand) - { - assert(it2.index() < loop_op.getNumResults()); - auto new_res = loop_op.getResult(static_cast(it2.index())); - old_res.replaceAllUsesWith(new_res); - } - } - } - }); - - assert(while_op.getOperation()->getUsers().empty()); - builder.eraseOp(while_op); - changed = true; - } - - if (getiter.getOperation()->getUsers().empty()) - { - builder.eraseOp(getiter); - changed = true; - } - return mlir::success(changed); -} - mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) { if ((operands.size() < 1 || operands.size() > 3) || diff --git a/mlir-compiler/src/transforms/loop_utils.cpp b/mlir-compiler/src/transforms/loop_utils.cpp new file mode 100644 index 00000000000..6708de35257 --- /dev/null +++ b/mlir-compiler/src/transforms/loop_utils.cpp @@ -0,0 +1,159 @@ +#include "transforms/loop_utils.hpp" + +#include + +#include +#include +#include +#include +#include + +#include "plier/dialect.hpp" + +namespace +{ +template +Op get_next_op(llvm::iterator_range& iters) +{ + if (iters.empty()) + { + return nullptr; + } + auto res = mlir::dyn_cast(iters.begin()); + if (res) + { + auto next = std::next(iters.begin()); + iters = {next, iters.end()}; + } + return res; +} +} + +mlir::LogicalResult lower_while_to_for( + plier::GetiterOp getiter, mlir::PatternRewriter& builder, + llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, + llvm::function_ref get_iter_val) +{ + llvm::SmallVector to_process; + for (auto user : getiter.getOperation()->getUsers()) + { + if( auto while_op = mlir::dyn_cast(user->getParentOp())) + { + to_process.emplace_back(while_op); + } + } + + bool changed = false; + for (auto while_op : to_process) + { + auto& before_block = while_op.before().front(); + auto iters = llvm::iterator_range(before_block); + auto iternext = get_next_op(iters); + auto pairfirst = get_next_op(iters); + auto pairsecond = get_next_op(iters); + while (get_next_op(iters)) {} // skip casts + auto before_term = get_next_op(iters); + + auto skip_casts = [](mlir::Value op) + { + while (auto cast = mlir::dyn_cast_or_null(op.getDefiningOp())) + { + op = cast.getOperand(); + } + return op; + }; + if (!iternext || !pairfirst || !pairsecond || !before_term || + skip_casts(before_term.condition()) != pairsecond) + { + continue; + } + + auto& after_block = while_op.after().front(); + + auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) + { + mlir::BlockAndValueMapping mapper; + assert(before_block.getNumArguments() == iterargs.size()); + assert(after_block.getNumArguments() == before_term.args().size()); + mapper.map(before_block.getArguments(), iterargs); + for (auto it : llvm::zip(after_block.getArguments(), before_term.args())) + { + auto block_arg = std::get<0>(it); + auto term_arg = std::get<1>(it); + if (term_arg == pairfirst) // iter arg + { + auto iter_val = get_iter_val(builder, loc, pairfirst.getType(), iv); + mapper.map(block_arg, iter_val); + } + else + { + mapper.map(block_arg, mapper.lookupOrDefault(term_arg)); + } + } + + for (auto& op : after_block) // with terminator + { + builder.clone(op, mapper); + } + }; + + auto loc = getiter.getLoc(); + + auto index_cast = [&](mlir::Value val)->mlir::Value + { + if (!val.getType().isa()) + { + return builder.create(loc, val, mlir::IndexType::get(val.getContext())); + } + return val; + }; + + auto bounds = get_bounds(builder, loc); + auto lower_bound = index_cast(std::get<0>(bounds)); + auto upper_bound = index_cast(std::get<1>(bounds)); + auto step = index_cast(std::get<2>(bounds)); + + builder.setInsertionPoint(while_op); + auto loop_op = builder.create( + loc, + lower_bound, + upper_bound, + step, + while_op.getOperands(), // iterArgs + body + ); + + assert(while_op.getNumResults() >= loop_op.getNumResults()); + builder.updateRootInPlace(while_op, [&]() + { + assert(while_op.getNumResults() == before_term.args().size()); + for (auto it : llvm::zip(while_op.getResults(), before_term.args())) + { + auto old_res = std::get<0>(it); + auto operand = std::get<1>(it); + for (auto it2 : llvm::enumerate(before_block.getArguments())) + { + auto arg = it2.value(); + if (arg == operand) + { + assert(it2.index() < loop_op.getNumResults()); + auto new_res = loop_op.getResult(static_cast(it2.index())); + old_res.replaceAllUsesWith(new_res); + } + } + } + }); + + assert(while_op.getOperation()->getUsers().empty()); + builder.eraseOp(while_op); + changed = true; + } + + if (getiter.getOperation()->getUsers().empty()) + { + builder.eraseOp(getiter); + changed = true; + } + return mlir::success(changed); +} + diff --git a/mlir-compiler/src/transforms/loop_utils.hpp b/mlir-compiler/src/transforms/loop_utils.hpp new file mode 100644 index 00000000000..58ae9f635b9 --- /dev/null +++ b/mlir-compiler/src/transforms/loop_utils.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include + +namespace mlir +{ +struct LogicalResult; +class PatternRewriter; +class Value; +class Location; +class OpBuilder; +class Type; +} + +namespace plier +{ +class GetiterOp; +} + +mlir::LogicalResult lower_while_to_for( + plier::GetiterOp getiter, mlir::PatternRewriter& builder, + llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, + llvm::function_ref get_iter_val); From 1dec257008dfb8871817be0cad3b81de6e24f56b Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 12 Nov 2020 15:05:15 +0300 Subject: [PATCH 160/259] unary ops --- mlir-compiler/include/plier/PlierOps.td | 12 ++++++ mlir-compiler/src/dialect.cpp | 6 +++ mlir-compiler/src/lowering.cpp | 24 ++++++++--- mlir-compiler/src/pipelines/plier_to_std.cpp | 44 ++++++++++++++++++++ numba/mlir/tests/test_basic.py | 11 +++++ 5 files changed, 91 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 0e9e81ab27c..b5d9138f41c 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -65,6 +65,18 @@ def BinOp : Plier_Op<"binop", []> { ]; } +def UnaryOp : Plier_Op<"unary", []> { + let arguments = (ins + AnyType:$value, + StrAttr:$op); + + let results = (outs AnyType); + + let builders = [ + OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value, ::mlir::StringRef op"> + ]; +} + def CastOp : Plier_Op<"cast", []> { let arguments = (ins AnyType:$value); diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 528dc912c80..e5879696713 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -108,6 +108,12 @@ void BinOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, rhs, op); } +void UnaryOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value, mlir::StringRef op) { + UnaryOp::build(builder, state, PyType::getUndefined(state.getContext()), + value, op); +} + mlir::OpFoldResult CastOp::fold(llvm::ArrayRef /*operands*/) { auto op_type = getOperand().getType(); diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index a21547d913d..f33ce346e7e 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -113,8 +113,10 @@ struct inst_handles }; static const constexpr OpId ops_names[] = { - {"+", "add"}, - {"-", "sub"}, + {"+", "add"}, // binary + {"+", "pos"}, // unary + {"-", "sub"}, // binary + {"-", "neg"}, // unary {"*", "mul"}, {">", "gt"}, @@ -282,6 +284,7 @@ struct plier_lowerer using func_t = mlir::Value (plier_lowerer::*)(const py::handle&); const std::pair handlers[] = { {"binop", &plier_lowerer::lower_binop}, + {"unary", &plier_lowerer::lower_unary}, {"cast", &plier_lowerer::lower_cast}, {"call", &plier_lowerer::lower_call}, {"phi", &plier_lowerer::lower_phi}, @@ -401,17 +404,26 @@ struct plier_lowerer auto rhs_name = expr.attr("rhs"); auto lhs = loadvar(lhs_name); auto rhs = loadvar(rhs_name); - return resolve_op(lhs, rhs, op); + auto op_name = resolve_op(op); + return builder.create(get_current_loc(), lhs, rhs, op_name); } - mlir::Value resolve_op(mlir::Value lhs, mlir::Value rhs, const py::handle& op) + mlir::Value lower_unary(const py::handle& expr) + { + auto op = expr.attr("fn"); + auto val_name = expr.attr("value"); + auto val = loadvar(val_name); + auto op_name = resolve_op(op); + return builder.create(get_current_loc(), val, op_name); + } + + llvm::StringRef resolve_op(const py::handle& op) { for (auto elem : llvm::zip(insts.ops_names, insts.ops_handles)) { if (op.is(std::get<1>(elem))) { - auto op_name = std::get<0>(elem).op; - return builder.create(get_current_loc(), lhs, rhs, op_name); + return std::get<0>(elem).op; } } diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 7c080eb3a69..7319aede6e0 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -559,6 +559,49 @@ struct BinOpLowering : public mlir::OpRewritePattern } }; +mlir::Value negate(mlir::Value val, mlir::Location loc, mlir::PatternRewriter &rewriter) +{ + auto type = val.getType(); + if (auto itype = type.dyn_cast()) + { + // TODO: not int negation? + auto zero = rewriter.create(loc, mlir::IntegerAttr::get(itype, 0)); + return rewriter.create(loc, zero, val); + } + if (type.isa()) + { + return rewriter.create(loc, val); + } + llvm_unreachable("negate: unsupported type"); +} + +struct UnaryOpLowering : public mlir::OpRewritePattern +{ + UnaryOpLowering(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + plier::UnaryOp op, mlir::PatternRewriter &rewriter) const override + { + auto arg = op.getOperand(); + auto type = arg.getType(); + if (!is_supported_type(type)) + { + return mlir::failure(); + } + if (op.op() == "+") + { + rewriter.replaceOp(op, arg); + return mlir::success(); + } + assert(op.op() == "-"); + auto new_val = negate(arg, op.getLoc(), rewriter); + rewriter.replaceOp(op, new_val); + return mlir::success(); + } +}; + mlir::Block* get_next_block(mlir::Block* block) { assert(nullptr != block); @@ -1090,6 +1133,7 @@ void PlierToStdPass::runOnOperation() SelectOpLowering, CondBrOpLowering, BinOpLowering, + UnaryOpLowering, ScfIfRewrite, ScfWhileRewrite, FixupWhileTypes diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 55216e6bbdd..40e8dca7cfb 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -32,6 +32,17 @@ def test_ops(self): for a, b in itertools.product(_test_values, _test_values): assert_equal(py_func(a, b), jit_func(a, b)) + def test_unary_ops(self): + py_funcs = [ + lambda a: +a, + lambda a: -a, + ] + + for py_func in py_funcs: + jit_func = njit(py_func) + for a in _test_values: + assert_equal(py_func(a), jit_func(a)) + def test_cmp_ops(self): py_funcs = [ lambda a, b: a if a > b else b, From aade610fe7709e1b88385e61a8b7c7659756e229 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 12 Nov 2020 15:05:15 +0300 Subject: [PATCH 161/259] div --- mlir-compiler/src/lowering.cpp | 2 + mlir-compiler/src/pipelines/plier_to_std.cpp | 58 +++++++++++++------- numba/mlir/tests/test_basic.py | 8 ++- 3 files changed, 47 insertions(+), 21 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index f33ce346e7e..389473ef3b7 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -118,6 +118,8 @@ struct inst_handles {"-", "sub"}, // binary {"-", "neg"}, // unary {"*", "mul"}, + {"/", "truediv"}, + {"//", "floordiv"}, {">", "gt"}, {">=", "ge"}, diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 7319aede6e0..b02e9c31412 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -192,22 +193,6 @@ bool is_supported_type(mlir::Type type) return type.isIntOrFloat(); } -template -void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) -{ - assert(nullptr != op); - rewriter.replaceOpWithNewOp(op, new_type, operands); -} - -template -void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) -{ - assert(nullptr != op); - auto pred_attr = mlir::IntegerAttr::get(mlir::IntegerType::get(64, op->getContext()), Pred); - mlir::Type new_type = mlir::IntegerType::get(1, op->getContext()); - rewriter.replaceOpWithNewOp(op, new_type, pred_attr, operands[0], operands[1]); -} - bool is_int(mlir::Type type) { assert(type); @@ -472,11 +457,36 @@ mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& return nullptr; } +template +void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) +{ + assert(nullptr != op); + rewriter.replaceOpWithNewOp(op, new_type, operands); +} + +void replace_itruediv_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) +{ + assert(nullptr != op); + auto lhs = do_cast(new_type, operands[0], rewriter); + auto rhs = do_cast(new_type, operands[1], rewriter); + rewriter.replaceOpWithNewOp(op, lhs, rhs); +} + +template +void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) +{ + assert(nullptr != op); + auto pred_attr = mlir::IntegerAttr::get(mlir::IntegerType::get(64, op->getContext()), Pred); + mlir::Type new_type = mlir::IntegerType::get(1, op->getContext()); + rewriter.replaceOpWithNewOp(op, new_type, pred_attr, operands[0], operands[1]); +} + + struct BinOpLowering : public mlir::OpRewritePattern { - BinOpLowering(mlir::TypeConverter &/*typeConverter*/, + BinOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *context): - OpRewritePattern(context) {} + OpRewritePattern(context), converter(typeConverter) {} mlir::LogicalResult matchAndRewrite( plier::BinOp op, mlir::PatternRewriter &rewriter) const override @@ -489,6 +499,11 @@ struct BinOpLowering : public mlir::OpRewritePattern { return mlir::failure(); } + auto res_type = converter.convertType(op.getType()); + if (!res_type || !is_supported_type(res_type)) + { + return mlir::failure(); + } mlir::Type final_type; std::array converted_operands; if (type0 != type1) @@ -517,6 +532,7 @@ struct BinOpLowering : public mlir::OpRewritePattern {"+", &replace_op, &replace_op}, {"-", &replace_op, &replace_op}, {"*", &replace_op, &replace_op}, + {"/", &replace_itruediv_op, &replace_op}, {">", &replace_cmp_op(mlir::CmpIPredicate::sgt)>, &replace_cmp_op(mlir::CmpFPredicate::OGT)>}, @@ -539,7 +555,7 @@ struct BinOpLowering : public mlir::OpRewritePattern { if (h.type == op.op()) { - (h.*mem)(op, rewriter, final_type, converted_operands); + (h.*mem)(op, rewriter, res_type, converted_operands); return mlir::success(); } } @@ -557,6 +573,8 @@ struct BinOpLowering : public mlir::OpRewritePattern } return mlir::failure(); } +private: + mlir::TypeConverter& converter; }; mlir::Value negate(mlir::Value val, mlir::Location loc, mlir::PatternRewriter &rewriter) @@ -1147,6 +1165,8 @@ void PlierToStdPass::runOnOperation() CallOpLowering >(type_converter, context, &basic_rewrite); + mlir::populateStdExpandDivsRewritePatterns(context, patterns); + for (auto *op : context->getRegisteredOperations()) { op->getCanonicalizationPatterns(patterns, context); diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 40e8dca7cfb..5206bedb354 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -24,13 +24,17 @@ def test_ops(self): lambda a, b: a + b, lambda a, b: a - b, lambda a, b: a * b, - # TODO: div + lambda a, b: a / b, + # TODO: floordiv ] for py_func in py_funcs: jit_func = njit(py_func) for a, b in itertools.product(_test_values, _test_values): - assert_equal(py_func(a, b), jit_func(a, b)) + try: + assert_equal(py_func(a, b), jit_func(a, b)) + except ZeroDivisionError: + pass def test_unary_ops(self): py_funcs = [ From 42f3e391a92aff4be415cb23def10ae0f6f0e24e Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 12 Nov 2020 15:05:15 +0300 Subject: [PATCH 162/259] get_const_val float --- mlir-compiler/src/lowering.cpp | 4 ++++ numba/mlir/tests/test_basic.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 389473ef3b7..d74c6ca078b 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -493,6 +493,10 @@ struct plier_lowerer { return builder.getI64IntegerAttr(val.cast()); } + if (py::isinstance(val)) + { + return builder.getF64FloatAttr(val.cast()); + } report_error(llvm::Twine("get_const_val unhandled type \"") + py::str(val.get_type()).cast() + "\""); } diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 5206bedb354..cc2b6b1ad0b 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -66,6 +66,8 @@ def test_const_ops(self): py_funcs = [ lambda a: a + 42, lambda a: 43 + a, + lambda a: a + 42.5, + lambda a: 43.5 + a, ] for py_func in py_funcs: From 09cdbaf3560485de37cd88188e496dd7ab3779dc Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 12 Nov 2020 15:05:15 +0300 Subject: [PATCH 163/259] setitem from numba ir --- mlir-compiler/include/plier/PlierOps.td | 9 +++++++++ mlir-compiler/src/lowering.cpp | 11 +++++++++++ 2 files changed, 20 insertions(+) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index b5d9138f41c..607b0025289 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -139,6 +139,15 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { ]; } +def SetItemOp : Plier_Op<"setitem", []> { + let arguments = (ins + AnyType:$target, + AnyType:$index, + AnyType:$value); + + let builders = []; +} + def GetiterOp : Plier_Op<"getiter", [NoSideEffect]> { let arguments = (ins AnyType:$value); diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index d74c6ca078b..92c8e6651f3 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -78,6 +78,7 @@ struct inst_handles Return = mod.attr("Return"); Branch = mod.attr("Branch"); Jump = mod.attr("Jump"); + SetItem = mod.attr("SetItem"); Arg = mod.attr("Arg"); Expr = mod.attr("Expr"); @@ -99,6 +100,7 @@ struct inst_handles py::handle Return; py::handle Branch; py::handle Jump; + py::handle SetItem; py::handle Arg; py::handle Expr; @@ -226,6 +228,10 @@ struct plier_lowerer auto val = lower_assign(inst, target); storevar(val, target); } + else if (py::isinstance(inst, insts.SetItem)) + { + setitem(inst.attr("target"), inst.attr("index"), inst.attr("value")); + } else if (py::isinstance(inst, insts.Del)) { delvar(inst.attr("value")); @@ -440,6 +446,11 @@ struct plier_lowerer return builder.create(get_current_loc(), value, name); } + void setitem(const py::handle& target, const py::handle& index, const py::handle& value) + { + builder.create(get_current_loc(), loadvar(target), loadvar(index), loadvar(value)); + } + void storevar(mlir::Value val, const py::handle& inst) { vars_map[inst.attr("name").cast()] = val; From 8de54d109987d9cc2639e4812e96847e2d7f4c5c Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 12 Nov 2020 15:05:15 +0300 Subject: [PATCH 164/259] functions returning None --- mlir-compiler/include/plier/dialect.hpp | 1 + mlir-compiler/src/dialect.cpp | 5 +++++ mlir-compiler/src/lowering.cpp | 19 ++++++++++++------ mlir-compiler/src/pipelines/lower_to_llvm.cpp | 15 ++++++++------ .../src/pipelines/plier_to_linalg.cpp | 2 +- mlir-compiler/src/pipelines/plier_to_std.cpp | 20 +++++++++++++------ mlir-compiler/src/pipelines/plier_to_std.hpp | 3 ++- .../src/rewrites/type_conversion.cpp | 6 +++--- numba/mlir/tests/test_basic.py | 12 +++++++++++ 9 files changed, 60 insertions(+), 23 deletions(-) diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index f5081c48797..1d57f496729 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -27,6 +27,7 @@ class PyType : public mlir::Type::TypeBase<::plier::PyType, mlir::Type, static PyType get(mlir::MLIRContext *context, mlir::StringRef name); static PyType getUndefined(mlir::MLIRContext *context); + static PyType getNone(mlir::MLIRContext *context); mlir::StringRef getName() const; }; diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index e5879696713..c0db90800e9 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -65,6 +65,11 @@ PyType PyType::getUndefined(mlir::MLIRContext* context) return Base::get(context, ""); } +PyType PyType::getNone(mlir::MLIRContext* context) +{ + return Base::get(context, "none"); +} + llvm::StringRef PyType::getName() const { return getImpl()->name; diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 92c8e6651f3..81172016e8f 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -273,8 +273,7 @@ struct plier_lowerer } if (py::isinstance(value, insts.Const)) { - auto val = get_const_val(value.attr("value")); - return builder.create(get_current_loc(), val); + return get_const(value.attr("value")); } if (py::isinstance(value, insts.Global)) { @@ -498,17 +497,25 @@ struct plier_lowerer builder.create(get_current_loc(), mlir::None, block); } - mlir::Attribute get_const_val(const py::handle& val) + mlir::Value get_const(const py::handle& val) { + auto get_val = [&](mlir::Attribute attr) + { + return builder.create(get_current_loc(), attr); + }; if (py::isinstance(val)) { - return builder.getI64IntegerAttr(val.cast()); + return get_val(builder.getI64IntegerAttr(val.cast())); } if (py::isinstance(val)) { - return builder.getF64FloatAttr(val.cast()); + return get_val(builder.getF64FloatAttr(val.cast())); + } + if (py::isinstance(val)) + { + return get_val(builder.getUnitAttr()); } - report_error(llvm::Twine("get_const_val unhandled type \"") + py::str(val.get_type()).cast() + "\""); + report_error(llvm::Twine("get_const unhandled type \"") + py::str(val.get_type()).cast() + "\""); } mlir::FunctionType get_func_type(const py::handle& fnargs, const py::handle& restype) diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 9a1fe776376..2b08c82605d 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -261,7 +261,7 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) return; } auto old_type = func.getType(); - assert(old_type.getNumResults() == 1); + assert(old_type.getNumResults() <= 1); auto& ctx = *old_type.getContext(); llvm::SmallVector args; @@ -320,11 +320,11 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) } }; - add_arg(ptr(old_type.getResult(0))); + auto orig_ret_type = (old_type.getNumResults() != 0 ? old_type.getResult(0) : type_helper.ptr(type_helper.i(8))); + add_arg(ptr(orig_ret_type)); add_arg(ptr(ptr(getExceptInfoType(type_helper)))); auto old_args = old_type.getInputs(); -// std::copy(old_args.begin(), old_args.end(), std::back_inserter(args)); for (auto arg : old_args) { process_arg(arg); @@ -350,16 +350,19 @@ struct ReturnOpLowering : public mlir::OpRewritePattern rewriter.replaceOpWithNewOp(op, ret); }; + rewriter.setInsertionPoint(op); + auto addr = op.getParentRegion()->front().getArgument(0); if (op.getNumOperands() == 0) { - rewriter.setInsertionPoint(op); + assert(addr.getType().isa()); + auto null_type = addr.getType().cast().getElementType(); + auto ll_val = rewriter.create(op.getLoc(), null_type); + rewriter.create(op.getLoc(), ll_val, addr); insert_ret(); return mlir::success(); } else if (op.getNumOperands() == 1) { - rewriter.setInsertionPoint(op); - auto addr = op.getParentRegion()->front().getArgument(0); auto val = op.getOperand(0); auto ll_ret_type = type_converter.convertType(val.getType()); assert(static_cast(ll_ret_type)); diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index e106d20ba31..33a693e1335 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -268,7 +268,7 @@ void PlierToLinalgPass::runOnOperation() mlir::TypeConverter type_converter; // Convert unknown types to itself type_converter.addConversion([](mlir::Type type) { return type; }); - populate_std_type_converter(type_converter); + populate_std_type_converter(getContext(), type_converter); type_converter.addConversion([&](plier::PyType type)->llvm::Optional { auto ret = map_plier_type(type_converter, type); diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index b02e9c31412..3d2895fc1ff 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -242,8 +242,8 @@ struct ReturnOpLowering : public mlir::OpRewritePattern auto operands = op.getOperands(); auto func = mlir::cast(op.getParentOp()); auto res_types = func.getType().getResults(); - assert(res_types.size() == operands.size()); - bool converted = false; + assert(res_types.size() == operands.size() || res_types.empty()); + bool converted = (res_types.size() != operands.size()); llvm::SmallVector new_vals; for (auto it : llvm::zip(operands, res_types)) { @@ -1138,9 +1138,9 @@ void PlierToStdPass::runOnOperation() mlir::TypeConverter type_converter; // Convert unknown types to itself type_converter.addConversion([](mlir::Type type) { return type; }); - populate_std_type_converter(type_converter); auto context = &getContext(); + populate_std_type_converter(*context, type_converter); mlir::OwningRewritePatternList patterns; @@ -1182,16 +1182,24 @@ void populate_plier_to_std_pipeline(mlir::OpPassManager& pm) } } -void populate_std_type_converter(mlir::TypeConverter& converter) +void populate_std_type_converter(mlir::MLIRContext& context, mlir::TypeConverter& converter) { - converter.addConversion([](mlir::Type type)->llvm::Optional + auto none_type = plier::PyType::getNone(&context); + converter.addConversion( + [none_type](mlir::Type type, llvm::SmallVectorImpl& ret_types) + ->llvm::Optional { + if (type == none_type) + { + return mlir::success(); + } auto ret = map_plier_type(type); if (!ret) { return llvm::None; } - return ret; + ret_types.push_back(ret); + return mlir::success(); }); } diff --git a/mlir-compiler/src/pipelines/plier_to_std.hpp b/mlir-compiler/src/pipelines/plier_to_std.hpp index 2965d10335a..80afadac4c9 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.hpp +++ b/mlir-compiler/src/pipelines/plier_to_std.hpp @@ -9,10 +9,11 @@ class StringRef; namespace mlir { +class MLIRContext; class TypeConverter; } -void populate_std_type_converter(mlir::TypeConverter& converter); +void populate_std_type_converter(mlir::MLIRContext& context, mlir::TypeConverter& converter); void register_plier_to_std_pipeline(PipelineRegistry& registry); diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index c985b6e842d..4419a79691e 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -25,9 +25,9 @@ mlir::LogicalResult setBlockSig( builder.setInsertionPointToStart(&block); auto res = builder.create(builder.getUnknownLoc(), arg.getType(), arg); arg.replaceUsesWithIf(res, [&](mlir::OpOperand& op) - { - return op.getOwner() != res; - }); + { + return op.getOwner() != res; + }); for (auto& use : block.getUses()) { diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index cc2b6b1ad0b..ea73b58b17c 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -85,6 +85,18 @@ def py_func(a): for val in _test_values: assert_equal(py_func(val), jit_func(val)) + def test_ret_none(self): + def py_func1(): + return None + + def py_func2(): + pass + + jit_func1 = njit(py_func1) + jit_func2 = njit(py_func2) + assert_equal(py_func1(), jit_func1()) + assert_equal(py_func2(), jit_func2()) + def test_jump(self): def py_func(a, b): c = 3 From ee2173558e30e630d56fa0af267af46909f143a3 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 18 Nov 2020 22:29:19 +0300 Subject: [PATCH 165/259] loop continue --- mlir-compiler/src/pipelines/plier_to_std.cpp | 177 ++++++++++--------- numba/mlir/tests/test_basic.py | 14 ++ 2 files changed, 112 insertions(+), 79 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 3d2895fc1ff..795541b6872 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -670,110 +670,129 @@ struct ScfIfRewrite : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite( mlir::CondBranchOp op, mlir::PatternRewriter &rewriter) const override { - auto true_block = op.getTrueDest(); - auto post_block = get_next_block(true_block); - if (nullptr == post_block) + auto getDest = [&](bool true_dest) { - return mlir::failure(); - } - auto false_block = op.getFalseDest(); - if (false_block != post_block && - get_next_block(false_block) != post_block) - { - return mlir::failure(); - } - - auto start_block = op.getOperation()->getBlock(); - if (!is_blocks_different({start_block, true_block, post_block})) + return true_dest ? op.getTrueDest() : op.getFalseDest(); + }; + auto getOperands = [&](bool true_dest) { - return mlir::failure(); - } - auto cond = op.condition(); - - mlir::BlockAndValueMapping mapper; - llvm::SmallVector yield_vals; - auto copy_block = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Block& block) + return true_dest ? op.getTrueOperands() : op.getFalseOperands(); + }; + auto loc = op.getLoc(); + for (bool reverse : {false, true}) { - mapper.clear(); - for (auto& op : block.without_terminator()) + auto true_block = getDest(!reverse); + auto post_block = get_next_block(true_block); + if (nullptr == post_block) { - builder.clone(op, mapper); + continue; } - auto term = mlir::cast(block.getTerminator()); - yield_vals.clear(); - yield_vals.reserve(term.getNumOperands()); - for (auto op : term.getOperands()) + auto false_block = getDest(reverse); + if (false_block != post_block && + get_next_block(false_block) != post_block) { - yield_vals.emplace_back(mapper.lookupOrDefault(op)); + continue; } - builder.create(loc, yield_vals); - }; - auto true_body = [&](mlir::OpBuilder& builder, mlir::Location loc) - { - copy_block(builder, loc, *true_block); - }; + auto start_block = op.getOperation()->getBlock(); + if (!is_blocks_different({start_block, true_block, post_block})) + { + continue; + } + mlir::Value cond = op.condition(); + if (reverse) + { + auto i1 = mlir::IntegerType::get(1, op.getContext()); + auto one = rewriter.create(loc, mlir::IntegerAttr::get(i1, 1)); + cond = rewriter.create(loc, cond, one); + } - bool has_else = false_block != post_block; - auto res_types = mlir::cast(true_block->getTerminator()).getOperandTypes(); - mlir::scf::IfOp if_op; - if (has_else) - { - auto false_body = [&](mlir::OpBuilder& builder, mlir::Location loc) + mlir::BlockAndValueMapping mapper; + llvm::SmallVector yield_vals; + auto copy_block = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Block& block) { - copy_block(builder, loc, *false_block); + mapper.clear(); + for (auto& op : block.without_terminator()) + { + builder.clone(op, mapper); + } + auto term = mlir::cast(block.getTerminator()); + yield_vals.clear(); + yield_vals.reserve(term.getNumOperands()); + for (auto op : term.getOperands()) + { + yield_vals.emplace_back(mapper.lookupOrDefault(op)); + } + builder.create(loc, yield_vals); }; - if_op = rewriter.create( - op.getLoc(), - res_types, - cond, - true_body, - false_body); - } - else - { - if (res_types.empty()) + + auto true_body = [&](mlir::OpBuilder& builder, mlir::Location loc) { - if_op = rewriter.create( - op.getLoc(), - res_types, - cond, - true_body); - } - else + copy_block(builder, loc, *true_block); + }; + + bool has_else = false_block != post_block; + auto res_types = mlir::cast(true_block->getTerminator()).getOperandTypes(); + mlir::scf::IfOp if_op; + if (has_else) { auto false_body = [&](mlir::OpBuilder& builder, mlir::Location loc) { - auto res = op.getFalseOperands(); - yield_vals.clear(); - yield_vals.reserve(res.size()); - for (auto op : res) - { - yield_vals.emplace_back(mapper.lookupOrDefault(op)); - } - builder.create(loc, yield_vals); + copy_block(builder, loc, *false_block); }; if_op = rewriter.create( - op.getLoc(), + loc, res_types, cond, true_body, false_body); } - } + else + { + if (res_types.empty()) + { + if_op = rewriter.create( + loc, + res_types, + cond, + true_body); + } + else + { + auto false_body = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + auto res = getOperands(reverse); + yield_vals.clear(); + yield_vals.reserve(res.size()); + for (auto op : res) + { + yield_vals.emplace_back(mapper.lookupOrDefault(op)); + } + builder.create(loc, yield_vals); + }; + if_op = rewriter.create( + loc, + res_types, + cond, + true_body, + false_body); + } + } - rewriter.create(op.getLoc(), post_block, if_op.getResults()); - rewriter.eraseOp(op); + rewriter.create(loc, post_block, if_op.getResults()); + rewriter.eraseOp(op); - if (true_block->getUsers().empty()) - { - erase_blocks(true_block); - } - if (false_block->getUsers().empty()) - { - erase_blocks(false_block); + if (true_block->getUsers().empty()) + { + erase_blocks(true_block); + } + if (false_block->getUsers().empty()) + { + erase_blocks(false_block); + } + return mlir::success(); } - return mlir::success(); + return mlir::failure(); } }; diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index ea73b58b17c..10dcfc56112 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -163,6 +163,20 @@ def py_func(n): jit_func = njit(py_func) assert_equal(py_func(10), jit_func(10)) + def test_range_continue(self): + def py_func(n): + res = 0 + res1 = 2 + for i in range(n): + res = res + i + if i < 5: + continue + res1 = res1 + i * 2 + return res + res1 + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + def test_range_nested(self): def py_func(a, b, c): res = 0 From 011499b2c607204f45ce55f39900d03c9c7b445a Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 18 Nov 2020 23:55:15 +0300 Subject: [PATCH 166/259] range index usage after loop --- mlir-compiler/src/transforms/loop_utils.cpp | 18 ++++++++++++++++++ numba/mlir/tests/test_basic.py | 10 ++++++++++ 2 files changed, 28 insertions(+) diff --git a/mlir-compiler/src/transforms/loop_utils.cpp b/mlir-compiler/src/transforms/loop_utils.cpp index 6708de35257..fd08842f233 100644 --- a/mlir-compiler/src/transforms/loop_utils.cpp +++ b/mlir-compiler/src/transforms/loop_utils.cpp @@ -27,6 +27,16 @@ Op get_next_op(llvm::iterator_range& iters) } return res; } + +mlir::Value get_last_iter_value( + mlir::PatternRewriter& builder, mlir::Location loc, + mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value step) +{ + auto len = builder.create(loc, upper_bound, lower_bound); + auto count = builder.create(loc, len, step); + auto inc = builder.create(loc, count, step); + return builder.create(loc, lower_bound, inc); +} } mlir::LogicalResult lower_while_to_for( @@ -139,8 +149,16 @@ mlir::LogicalResult lower_while_to_for( assert(it2.index() < loop_op.getNumResults()); auto new_res = loop_op.getResult(static_cast(it2.index())); old_res.replaceAllUsesWith(new_res); + break; } } + if (operand == pairfirst && !old_res.getUsers().empty()) + { + auto val = get_last_iter_value(builder, loc, lower_bound, upper_bound, step); + auto new_res = builder.create(loc, old_res.getType(), val); + old_res.replaceAllUsesWith(new_res); + } + assert(old_res.getUsers().empty()); } }); diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 10dcfc56112..ce0fd76cd58 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -149,6 +149,16 @@ def py_func(a, b, c): jit_func = njit(py_func) assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + def test_range_use_index_after(self): + def py_func(n): + res = 0 + for i in range(0, n, 2): + res = res + i + return res + i + + jit_func = njit(py_func) + assert_equal(py_func(9), jit_func(9)) + def test_range_if(self): def py_func(n): res = 0 From 25209f4bfd85a3f73d7534b4017bb69f086a2e58 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 19 Nov 2020 00:04:50 +0300 Subject: [PATCH 167/259] test --- numba/mlir/tests/test_basic.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index ce0fd76cd58..7655d630122 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -173,6 +173,23 @@ def py_func(n): jit_func = njit(py_func) assert_equal(py_func(10), jit_func(10)) + def test_range_ifs(self): + def py_func(n): + res = 0 + for i in range(n): + if i == 2: + res = res + 2 + elif i == 7: + res = res + 5 + elif i == 99: + res = res + 99 + else: + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + def test_range_continue(self): def py_func(n): res = 0 From cd128c1cfbd0f9fb145abfaa262948e2ad72a7b3 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 20 Nov 2020 17:14:44 +0300 Subject: [PATCH 168/259] update llvm --- mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 +- mlir-compiler/src/pipelines/plier_to_std.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 4a1f9a6027f..a4dda885053 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -b0de3f67874ac3eff465cb2ef8ab6081292625c3 +0caa82e2ac53b2ff475531086dfe648fb2d6158a diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 33a693e1335..fc7a62ee8f4 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -331,7 +331,7 @@ void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) pm.addPass(mlir::createLinalgFusionOfTensorOpsPass()); - pm.addPass(mlir::createLinalgBufferizePass()); + pm.addNestedPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::createStdBufferizePass()); pm.addPass(mlir::createFuncBufferizePass()); diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 795541b6872..a205a43cf88 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1184,7 +1184,7 @@ void PlierToStdPass::runOnOperation() CallOpLowering >(type_converter, context, &basic_rewrite); - mlir::populateStdExpandDivsRewritePatterns(context, patterns); + mlir::populateStdExpandOpsPatterns(context, patterns); for (auto *op : context->getRegisteredOperations()) { From cfc63966607d81873c98aabf993f8497093fe38e Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 20 Nov 2020 19:34:47 +0300 Subject: [PATCH 169/259] add jumps param --- mlir-compiler/src/pipeline_registry.cpp | 1 + mlir-compiler/src/pipeline_registry.hpp | 1 + mlir-compiler/src/pipelines/base_pipeline.cpp | 4 ++-- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 2 +- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 +- mlir-compiler/src/pipelines/plier_to_std.cpp | 2 +- 6 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp index 023a1ef7078..21e0889462e 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -90,6 +90,7 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const auto sink = [&](llvm::StringRef pipeline_name, llvm::ArrayRef prev_pipelines, llvm::ArrayRef next_pipelines, + llvm::ArrayRef jumps, pipeline_funt_t func) { assert(!pipeline_name.empty()); diff --git a/mlir-compiler/src/pipeline_registry.hpp b/mlir-compiler/src/pipeline_registry.hpp index 3b348be8533..1bcd3f46a48 100644 --- a/mlir-compiler/src/pipeline_registry.hpp +++ b/mlir-compiler/src/pipeline_registry.hpp @@ -24,6 +24,7 @@ class PipelineRegistry llvm::StringRef pipeline_name, llvm::ArrayRef prev_pipelines, llvm::ArrayRef next_pipelines, + llvm::ArrayRef jumps, pipeline_funt_t func); using registry_entry_t = std::function)>; diff --git a/mlir-compiler/src/pipelines/base_pipeline.cpp b/mlir-compiler/src/pipelines/base_pipeline.cpp index 2d1405b81c0..7e8b35b4ba4 100644 --- a/mlir-compiler/src/pipelines/base_pipeline.cpp +++ b/mlir-compiler/src/pipelines/base_pipeline.cpp @@ -21,11 +21,11 @@ void register_base_pipeline(PipelineRegistry& registry) { if (0 == i) { - sink(passes[i], {}, {}, dummy_pass_func); + sink(passes[i], {}, {}, {}, dummy_pass_func); } else { - sink(passes[i], {passes[i - 1]}, {}, dummy_pass_func); + sink(passes[i], {passes[i - 1]}, {}, {}, dummy_pass_func); } }); } diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 2b08c82605d..8a34ac1d96f 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -568,7 +568,7 @@ void register_lower_to_llvm_pipeline(PipelineRegistry& registry) registry.register_pipeline([](auto sink) { auto stage = get_lower_lowering_stage(); - sink(lower_to_llvm_pipeline_name(), {stage.begin}, {stage.end}, &populate_lower_to_llvm_pipeline); + sink(lower_to_llvm_pipeline_name(), {stage.begin}, {stage.end}, {}, &populate_lower_to_llvm_pipeline); }); } diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index fc7a62ee8f4..13b233b7d97 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -350,7 +350,7 @@ void register_plier_to_linalg_pipeline(PipelineRegistry& registry) registry.register_pipeline([](auto sink) { auto stage = get_high_lowering_stage(); - sink(plier_to_linalg_pipeline_name(), {plier_to_std_pipeline_name()}, {stage.end}, &populate_plier_to_linalg_pipeline); + sink(plier_to_linalg_pipeline_name(), {plier_to_std_pipeline_name()}, {stage.end}, {}, &populate_plier_to_linalg_pipeline); }); } diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index a205a43cf88..f45bd731ab7 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1227,7 +1227,7 @@ void register_plier_to_std_pipeline(PipelineRegistry& registry) registry.register_pipeline([](auto sink) { auto stage = get_high_lowering_stage(); - sink(plier_to_std_pipeline_name(), {stage.begin}, {stage.end}, &populate_plier_to_std_pipeline); + sink(plier_to_std_pipeline_name(), {stage.begin}, {stage.end}, {}, &populate_plier_to_std_pipeline); }); } From 0d6e4757b8522a7620eb5af36f64855c27b0fbfe Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 20 Nov 2020 19:37:52 +0300 Subject: [PATCH 170/259] refac --- mlir-compiler/src/pipeline_registry.cpp | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp index 21e0889462e..d3790c0c319 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -178,9 +178,17 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const topo_visit(get_pipeline_info(name), iter_func, visit_func); } - for (auto current = first_pipeline; nullptr != current; - current = current->next) - { - current->func(pm); - } + auto iterate_pipelines = [&](auto func) + { + for (auto current = first_pipeline; nullptr != current; + current = current->next) + { + func(current); + } + }; + + iterate_pipelines([&](PipelineInfo* pipeline) + { + pipeline->func(pm); + }); } From e60e4e9762055dd1b5f029efc2c21d29608966f8 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 20 Nov 2020 23:52:22 +0300 Subject: [PATCH 171/259] rework pipeline registry --- mlir-compiler/src/compiler.cpp | 156 +++++++++++++++++-- mlir-compiler/src/pipeline_registry.cpp | 27 +++- mlir-compiler/src/pipeline_registry.hpp | 5 +- mlir-compiler/src/pipelines/plier_to_std.cpp | 14 +- 4 files changed, 176 insertions(+), 26 deletions(-) diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index 6b0124ac245..6db2636a107 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -9,16 +9,20 @@ #include +#include + #include "utils.hpp" #include "pipeline_registry.hpp" -class CompilerContext::CompilerContextImpl +namespace { -public: - CompilerContextImpl(mlir::MLIRContext& ctx, - const CompilerContext::Settings& settings, - const PipelineRegistry& registry): +struct PassManagerStage +{ + template + PassManagerStage(mlir::MLIRContext& ctx, + const CompilerContext::Settings& settings, + F&& init_func): pm(&ctx) { pm.enableVerifier(settings.verify); @@ -37,10 +41,142 @@ class CompilerContext::CompilerContextImpl pm.enableIRPrinting(); } - registry.populate_pass_manager(pm); + init_func(pm); + } + + void add_jump(mlir::StringAttr name, PassManagerStage* stage) + { + assert(!name.getValue().empty()); + assert(nullptr != stage); + jumps.emplace_back(name, stage); + } + + PassManagerStage* get_jump(mlir::ArrayAttr names) const + { + for (auto& it : jumps) + { + for (auto name : names) + { + if (it.first == name.cast()) + { + return it.second; + } + } + } + return nullptr; + } + + void set_next_stage(PassManagerStage* stage) + { + assert(nullptr == next_stage); + assert(nullptr != stage); + next_stage = stage; } - void run(mlir::ModuleOp& module) + PassManagerStage* get_next_sgate() const + { + return next_stage; + } + + mlir::LogicalResult run(mlir::ModuleOp op) + { + return pm.run(op); + } + +private: + mlir::PassManager pm; + llvm::SmallVector, 1> jumps; + PassManagerStage* next_stage = nullptr; +}; + +struct PassManagerSchedule +{ + PassManagerSchedule(mlir::MLIRContext& ctx, + const CompilerContext::Settings& settings, + const PipelineRegistry& registry) + { + auto func = [&](auto sink) + { + struct StageDesc + { + llvm::StringRef name; + llvm::ArrayRef jumps; + std::unique_ptr stage; + }; + + assert(nullptr == stages); + llvm::SmallVector stages_temp; + std::unordered_map stages_map; + + auto add_stage = [&](llvm::StringRef name, llvm::ArrayRef jumps, auto pm_init_func) + { + assert(!name.empty()); + auto prev_stage = (stages_map.empty() ? nullptr : stages_temp.back().stage.get()); + stages_temp.push_back({name, jumps, std::make_unique(ctx, settings, pm_init_func)}); + assert(stages_map.count(name.data()) == 0); + stages_map.insert({name.data(), stages_temp.back().stage.get()}); + if (nullptr != prev_stage) + { + prev_stage->set_next_stage(stages_temp.back().stage.get()); + } + }; + + sink(add_stage); + + for (auto& stage : stages_temp) + { + for (auto jump : stage.jumps) + { + assert(!jump.empty()); + auto it = stages_map.find(jump.data()); + assert(it != stages_map.end()); + assert(nullptr != it->second); + auto name = mlir::StringAttr::get(jump, &ctx); + stage.stage->add_jump(name, it->second); + } + } + + stages = std::make_unique[]>(stages_temp.size()); + for (auto it : llvm::enumerate(stages_temp)) + { + stages[it.index()] = std::move(it.value().stage); + } + }; + registry.populate_pass_manager(func); + } + + mlir::LogicalResult run(mlir::ModuleOp module) + { + assert(nullptr != stages); + auto current = stages[0].get(); + do + { + assert(nullptr != current); + if (mlir::failed(current->run(module))) + { + return mlir::failure(); + } + // TODO: jumps + current = current->get_next_sgate(); + } + while (nullptr != current); + return mlir::success(); + } + +private: + std::unique_ptr[]> stages; +}; +} + +class CompilerContext::CompilerContextImpl +{ +public: + CompilerContextImpl(mlir::MLIRContext& ctx, + const CompilerContext::Settings& settings, + const PipelineRegistry& registry): + schedule(ctx, settings, registry) {} + + void run(mlir::ModuleOp module) { std::string err; llvm::raw_string_ostream err_stream(err); @@ -52,9 +188,9 @@ class CompilerContext::CompilerContextImpl } }; - scoped_diag_handler(*pm.getContext(), diag_handler, [&]() + scoped_diag_handler(*module.getContext(), diag_handler, [&]() { - if (mlir::failed(pm.run(module))) + if (mlir::failed(schedule.run(module))) { err_stream << "\n"; module.print(err_stream); @@ -64,7 +200,7 @@ class CompilerContext::CompilerContextImpl }); } private: - mlir::PassManager pm; + PassManagerSchedule schedule; }; CompilerContext::CompilerContext(mlir::MLIRContext& ctx, diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp index d3790c0c319..c7ce65fa25f 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -34,7 +34,7 @@ void topo_visit(T& elem, IterF&& iter_func, VisitF&& func) } } -void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const +void PipelineRegistry::populate_pass_manager(populate_pass_manager_t result_sink) const { llvm::BumpPtrAllocator allocator; llvm::UniqueStringSaver string_set(allocator); @@ -81,8 +81,10 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const PipelineSet next_pipelines; pipeline_funt_t func = nullptr; PipelineInfo* next = nullptr; + llvm::ArrayRef jumps; bool visited = false; bool iterating = false; + bool jump_target = false; }; std::unordered_map pipelines_map; @@ -106,6 +108,16 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const info.func = func; llvm::transform(prev_pipelines, std::back_inserter(info.prev_pipelines), get_pipeline); llvm::transform(next_pipelines, std::back_inserter(info.next_pipelines), get_pipeline); + if (!jumps.empty()) + { + auto data = allocator.Allocate(jumps.size()); + llvm::transform(jumps, data, [&](llvm::StringRef str) + { + assert(!str.empty()); + return string_set.save(str); + }); + info.jumps = { data, jumps.size() }; + } }; for (auto& p : pipelines) @@ -189,6 +201,17 @@ void PipelineRegistry::populate_pass_manager(mlir::OpPassManager& pm) const iterate_pipelines([&](PipelineInfo* pipeline) { - pipeline->func(pm); + for (auto jump : pipeline->jumps) + { + get_pipeline_info(jump).jump_target = true; + } + }); + + result_sink([&](auto add_stage) + { + iterate_pipelines([&](PipelineInfo* pipeline) + { + add_stage(pipeline->name, pipeline->jumps, pipeline->func); + }); }); } diff --git a/mlir-compiler/src/pipeline_registry.hpp b/mlir-compiler/src/pipeline_registry.hpp index 1bcd3f46a48..e742e3b7360 100644 --- a/mlir-compiler/src/pipeline_registry.hpp +++ b/mlir-compiler/src/pipeline_registry.hpp @@ -30,7 +30,10 @@ class PipelineRegistry void register_pipeline(registry_entry_t func); - void populate_pass_manager(mlir::OpPassManager& pm) const; + using fill_stage_sink_t = llvm::function_ref jumps, llvm::function_ref)>; + using populate_pass_manager_sink_t = llvm::function_ref; + using populate_pass_manager_t = llvm::function_ref; + void populate_pass_manager(populate_pass_manager_t result_sink) const; private: std::vector pipelines; diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index f45bd731ab7..0050c1d013d 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1134,24 +1134,12 @@ struct PlierToStdPass : { registry.insert(); registry.insert(); + registry.insert(); } void runOnOperation() override; }; -template -mlir::Value cast_materializer( - mlir::OpBuilder& builder, T type, mlir::ValueRange inputs, - mlir::Location loc) -{ - assert(inputs.size() == 1); - if (type == inputs[0].getType()) - { - return inputs[0]; - } - return builder.create(loc, type, inputs[0]); -} - void PlierToStdPass::runOnOperation() { mlir::TypeConverter type_converter; From bff7091089ef9a155fefbd9482c8fbf63e6dde2d Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 21 Nov 2020 00:14:43 +0300 Subject: [PATCH 172/259] merge pippeline stages --- mlir-compiler/src/pipeline_registry.cpp | 40 +++++++++++++++++++++---- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp index c7ce65fa25f..1f8c164f63f 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -195,23 +195,53 @@ void PipelineRegistry::populate_pass_manager(populate_pass_manager_t result_sink for (auto current = first_pipeline; nullptr != current; current = current->next) { - func(current); + func(*current); } }; - iterate_pipelines([&](PipelineInfo* pipeline) + iterate_pipelines([&](PipelineInfo& pipeline) { - for (auto jump : pipeline->jumps) + for (auto jump : pipeline.jumps) { get_pipeline_info(jump).jump_target = true; } }); + llvm::SmallVector funcs; + llvm::StringRef current_name; + llvm::ArrayRef current_jumps; result_sink([&](auto add_stage) { - iterate_pipelines([&](PipelineInfo* pipeline) + auto flush_stages = [&]() { - add_stage(pipeline->name, pipeline->jumps, pipeline->func); + if (!funcs.empty()) + { + assert(!current_name.empty()); + auto flusher = [&](mlir::OpPassManager& pm) + { + for (auto f : funcs) + { + f(pm); + } + }; + add_stage(current_name, current_jumps, flusher); + funcs.clear(); + current_name = {}; + current_jumps = {}; + } + assert(current_name.empty()); + assert(current_jumps.empty()); + }; + iterate_pipelines([&](PipelineInfo& pipeline) + { + if (&pipeline == first_pipeline || pipeline.jump_target || !pipeline.jumps.empty()) + { + flush_stages(); + current_name = pipeline.name; + current_jumps = pipeline.jumps; + } + funcs.emplace_back(pipeline.func); }); + flush_stages(); }); } From 33d71f47087fd9ccd8b94a3c113988f2deba17dc Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 21 Nov 2020 00:54:51 +0300 Subject: [PATCH 173/259] module pipeline jump markers --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/src/compiler.cpp | 26 +++++-- .../src/transforms/pipeline_utils.cpp | 72 +++++++++++++++++++ .../src/transforms/pipeline_utils.hpp | 14 ++++ 4 files changed, 107 insertions(+), 7 deletions(-) create mode 100644 mlir-compiler/src/transforms/pipeline_utils.cpp create mode 100644 mlir-compiler/src/transforms/pipeline_utils.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 8a1801ae27e..137af14be3b 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -28,6 +28,7 @@ set(SOURCES_LIST src/rewrites/cast_lowering.cpp src/rewrites/type_conversion.cpp src/transforms/loop_utils.cpp + src/transforms/pipeline_utils.cpp src/compiler.cpp src/dialect.cpp src/lowering.cpp @@ -46,6 +47,7 @@ set(HEADERS_LIST src/rewrites/cast_lowering.hpp src/rewrites/type_conversion.hpp src/transforms/loop_utils.hpp + src/transforms/pipeline_utils.hpp src/compiler.hpp src/lowering.hpp src/pipeline_registry.hpp diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index 6db2636a107..ceed240f37d 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -15,6 +15,8 @@ #include "pipeline_registry.hpp" +#include "transforms/pipeline_utils.hpp" + namespace { struct PassManagerStage @@ -51,19 +53,20 @@ struct PassManagerStage jumps.emplace_back(name, stage); } - PassManagerStage* get_jump(mlir::ArrayAttr names) const + std::pair get_jump(mlir::ArrayAttr names) const { for (auto& it : jumps) { for (auto name : names) { - if (it.first == name.cast()) + auto str = name.cast(); + if (it.first == str) { - return it.second; + return {it.second, str}; } } } - return nullptr; + return {nullptr, nullptr}; } void set_next_stage(PassManagerStage* stage) @@ -73,7 +76,7 @@ struct PassManagerStage next_stage = stage; } - PassManagerStage* get_next_sgate() const + PassManagerStage* get_next_stage() const { return next_stage; } @@ -156,8 +159,17 @@ struct PassManagerSchedule { return mlir::failure(); } - // TODO: jumps - current = current->get_next_sgate(); + auto markers = get_pipeline_jump_markers(module); + auto jump_target = current->get_jump(markers); + if (nullptr != jump_target.first) + { + remove_pipeline_jump_marker(module, jump_target.second); + current = jump_target.first; + } + else + { + current = current->get_next_stage(); + } } while (nullptr != current); return mlir::success(); diff --git a/mlir-compiler/src/transforms/pipeline_utils.cpp b/mlir-compiler/src/transforms/pipeline_utils.cpp new file mode 100644 index 00000000000..2f4ba9013ff --- /dev/null +++ b/mlir-compiler/src/transforms/pipeline_utils.cpp @@ -0,0 +1,72 @@ +#include "transforms/pipeline_utils.hpp" + +#include +#include + +mlir::ModuleOp get_module(mlir::Operation* op) +{ + assert(nullptr != op); + while (!mlir::isa(op)) + { + op = op->getParentOp(); + assert(nullptr != op); + } + return mlir::cast(op); +} + +namespace +{ +const constexpr llvm::StringLiteral jump_marker_name("pipeline_jump_markers"); +} + +mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module) +{ + return module.getAttrOfType(jump_marker_name); +} + +void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) +{ + assert(name); + assert(!name.getValue().empty()); + + llvm::SmallVector name_list; + if (auto old_attr = module.getAttrOfType(jump_marker_name)) + { + name_list.assign(old_attr.begin(), old_attr.end()); + } + auto it = llvm::lower_bound(name_list, name, + [](mlir::Attribute lhs, mlir::StringAttr rhs) + { + return lhs.cast().getValue() < rhs.getValue(); + }); + if (it == name_list.end()) + { + name_list.emplace_back(name); + } + else if (*it != name) + { + name_list.insert(it, name); + } + module.setAttr(jump_marker_name, mlir::ArrayAttr::get(name_list, module.getContext())); +} + + +void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) +{ + assert(name); + assert(!name.getValue().empty()); + + llvm::SmallVector name_list; + if (auto old_attr = module.getAttrOfType(jump_marker_name)) + { + name_list.assign(old_attr.begin(), old_attr.end()); + } + auto it = llvm::lower_bound(name_list, name, + [](mlir::Attribute lhs, mlir::StringAttr rhs) + { + return lhs.cast().getValue() < rhs.getValue(); + }); + assert(it != name_list.end()); + name_list.erase(it); + module.setAttr(jump_marker_name, mlir::ArrayAttr::get(name_list, module.getContext())); +} diff --git a/mlir-compiler/src/transforms/pipeline_utils.hpp b/mlir-compiler/src/transforms/pipeline_utils.hpp new file mode 100644 index 00000000000..00fae08f5c1 --- /dev/null +++ b/mlir-compiler/src/transforms/pipeline_utils.hpp @@ -0,0 +1,14 @@ +#pragma once + +namespace mlir +{ +class ArrayAttr; +class Operation; +class ModuleOp; +class StringAttr; +} + +mlir::ModuleOp get_module(mlir::Operation* op); +mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module); +void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name); +void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name); From 296070bee82a69231158390b14bad39f8b6b6eb2 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 18 Nov 2020 14:48:14 +0300 Subject: [PATCH 174/259] array len WIP --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 8 ++++++++ numba/mlir/tests/test_numpy.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 13b233b7d97..753424f6a4c 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -200,6 +200,14 @@ mlir::LogicalResult numpy_rewrite( rewriter.replaceOp(op, res); return mlir::success(); } + if (name == "" && check_numpy_args(args, 1)) + { + auto loc = op.getLoc(); + mlir::Value dim = rewriter.create(loc, args[0], 0); + mlir::Value res = rewriter.create(loc, op.getType(), dim); + rewriter.replaceOp(op, res); + return mlir::success(); + } return mlir::failure(); } diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 760093f7f82..1be51f517e9 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -24,6 +24,14 @@ def py_func(a, b): for i in range(3): assert_equal(py_func(arr, i), jit_func(arr, i)) + def test_array_len(self): + def py_func(a): + return len(a) + + jit_func = njit(py_func) + arr = np.asarray([5,6,7]) + assert_equal(py_func(arr), jit_func(arr)) + def test_sum(self): def py_func(a): return a.sum() From 15df66f7ae9feb17e6e2033f3a3d96af7f6d7050 Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 21 Nov 2020 01:47:27 +0300 Subject: [PATCH 175/259] pipeline control flow usage and fixes --- mlir-compiler/src/compiler.cpp | 13 ++++++++----- mlir-compiler/src/pipeline_registry.cpp | 19 ++++++++++++++----- .../src/pipelines/plier_to_linalg.cpp | 11 ++++++++++- .../src/transforms/pipeline_utils.cpp | 2 +- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index ceed240f37d..eec1b4eafc9 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -55,14 +55,17 @@ struct PassManagerStage std::pair get_jump(mlir::ArrayAttr names) const { - for (auto& it : jumps) + if (names) { - for (auto name : names) + for (auto& it : jumps) { - auto str = name.cast(); - if (it.first == str) + for (auto name : names) { - return {it.second, str}; + auto str = name.cast(); + if (it.first == str) + { + return {it.second, str}; + } } } } diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/src/pipeline_registry.cpp index 1f8c164f63f..e62ae0c1a77 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/src/pipeline_registry.cpp @@ -190,6 +190,8 @@ void PipelineRegistry::populate_pass_manager(populate_pass_manager_t result_sink topo_visit(get_pipeline_info(name), iter_func, visit_func); } + assert(nullptr != first_pipeline); + auto iterate_pipelines = [&](auto func) { for (auto current = first_pipeline; nullptr != current; @@ -201,14 +203,21 @@ void PipelineRegistry::populate_pass_manager(populate_pass_manager_t result_sink iterate_pipelines([&](PipelineInfo& pipeline) { - for (auto jump : pipeline.jumps) + if (!pipeline.jumps.empty()) { - get_pipeline_info(jump).jump_target = true; + for (auto jump : pipeline.jumps) + { + get_pipeline_info(jump).jump_target = true; + } + if (nullptr != pipeline.next) + { + pipeline.next->jump_target = true; + } } }); llvm::SmallVector funcs; - llvm::StringRef current_name; + llvm::StringRef current_name = first_pipeline->name; llvm::ArrayRef current_jumps; result_sink([&](auto add_stage) { @@ -234,13 +243,13 @@ void PipelineRegistry::populate_pass_manager(populate_pass_manager_t result_sink }; iterate_pipelines([&](PipelineInfo& pipeline) { - if (&pipeline == first_pipeline || pipeline.jump_target || !pipeline.jumps.empty()) + if (pipeline.jump_target) { flush_stages(); current_name = pipeline.name; - current_jumps = pipeline.jumps; } funcs.emplace_back(pipeline.func); + current_jumps = pipeline.jumps; }); flush_stages(); }); diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 753424f6a4c..8b753491c42 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -18,6 +18,7 @@ #include "plier/dialect.hpp" #include "pipelines/plier_to_std.hpp" +#include "transforms/pipeline_utils.hpp" #include "rewrites/call_lowering.hpp" #include "rewrites/cast_lowering.hpp" #include "rewrites/type_conversion.hpp" @@ -131,6 +132,13 @@ mlir::Type get_elem_type(mlir::Type type) llvm_unreachable("get_elem_type: unknown type"); } +void rerun_std_pipeline(mlir::Operation* op) +{ + assert(nullptr != op); + auto marker = mlir::StringAttr::get(plier_to_std_pipeline_name(), op->getContext()); + add_pipeline_jump_marker(get_module(op), marker); +} + mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) @@ -205,6 +213,7 @@ mlir::LogicalResult numpy_rewrite( auto loc = op.getLoc(); mlir::Value dim = rewriter.create(loc, args[0], 0); mlir::Value res = rewriter.create(loc, op.getType(), dim); + rerun_std_pipeline(op); rewriter.replaceOp(op, res); return mlir::success(); } @@ -358,7 +367,7 @@ void register_plier_to_linalg_pipeline(PipelineRegistry& registry) registry.register_pipeline([](auto sink) { auto stage = get_high_lowering_stage(); - sink(plier_to_linalg_pipeline_name(), {plier_to_std_pipeline_name()}, {stage.end}, {}, &populate_plier_to_linalg_pipeline); + sink(plier_to_linalg_pipeline_name(), {plier_to_std_pipeline_name()}, {stage.end}, {plier_to_std_pipeline_name()}, &populate_plier_to_linalg_pipeline); }); } diff --git a/mlir-compiler/src/transforms/pipeline_utils.cpp b/mlir-compiler/src/transforms/pipeline_utils.cpp index 2f4ba9013ff..d499f9bdd0b 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.cpp +++ b/mlir-compiler/src/transforms/pipeline_utils.cpp @@ -16,7 +16,7 @@ mlir::ModuleOp get_module(mlir::Operation* op) namespace { -const constexpr llvm::StringLiteral jump_marker_name("pipeline_jump_markers"); +const constexpr llvm::StringLiteral jump_marker_name("#plier.pipeline_jump_markers"); } mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module) From 591aa055357d4197687353fae9fe9d0d13065a20 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 18 Nov 2020 21:27:40 +0300 Subject: [PATCH 176/259] test --- numba/mlir/tests/test_numpy.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 1be51f517e9..178d7f88fd8 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -60,5 +60,14 @@ def py_func(a, b, c): arr3 = np.asarray([7,8,9]) assert_equal(py_func(arr1, arr2, arr3), jit_func(arr1, arr2, arr3)) + def test_setitem(self): + def py_func(a, b): + a[b] = 42 + return a.sum() + + jit_func = njit(py_func) + arr = np.asarray([1,2,3]) + assert_equal(py_func(arr, 1), jit_func(arr, 1)) + if __name__ == '__main__': unittest.main() From 6c50632eb672d12c05511850d7645a02e2cec823 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 20 Nov 2020 15:21:11 +0300 Subject: [PATCH 177/259] work on setitem --- .../src/pipelines/plier_to_linalg.cpp | 109 +++++++++++++++++- numba/mlir/tests/test_numpy.py | 2 +- 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 8b753491c42..d5ad763920a 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -266,6 +266,112 @@ struct GetitemOpLowering : public mlir::OpRewritePattern } }; +bool can_replace_ssa(mlir::Operation* op) +{ + assert(nullptr != op); + if (op->getParentRegion()->getBlocks().size() != 1) + { + return false; + } + auto parent = op->getParentOp(); + if (mlir::isa(parent)) + { + return true; + } + return false; +// return can_replace_ssa(parent); +} + +bool replace_ssa_in_block(mlir::Value value, mlir::Value new_value, mlir::PatternRewriter &rewriter) +{ + auto new_op = new_value.getDefiningOp(); + assert(nullptr != new_op); + auto block = new_op->getBlock(); + bool changed = false; + for (auto user : llvm::make_early_inc_range(value.getUsers())) + { + if (auto op = block->findAncestorOpInBlock(*user)) + { + if (op != new_op && new_op->isBeforeInBlock(op)) + { + rewriter.updateRootInPlace(user, [&]() + { + for (auto it2 : llvm::enumerate(user->getOperands())) + { + if (it2.value() == value) + { + user->setOperand(static_cast(it2.index()), new_value); + break; + } + } + }); + changed = true; + } + } + } + return changed; +} + +bool replace_ssa_value(mlir::Value value, mlir::Value new_value, mlir::PatternRewriter &rewriter) +{ + bool changed = replace_ssa_in_block(value, new_value, rewriter); + auto parent = new_value.getDefiningOp()->getParentOp(); + if (auto func = mlir::dyn_cast(parent)) + { + // TODO update return + return changed; + } + llvm_unreachable("Unhandled parent op"); +} + +mlir::Value index_cast(mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder) +{ + if (!value.getType().isa()) + { + return builder.create(loc, mlir::IndexType::get(value.getContext()), value); + } + return value; +} + +template +struct SetitemOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + T op, mlir::PatternRewriter &rewriter) const override + { + if (!can_replace_ssa(op)) + { + return mlir::failure(); + } + auto target = op.getOperand(0); + auto index = op.getOperand(1); + auto value = op.getOperand(2); + auto target_type = target.getType().template dyn_cast(); + if (!target_type) + { + return mlir::failure(); + } + auto elem_type = target_type.getElementType(); + auto loc = op.getLoc(); + if (value.getType() != elem_type) + { + // TODO + rewriter.create(loc, elem_type, value); +// return mlir::failure(); + } + + auto new_tensor = rewriter.create(loc, value); + auto new_index = index_cast(index, loc, rewriter); + mlir::Value one = rewriter.create(loc, 1); + auto new_value = rewriter.create(loc, new_tensor, target, new_index, one, one); + replace_ssa_value(target, new_value, rewriter); + rewriter.eraseOp(op); + return mlir::success(); + } +}; + struct PlierToLinalgPass : public mlir::PassWrapper> { @@ -308,7 +414,8 @@ void PlierToLinalgPass::runOnOperation() patterns.insert< GetitemOpLowering, - GetitemOpLowering + GetitemOpLowering, + SetitemOpLowering >(&getContext()); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 178d7f88fd8..1263a9291a4 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -63,7 +63,7 @@ def py_func(a, b, c): def test_setitem(self): def py_func(a, b): a[b] = 42 - return a.sum() + return a[b] jit_func = njit(py_func) arr = np.asarray([1,2,3]) From a2f69aadfb1650251646f83d78682fe0b367afd3 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 20 Nov 2020 17:55:34 +0300 Subject: [PATCH 178/259] work on setitem --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index d5ad763920a..e9cebce113a 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -328,7 +328,8 @@ mlir::Value index_cast(mlir::Value value, mlir::Location loc, mlir::OpBuilder& b { if (!value.getType().isa()) { - return builder.create(loc, mlir::IndexType::get(value.getContext()), value); + auto index_type = mlir::IndexType::get(value.getContext()); + return builder.create(loc, value, index_type); } return value; } @@ -358,7 +359,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern if (value.getType() != elem_type) { // TODO - rewriter.create(loc, elem_type, value); + value = rewriter.create(loc, elem_type, value); // return mlir::failure(); } From acc8f81ce4b8e8a4b908e4003943b50d92a12b9e Mon Sep 17 00:00:00 2001 From: Butygin Date: Sat, 21 Nov 2020 01:55:52 +0300 Subject: [PATCH 179/259] rerun_std_pipeline --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index e9cebce113a..02a1adb1b4a 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -360,6 +360,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern { // TODO value = rewriter.create(loc, elem_type, value); + rerun_std_pipeline(op); // return mlir::failure(); } From 322d63eea6f2d886e55edcf7ee1b158ad92e11a0 Mon Sep 17 00:00:00 2001 From: Butygin Date: Tue, 24 Nov 2020 21:32:14 +0300 Subject: [PATCH 180/259] setitem lowering --- .../src/pipelines/plier_to_linalg.cpp | 80 ++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 02a1adb1b4a..cda552958c6 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -335,7 +335,7 @@ mlir::Value index_cast(mlir::Value value, mlir::Location loc, mlir::OpBuilder& b } template -struct SetitemOpLowering : public mlir::OpRewritePattern +struct SetitemOpLoweringSSA : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -388,6 +388,82 @@ struct PlierToLinalgPass : void runOnOperation() override; }; +template +struct SetitemOpLowering : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + T op, mlir::PatternRewriter &rewriter) const override + { + auto rewrite_memref = [&]() + { + auto target = op.getOperand(0); + auto index = op.getOperand(1); + auto value = op.getOperand(2); + auto loc = op.getLoc(); + auto ind = index_cast(index, loc, rewriter); + auto elem_type = target.getType().template cast().getElementType(); + if (value.getType() != elem_type) + { + // TODO + value = rewriter.create(loc, elem_type, value); + rerun_std_pipeline(op); + } + auto store = rewriter.create(loc, value, target, ind); + rewriter.eraseOp(op); + }; + + auto get_target_type = [&]() + { + return op.getOperand(0).getType(); + }; + + if (auto target_type = get_target_type().template dyn_cast()) + { + auto target = op.getOperand(0); + mlir::OpBuilder::InsertionGuard g(rewriter); + if (auto parent_op = target.getDefiningOp()) + { + rewriter.setInsertionPoint(parent_op); + } + else + { + rewriter.setInsertionPointToStart(target.getParentBlock()); + } + auto memref_type = mlir::MemRefType::get(target_type.getShape(), target_type.getElementType()); + auto memref = rewriter.create(target.getLoc(), memref_type, target); + for (auto& use : llvm::make_early_inc_range(target.getUses())) + { + auto use_op = use.getOwner(); + assert(nullptr != use_op); + if (use_op != memref) + { + if (mlir::isa(use_op)) + { + use_op->setOperand(use.getOperandNumber(), memref); + } + else + { + rewriter.setInsertionPoint(use_op); + auto new_val = rewriter.create(use_op->getLoc(), memref); + rewriter.updateRootInPlace(use_op, [&]() + { + use_op->setOperand(use.getOperandNumber(), new_val); + }); + } + } + } + rewrite_memref(); + } + else if (auto target_type = get_target_type().template dyn_cast()) + { + rewrite_memref(); + } + return mlir::success(); + } +}; + void PlierToLinalgPass::runOnOperation() { mlir::TypeConverter type_converter; @@ -464,8 +540,8 @@ void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) pm.addNestedPass(mlir::createPromoteBuffersToStackPass(1024)); pm.addNestedPass(mlir::createBufferHoistingPass()); pm.addNestedPass(mlir::createBufferLoopHoistingPass()); - pm.addNestedPass(mlir::createCopyRemovalPass()); pm.addNestedPass(mlir::createBufferDeallocationPass()); + pm.addNestedPass(mlir::createCopyRemovalPass()); pm.addPass(std::make_unique()); } From 1b8f9a4c90d4f78a6890c51eb5c9b848dd5c281a Mon Sep 17 00:00:00 2001 From: Butygin Date: Tue, 24 Nov 2020 21:56:06 +0300 Subject: [PATCH 181/259] fix --- .../src/pipelines/plier_to_linalg.cpp | 46 ++++++++++--------- numba/mlir/tests/test_numpy.py | 10 ++++ 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index cda552958c6..e764a9efdf1 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -261,6 +261,7 @@ struct GetitemOpLowering : public mlir::OpRewritePattern { llvm_unreachable("Invalid getitem"); } + rerun_std_pipeline(op); rewriter.replaceOp(op, res); return mlir::success(); } @@ -329,7 +330,9 @@ mlir::Value index_cast(mlir::Value value, mlir::Location loc, mlir::OpBuilder& b if (!value.getType().isa()) { auto index_type = mlir::IndexType::get(value.getContext()); - return builder.create(loc, value, index_type); + auto res = builder.create(loc, index_type, value); + rerun_std_pipeline(res); + return res; } return value; } @@ -396,24 +399,6 @@ struct SetitemOpLowering : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite( T op, mlir::PatternRewriter &rewriter) const override { - auto rewrite_memref = [&]() - { - auto target = op.getOperand(0); - auto index = op.getOperand(1); - auto value = op.getOperand(2); - auto loc = op.getLoc(); - auto ind = index_cast(index, loc, rewriter); - auto elem_type = target.getType().template cast().getElementType(); - if (value.getType() != elem_type) - { - // TODO - value = rewriter.create(loc, elem_type, value); - rerun_std_pipeline(op); - } - auto store = rewriter.create(loc, value, target, ind); - rewriter.eraseOp(op); - }; - auto get_target_type = [&]() { return op.getOperand(0).getType(); @@ -454,12 +439,29 @@ struct SetitemOpLowering : public mlir::OpRewritePattern } } } - rewrite_memref(); } - else if (auto target_type = get_target_type().template dyn_cast()) + else if (get_target_type().template isa()) + { + // nothing + } + else + { + return mlir::failure(); + } + auto target = op.getOperand(0); + auto index = op.getOperand(1); + auto value = op.getOperand(2); + auto loc = op.getLoc(); + auto ind = index_cast(index, loc, rewriter); + auto elem_type = target.getType().template cast().getElementType(); + if (value.getType() != elem_type) { - rewrite_memref(); + // TODO + value = rewriter.create(loc, elem_type, value); + rerun_std_pipeline(op); } + auto store = rewriter.create(loc, value, target, ind); + rewriter.eraseOp(op); return mlir::success(); } }; diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 1263a9291a4..a90d0ac7c69 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -69,5 +69,15 @@ def py_func(a, b): arr = np.asarray([1,2,3]) assert_equal(py_func(arr, 1), jit_func(arr, 1)) + def test_setitem_loop(self): + def py_func(a): + for i in range(len(a)): + a[i] = a[i] + i + return a.sum() + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + if __name__ == '__main__': unittest.main() From 8e2023df9d0cfb0ff37168f45ccbaa9a1bbaa115 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 25 Nov 2020 14:58:59 +0300 Subject: [PATCH 182/259] refac --- mlir-compiler/src/rewrites/call_lowering.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/rewrites/call_lowering.cpp b/mlir-compiler/src/rewrites/call_lowering.cpp index 65f308fee0b..8b92898c64a 100644 --- a/mlir-compiler/src/rewrites/call_lowering.cpp +++ b/mlir-compiler/src/rewrites/call_lowering.cpp @@ -8,6 +8,17 @@ llvm::StringRef extract_bound_func_name(llvm::StringRef name) auto len = name.find(' '); return name.substr(0, len); } + +bool check_class_name(llvm::StringRef& str, llvm::StringRef prefix) +{ + llvm::StringRef temp = str; + if (temp.consume_front(prefix) && temp.consume_front("(") && temp.consume_back(")")) + { + str = temp; + return true; + } + return false; +} } CallOpLowering::CallOpLowering( @@ -30,13 +41,13 @@ mlir::LogicalResult CallOpLowering::matchAndRewrite(plier::PyCallOp op, mlir::Pa auto name = func_type.cast().getName(); llvm::SmallVector arg_types; llvm::SmallVector args; - if (name.consume_front("Function(") && name.consume_back(")")) + if (check_class_name(name, "Function")) { llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); // TODO kwargs } - else if (name.consume_front("BoundFunction(") && name.consume_back(")")) + else if (check_class_name(name, "BoundFunction")) { auto getattr = mlir::dyn_cast(operands[0].getDefiningOp()); if (!getattr) From 9184fcc8507d22c2206f065fc01720ac9a6120ff Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 26 Nov 2020 16:51:21 +0300 Subject: [PATCH 183/259] fix --- mlir-compiler/src/lowering.cpp | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 81172016e8f..70e2cb45d6f 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -35,22 +35,11 @@ std::string serialize_mod(const llvm::Module& mod) { std::string ret; llvm::raw_string_ostream stream(ret); -// mod.print(stream, nullptr); llvm::WriteBitcodeToFile(mod, stream); stream.flush(); return ret; } -//template -//std::string to_str(T& obj) -//{ -// std::string ret; -// llvm::raw_string_ostream stream(ret); -// obj.print(stream); -// stream.flush(); -// return ret; -//} - std::vector> get_blocks(const py::object& func) { std::vector> ret; @@ -441,7 +430,7 @@ struct plier_lowerer { auto val = inst.attr("value"); auto value = loadvar(val); - auto name = val.attr("name").cast(); + auto name = inst.attr("attr").cast(); return builder.create(get_current_loc(), value, name); } From aebda2c5d53210c65f8c4cacfb08a4e2cd55056a Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 26 Nov 2020 18:01:25 +0300 Subject: [PATCH 184/259] refac --- mlir-compiler/include/plier/PlierOps.td | 3 ++- mlir-compiler/src/dialect.cpp | 4 ++-- mlir-compiler/src/lowering.cpp | 11 ++++++++--- numba/core/typed_passes.py | 16 ++++++++++++++++ 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 607b0025289..cf9a6a4ff26 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -89,6 +89,7 @@ def PyCallOp : Plier_Op<"call", []> { let arguments = (ins AnyType:$func, Variadic:$args, + StrAttr:$func_name, UI32Attr:$kw_start, ArrayAttr:$kw_names); @@ -96,7 +97,7 @@ def PyCallOp : Plier_Op<"call", []> { let builders = [ OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value func, " - "::mlir::ValueRange args, " + "::mlir::StringRef func_name, ::mlir::ValueRange args, " "::mlir::ArrayRef> kwargs"> ]; } diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index c0db90800e9..82b6610b1ad 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -131,7 +131,7 @@ mlir::OpFoldResult CastOp::fold(llvm::ArrayRef /*operands*/) } void PyCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value func, - mlir::ValueRange args, + llvm::StringRef func_name, mlir::ValueRange args, mlir::ArrayRef> kwargs) { auto ctx = builder.getContext(); mlir::SmallVector all_args; @@ -146,7 +146,7 @@ void PyCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir all_args.push_back(a.second); } PyCallOp::build(builder, state, PyType::getUndefined(state.getContext()), - func, all_args, kw_start, mlir::ArrayAttr::get(kw_names, ctx)); + func, all_args, func_name, kw_start, mlir::ArrayAttr::get(kw_names, ctx)); } void BuildTupleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 70e2cb45d6f..2a6064d68f7 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -123,7 +123,7 @@ struct inst_handles std::array ops_handles; }; -struct plier_lowerer +struct plier_lowerer final { plier_lowerer(mlir::MLIRContext& context): ctx(context), @@ -137,6 +137,7 @@ struct plier_lowerer { auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); typemap = compilation_context["typemap"]; + func_name_resolver = compilation_context["resolve_func"]; auto name = compilation_context["fnname"]().cast(); auto typ = get_func_type(compilation_context["fnargs"], compilation_context["restype"]); func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); @@ -164,6 +165,7 @@ struct plier_lowerer }; py::handle current_instr; py::handle typemap; + py::handle func_name_resolver; std::unordered_map block_infos; @@ -370,7 +372,8 @@ struct plier_lowerer mlir::Value lower_call(const py::handle& expr) { - auto func = loadvar(expr.attr("func")); + auto py_func = expr.attr("func"); + auto func = loadvar(py_func); auto args = expr.attr("args").cast(); auto kws = expr.attr("kws").cast(); auto vararg = expr.attr("vararg"); @@ -389,7 +392,9 @@ struct plier_lowerer kwargs_list.push_back({name.cast(), loadvar(val_name)}); } - return builder.create(get_current_loc(), func, + auto func_name = func_name_resolver(typemap(py_func)).cast(); + + return builder.create(get_current_loc(), func, func_name, args_list, kwargs_list); } diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 5a5b077a091..92c99ab4f47 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -467,6 +467,13 @@ def run_pass(self, state): ) return True +# TODO +import numpy +_mlir_func_names = { + id(range) : 'range', + id(numpy.add) : 'numpy.add' + } + @register_pass(mutates_CFG=True, analysis_only=False) class MlirBackend(LoweringPass): @@ -503,11 +510,20 @@ def run_pass(self, state): ctx['fnargs'] = lambda: state.args ctx['restype'] = lambda: state.return_type ctx['fnname'] = lambda: fn_name + ctx['resolve_func'] = self._resolve_func_name import mlir_compiler mod = mlir_compiler.lower_normal_function(ctx, state.func_ir) setattr(state, 'mlir_blob', mod) return True + def _resolve_func_name(self, obj): + if isinstance(obj, types.Function): + func = obj.typing_key + return _mlir_func_names.get(id(func), None) + if isinstance(obj, types.BoundFunction): + return str(obj.typing_key) + return None + @register_pass(mutates_CFG=True, analysis_only=False) class InlineOverloads(FunctionPass): From 2ec2471283f2afd7cd6ccc700d6435f049314ad9 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 26 Nov 2020 18:12:23 +0300 Subject: [PATCH 185/259] refactor function names --- mlir-compiler/src/lowering.cpp | 8 +++++++- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 4 ++-- mlir-compiler/src/pipelines/plier_to_std.cpp | 4 ++-- mlir-compiler/src/rewrites/call_lowering.cpp | 2 +- numba/core/typed_passes.py | 2 ++ 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 2a6064d68f7..21f7399f9a4 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -392,7 +392,13 @@ struct plier_lowerer final kwargs_list.push_back({name.cast(), loadvar(val_name)}); } - auto func_name = func_name_resolver(typemap(py_func)).cast(); + auto py_func_name = func_name_resolver(typemap(py_func)); + if (py_func_name.is_none()) + { + report_error(llvm::Twine("Can't resolve function: ") + py::str(typemap(py_func)).cast()); + } + + auto func_name = py_func_name.cast(); return builder.create(get_current_loc(), func, func_name, args_list, kwargs_list); diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index e764a9efdf1..c0c5fc991f2 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -143,7 +143,7 @@ mlir::LogicalResult numpy_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { - if (name == "" && check_numpy_args(args, 2)) + if (name == "numpy.add" && check_numpy_args(args, 2)) { auto loc = op.getLoc(); mlir::Value inputs[] = { args[0], args[1] }; @@ -208,7 +208,7 @@ mlir::LogicalResult numpy_rewrite( rewriter.replaceOp(op, res); return mlir::success(); } - if (name == "" && check_numpy_args(args, 1)) + if (name == "len" && check_numpy_args(args, 1)) { auto loc = op.getLoc(); mlir::Value dim = rewriter.create(loc, args[0], 0); diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 0050c1d013d..a7d90745c6a 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1103,8 +1103,8 @@ mlir::LogicalResult basic_rewrite( { using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); std::pair handlers[] = { - {"", lower_bool_cast}, - {"", lower_range}, + {"bool", lower_bool_cast}, + {"range", lower_range}, }; for (auto& handler : handlers) { diff --git a/mlir-compiler/src/rewrites/call_lowering.cpp b/mlir-compiler/src/rewrites/call_lowering.cpp index 8b92898c64a..d107845d5dc 100644 --- a/mlir-compiler/src/rewrites/call_lowering.cpp +++ b/mlir-compiler/src/rewrites/call_lowering.cpp @@ -66,5 +66,5 @@ mlir::LogicalResult CallOpLowering::matchAndRewrite(plier::PyCallOp op, mlir::Pa return mlir::failure(); } - return resolver(op, name, args, rewriter); + return resolver(op, op.func_name(), args, rewriter); } diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 92c99ab4f47..d0b644a4955 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -471,6 +471,8 @@ def run_pass(self, state): import numpy _mlir_func_names = { id(range) : 'range', + id(len) : 'len', + id(bool) : 'bool', id(numpy.add) : 'numpy.add' } From 9e607fc74d218cecb297daf210caff961d4278c7 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 02:42:46 +0300 Subject: [PATCH 186/259] remove unused --- mlir-compiler/src/pipelines/plier_to_std.cpp | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index a7d90745c6a..4083ebd6136 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1116,16 +1116,6 @@ mlir::LogicalResult basic_rewrite( return mlir::failure(); } -mlir::Operation* change_op_ret_type(mlir::Operation* op, - mlir::PatternRewriter& rewriter, - llvm::ArrayRef types) -{ - assert(nullptr != op); - mlir::OperationState state(op->getLoc(), op->getName().getStringRef(), - op->getOperands(), types, op->getAttrs()); - return rewriter.createOperation(state); -} - struct PlierToStdPass : public mlir::PassWrapper> { From 0beedcd7617f6daa248c9a42fb904a36de892e6f Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 02:53:20 +0300 Subject: [PATCH 187/259] func names registry --- numba/core/typed_passes.py | 13 +++---------- numba/mlir/__init__.py | 3 +++ numba/mlir/builtin_funcs.py | 5 +++++ numba/mlir/func_registry.py | 15 +++++++++++++++ numba/mlir/numpy_funcs.py | 5 +++++ 5 files changed, 31 insertions(+), 10 deletions(-) create mode 100644 numba/mlir/builtin_funcs.py create mode 100644 numba/mlir/func_registry.py create mode 100644 numba/mlir/numpy_funcs.py diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index d0b644a4955..bb05330c780 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -467,14 +467,6 @@ def run_pass(self, state): ) return True -# TODO -import numpy -_mlir_func_names = { - id(range) : 'range', - id(len) : 'len', - id(bool) : 'bool', - id(numpy.add) : 'numpy.add' - } @register_pass(mutates_CFG=True, analysis_only=False) class MlirBackend(LoweringPass): @@ -483,7 +475,8 @@ class MlirBackend(LoweringPass): def __init__(self): # LoweringPass.__init__(self) - pass + import numba.mlir.func_registry + self._get_func_name = numba.mlir.func_registry.get_func_name def run_pass(self, state): targetctx = state.targetctx @@ -521,7 +514,7 @@ def run_pass(self, state): def _resolve_func_name(self, obj): if isinstance(obj, types.Function): func = obj.typing_key - return _mlir_func_names.get(id(func), None) + return self._get_func_name(func) if isinstance(obj, types.BoundFunction): return str(obj.typing_key) return None diff --git a/numba/mlir/__init__.py b/numba/mlir/__init__.py index d4ef12a664e..5764159847c 100644 --- a/numba/mlir/__init__.py +++ b/numba/mlir/__init__.py @@ -1,4 +1,7 @@ from numba import runtests +import numba.mlir.builtin_funcs +import numba.mlir.numpy_funcs + def test(*args, **kwargs): return runtests.main("numba.mlir.tests", *args, **kwargs) diff --git a/numba/mlir/builtin_funcs.py b/numba/mlir/builtin_funcs.py new file mode 100644 index 00000000000..bac1fabc921 --- /dev/null +++ b/numba/mlir/builtin_funcs.py @@ -0,0 +1,5 @@ +from numba.mlir.func_registry import add_func + +add_func(range, 'range') +add_func(len, 'len') +add_func(bool, 'bool') diff --git a/numba/mlir/func_registry.py b/numba/mlir/func_registry.py new file mode 100644 index 00000000000..926906b56d7 --- /dev/null +++ b/numba/mlir/func_registry.py @@ -0,0 +1,15 @@ + +_mlir_func_names = {} + # id(range) : 'range', + # id(len) : 'len', + # id(bool) : 'bool', + # id(numpy.add) : 'numpy.add' + # } + +def add_func(func, name): + key = id(func) + assert not key in _mlir_func_names + _mlir_func_names[key] = name + +def get_func_name(func): + return _mlir_func_names.get(id(func), None) diff --git a/numba/mlir/numpy_funcs.py b/numba/mlir/numpy_funcs.py new file mode 100644 index 00000000000..ed965917a1a --- /dev/null +++ b/numba/mlir/numpy_funcs.py @@ -0,0 +1,5 @@ +from numba.mlir.func_registry import add_func + +import numpy + +add_func(numpy.add, 'numpy.add') From 35ed5be0ae5f02c4a3a352956b2129aa8074744d Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 03:49:05 +0300 Subject: [PATCH 188/259] some math funcs support --- mlir-compiler/src/pipelines/plier_to_std.cpp | 56 ++++++++++++++++++++ numba/mlir/__init__.py | 1 + numba/mlir/math_funcs.py | 11 ++++ 3 files changed, 68 insertions(+) create mode 100644 numba/mlir/math_funcs.py diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 4083ebd6136..58fac3f6a33 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1097,10 +1097,66 @@ mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef(mod.lookupSymbol(name))) + { + assert(op.getType() == type); + return op; + } + + mlir::OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(mod.getBody(), + std::prev(mod.getBody()->end())); + auto func = rewriter.create(rewriter.getUnknownLoc(), name, type); + func.setPrivate(); + return func; +} + +mlir::LogicalResult lower_math_func( + plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, + mlir::PatternRewriter& rewriter) +{ + auto ret_type = map_plier_type(op.getType()); + auto valid_type = [&](mlir::Type type) + { + return ret_type == type && type.isa(); + }; + if (ret_type && name.consume_front("math.") && args.size() == 1 && + valid_type(args[0].getType())) + { + auto is_float = ret_type.isa(); + auto func_type = mlir::FunctionType::get(args[0].getType(), ret_type, op.getContext()); + auto module = op.getParentOfType(); + mlir::FuncOp func; + if (is_float) + { + func = get_lib_symbol(module, name.str() + "f", func_type, rewriter); + } + else // double + { + func = get_lib_symbol(module, name, func_type, rewriter); + } + auto call = rewriter.create(op.getLoc(), func, args); + rewriter.replaceOp(op, call.getResults()); + return mlir::success(); + } + + return mlir::failure(); +} + mlir::LogicalResult basic_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { + if (mlir::succeeded(lower_math_func(op, name, args, rewriter))) + { + return mlir::success(); + } using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); std::pair handlers[] = { {"bool", lower_bool_cast}, diff --git a/numba/mlir/__init__.py b/numba/mlir/__init__.py index 5764159847c..7100473c946 100644 --- a/numba/mlir/__init__.py +++ b/numba/mlir/__init__.py @@ -2,6 +2,7 @@ import numba.mlir.builtin_funcs import numba.mlir.numpy_funcs +import numba.mlir.math_funcs def test(*args, **kwargs): return runtests.main("numba.mlir.tests", *args, **kwargs) diff --git a/numba/mlir/math_funcs.py b/numba/mlir/math_funcs.py new file mode 100644 index 00000000000..7811935d431 --- /dev/null +++ b/numba/mlir/math_funcs.py @@ -0,0 +1,11 @@ +from numba.mlir.func_registry import add_func + +import math + +_funcs = ['log', 'sqrt', 'exp', 'erf'] + +for f in _funcs: + fname = 'math.' + f + add_func(eval(fname), fname) + + From 44a317c6251a25162b54f44447b469f06619403f Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 04:07:37 +0300 Subject: [PATCH 189/259] remove get_module --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 +- mlir-compiler/src/transforms/pipeline_utils.cpp | 11 ----------- mlir-compiler/src/transforms/pipeline_utils.hpp | 2 -- 3 files changed, 1 insertion(+), 14 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index c0c5fc991f2..0001fd50141 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -136,7 +136,7 @@ void rerun_std_pipeline(mlir::Operation* op) { assert(nullptr != op); auto marker = mlir::StringAttr::get(plier_to_std_pipeline_name(), op->getContext()); - add_pipeline_jump_marker(get_module(op), marker); + add_pipeline_jump_marker(op->getParentOfType(), marker); } mlir::LogicalResult numpy_rewrite( diff --git a/mlir-compiler/src/transforms/pipeline_utils.cpp b/mlir-compiler/src/transforms/pipeline_utils.cpp index d499f9bdd0b..4d63b75569d 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.cpp +++ b/mlir-compiler/src/transforms/pipeline_utils.cpp @@ -3,17 +3,6 @@ #include #include -mlir::ModuleOp get_module(mlir::Operation* op) -{ - assert(nullptr != op); - while (!mlir::isa(op)) - { - op = op->getParentOp(); - assert(nullptr != op); - } - return mlir::cast(op); -} - namespace { const constexpr llvm::StringLiteral jump_marker_name("#plier.pipeline_jump_markers"); diff --git a/mlir-compiler/src/transforms/pipeline_utils.hpp b/mlir-compiler/src/transforms/pipeline_utils.hpp index 00fae08f5c1..7b53b6acd8a 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.hpp +++ b/mlir-compiler/src/transforms/pipeline_utils.hpp @@ -3,12 +3,10 @@ namespace mlir { class ArrayAttr; -class Operation; class ModuleOp; class StringAttr; } -mlir::ModuleOp get_module(mlir::Operation* op); mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module); void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name); void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name); From 57e92b44a6b20eb212bbb9cf64049e38085e30e1 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 04:27:44 +0300 Subject: [PATCH 190/259] make memref conversion funcs private --- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 8a34ac1d96f..2bed4e8269a 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -184,8 +184,6 @@ std::string gen_conversion_func_name(mlir::MemRefType memref_type) return ret; } -const constexpr llvm::StringRef linkage_attr = "numba_linkage"; - struct MemRefConversionCache { mlir::FuncOp get_conversion_func( @@ -207,7 +205,7 @@ struct MemRefConversionCache auto func_type = mlir::FunctionType::get(src_type, dst_type, builder.getContext()); auto loc = builder.getUnknownLoc(); auto new_func = mlir::FuncOp::create(loc, func_name, func_type); - new_func.setAttr(linkage_attr, mlir::StringAttr::get("internal", builder.getContext())); + new_func.setPrivate(); module.push_back(new_func); cache.insert({memref_type, new_func}); mlir::OpBuilder::InsertionGuard guard(builder); @@ -244,19 +242,9 @@ struct MemRefConversionCache llvm::DenseMap cache; }; -llvm::StringRef get_linkage(mlir::Operation* op) -{ - assert(nullptr != op); - if (auto attr = op->getAttr(linkage_attr).dyn_cast_or_null()) - { - return attr.getValue(); - } - return {}; -} - void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) { - if (get_linkage(func) == "internal") + if (func.isPrivate()) { return; } @@ -420,7 +408,7 @@ class CheckForPlierTypes : if (llvm::any_of(op->getResultTypes(), check_type) || llvm::any_of(op->getOperandTypes(), check_type)) { - op->emitOpError(": not all plier types were translated\n"); + op->emitOpError(": plier types weren't translated\n"); signalPassFailure(); } }); From f44d52b9bb8e7b0dbb847c54a8487192e7e611e3 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 04:39:51 +0300 Subject: [PATCH 191/259] plier.fastmath --- mlir-compiler/src/lowering.cpp | 4 ++++ numba/core/typed_passes.py | 1 + 2 files changed, 5 insertions(+) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 21f7399f9a4..375880bb0d3 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -141,6 +141,10 @@ struct plier_lowerer final auto name = compilation_context["fnname"]().cast(); auto typ = get_func_type(compilation_context["fnargs"], compilation_context["restype"]); func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); + if (compilation_context["fastmath"]().cast()) + { + func.setAttr("#plier.fastmath", mlir::UnitAttr::get(&ctx)); + } lower_func_body(func_ir); mod.push_back(func); return mod; diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index bb05330c780..37013b7806e 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -506,6 +506,7 @@ def run_pass(self, state): ctx['restype'] = lambda: state.return_type ctx['fnname'] = lambda: fn_name ctx['resolve_func'] = self._resolve_func_name + ctx['fastmath'] = lambda: state.targetctx.fastmath import mlir_compiler mod = mlir_compiler.lower_normal_function(ctx, state.func_ir) setattr(state, 'mlir_blob', mod) From 4cdf53dbcea021f9f53484b4a3dfc5218b034e3e Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 04:50:41 +0300 Subject: [PATCH 192/259] refactor attributes --- mlir-compiler/include/plier/dialect.hpp | 5 +++++ mlir-compiler/src/lowering.cpp | 2 +- mlir-compiler/src/transforms/pipeline_utils.cpp | 17 ++++++++--------- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 1d57f496729..7fe8d2e4a40 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -13,6 +13,11 @@ namespace plier { +namespace attributes +{ +const constexpr llvm::StringLiteral fastmath("#plier.fastmath"); +const constexpr llvm::StringLiteral jump_markers("#plier.pipeline_jump_markers"); +} namespace detail { diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 375880bb0d3..68bb27c003b 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -143,7 +143,7 @@ struct plier_lowerer final func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); if (compilation_context["fastmath"]().cast()) { - func.setAttr("#plier.fastmath", mlir::UnitAttr::get(&ctx)); + func.setAttr(plier::attributes::fastmath, mlir::UnitAttr::get(&ctx)); } lower_func_body(func_ir); mod.push_back(func); diff --git a/mlir-compiler/src/transforms/pipeline_utils.cpp b/mlir-compiler/src/transforms/pipeline_utils.cpp index 4d63b75569d..5a280ffe9b4 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.cpp +++ b/mlir-compiler/src/transforms/pipeline_utils.cpp @@ -3,14 +3,11 @@ #include #include -namespace -{ -const constexpr llvm::StringLiteral jump_marker_name("#plier.pipeline_jump_markers"); -} +#include "plier/dialect.hpp" mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module) { - return module.getAttrOfType(jump_marker_name); + return module.getAttrOfType(plier::attributes::jump_markers); } void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) @@ -18,8 +15,9 @@ void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) assert(name); assert(!name.getValue().empty()); + auto jump_markers = plier::attributes::jump_markers; llvm::SmallVector name_list; - if (auto old_attr = module.getAttrOfType(jump_marker_name)) + if (auto old_attr = module.getAttrOfType(jump_markers)) { name_list.assign(old_attr.begin(), old_attr.end()); } @@ -36,7 +34,7 @@ void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) { name_list.insert(it, name); } - module.setAttr(jump_marker_name, mlir::ArrayAttr::get(name_list, module.getContext())); + module.setAttr(jump_markers, mlir::ArrayAttr::get(name_list, module.getContext())); } @@ -45,8 +43,9 @@ void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) assert(name); assert(!name.getValue().empty()); + auto jump_markers = plier::attributes::jump_markers; llvm::SmallVector name_list; - if (auto old_attr = module.getAttrOfType(jump_marker_name)) + if (auto old_attr = module.getAttrOfType(jump_markers)) { name_list.assign(old_attr.begin(), old_attr.end()); } @@ -57,5 +56,5 @@ void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) }); assert(it != name_list.end()); name_list.erase(it); - module.setAttr(jump_marker_name, mlir::ArrayAttr::get(name_list, module.getContext())); + module.setAttr(jump_markers, mlir::ArrayAttr::get(name_list, module.getContext())); } From f2c3b338e92afe00ffaa1b5adcf3cf3e237d2a8c Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 04:57:01 +0300 Subject: [PATCH 193/259] alwaysinline --- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 2bed4e8269a..e527ff54c15 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -206,6 +206,8 @@ struct MemRefConversionCache auto loc = builder.getUnknownLoc(); auto new_func = mlir::FuncOp::create(loc, func_name, func_type); new_func.setPrivate(); + auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); + new_func.setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); module.push_back(new_func); cache.insert({memref_type, new_func}); mlir::OpBuilder::InsertionGuard guard(builder); From 4f2c5bcc16ba1c305c7a76eda97bd5c68b7fd2e0 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 05:11:50 +0300 Subject: [PATCH 194/259] fastmath func flags --- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index e527ff54c15..6a3d5782acb 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -244,12 +244,37 @@ struct MemRefConversionCache llvm::DenseMap cache; }; +mlir::Attribute get_fastmath_attrs(mlir::MLIRContext& ctx) +{ + auto add_pair = [&](auto name, auto val) + { + const mlir::Attribute attrs[] = { + mlir::StringAttr::get(name, &ctx), + mlir::StringAttr::get(val, &ctx) + }; + return mlir::ArrayAttr::get(attrs, &ctx); + }; + const mlir::Attribute attrs[] = { + add_pair("denormal-fp-math", "preserve-sign,preserve-sign"), + add_pair("denormal-fp-math-f32", "ieee,ieee"), + add_pair("no-infs-fp-math", "true"), + add_pair("no-nans-fp-math", "true"), + add_pair("no-signed-zeros-fp-math", "true"), + add_pair("unsafe-fp-math", "true"), + }; + return mlir::ArrayAttr::get(attrs, &ctx); +} + void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) { if (func.isPrivate()) { return; } + if (func.getAttr(plier::attributes::fastmath)) + { + func.setAttr("passthrough", get_fastmath_attrs(*func.getContext())); + } auto old_type = func.getType(); assert(old_type.getNumResults() <= 1); auto& ctx = *old_type.getContext(); From ded3bc3941128e40fdbfa6d0fc5f8ede1254b3e0 Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 18:15:31 +0300 Subject: [PATCH 195/259] update llvm --- mlir-compiler/llvm-sha.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index a4dda885053..7006f85b09a 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -0caa82e2ac53b2ff475531086dfe648fb2d6158a +969918e177adcfd526da7d8e21e5d76860e09c9e From 55992473c06ec2f695ee735a20c2e737e2cd264f Mon Sep 17 00:00:00 2001 From: Butygin Date: Fri, 27 Nov 2020 18:25:39 +0300 Subject: [PATCH 196/259] split linalg lowering into 2 phases --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 17 +++++++++++++---- mlir-compiler/src/pipelines/plier_to_linalg.hpp | 3 ++- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 0001fd50141..96de5001c4f 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -529,10 +529,13 @@ void LowerLinalgPass::runOnOperation() (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } -void populate_plier_to_linalg_pipeline(mlir::OpPassManager& pm) +void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); +} +void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) +{ pm.addPass(mlir::createLinalgFusionOfTensorOpsPass()); pm.addNestedPass(mlir::createLinalgBufferizePass()); @@ -554,11 +557,17 @@ void register_plier_to_linalg_pipeline(PipelineRegistry& registry) registry.register_pipeline([](auto sink) { auto stage = get_high_lowering_stage(); - sink(plier_to_linalg_pipeline_name(), {plier_to_std_pipeline_name()}, {stage.end}, {plier_to_std_pipeline_name()}, &populate_plier_to_linalg_pipeline); + sink(plier_to_linalg_gen_pipeline_name(), {plier_to_std_pipeline_name()}, {plier_to_linalg_opt_pipeline_name()}, {plier_to_std_pipeline_name()}, &populate_plier_to_linalg_gen_pipeline); + sink(plier_to_linalg_opt_pipeline_name(), {plier_to_linalg_gen_pipeline_name()}, {stage.end}, {}, &populate_plier_to_linalg_opt_pipeline); }); } -llvm::StringRef plier_to_linalg_pipeline_name() +llvm::StringRef plier_to_linalg_gen_pipeline_name() +{ + return "plier_to_linalg_gen"; +} + +llvm::StringRef plier_to_linalg_opt_pipeline_name() { - return "plier_to_linalg"; + return "plier_to_linalg_opt"; } diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.hpp b/mlir-compiler/src/pipelines/plier_to_linalg.hpp index a25cd7352ae..f18fa229470 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.hpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.hpp @@ -9,4 +9,5 @@ class StringRef; void register_plier_to_linalg_pipeline(PipelineRegistry& registry); -llvm::StringRef plier_to_linalg_pipeline_name(); +llvm::StringRef plier_to_linalg_gen_pipeline_name(); +llvm::StringRef plier_to_linalg_opt_pipeline_name(); From 011e18a15f99e1aef7cf505ccb54f8aa5330b513 Mon Sep 17 00:00:00 2001 From: Butygin Date: Thu, 3 Dec 2020 14:16:32 +0300 Subject: [PATCH 197/259] update llvm --- mlir-compiler/llvm-sha.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 7006f85b09a..8d7f758355a 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -969918e177adcfd526da7d8e21e5d76860e09c9e +c7dbaec396ef98b8bc6acb7631d2919449986add From e98541142826f02b8865748ed38efafb52c47bfe Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 9 Dec 2020 14:21:04 +0300 Subject: [PATCH 198/259] move MLIR settings to separate file, add soome instructions --- mlir-compiler/readme.md | 18 ++++++++++++++++++ numba/core/lowering.py | 3 ++- numba/core/typed_passes.py | 3 ++- numba/mlir/settings.py | 15 +++++++++++++++ 4 files changed, 37 insertions(+), 2 deletions(-) create mode 100644 mlir-compiler/readme.md create mode 100644 numba/mlir/settings.py diff --git a/mlir-compiler/readme.md b/mlir-compiler/readme.md new file mode 100644 index 00000000000..a84d5fbb4c1 --- /dev/null +++ b/mlir-compiler/readme.md @@ -0,0 +1,18 @@ +# Building MLIR backend + +MLIR backend is not yet integrated into Numba build process + +1. Follow usual numba build instructions (using release llvm) +2. Install pybind11 +3. Build llvm from specific commit required for the backend (numba/mlir-compiler/llvm-sha.txt) +4. Build backend using cmake (numba/mlir-compiler/CMakeLists.txt) using compiled llvm +5. Add dir with compiled backend to PYTHONPATH + +# Running MLIR backend tests + +`python runtests.py numba.mlir.tests` + +# Useful env variables + +`NUMBA_MLIR_ENABLE=1` - enable/diasable MLIR backed (default - 1) +`NUMBA_MLIR_PRINT_IR=1` - dump MLIR IR to stdout before and after each pass (default - 0) diff --git a/numba/core/lowering.py b/numba/core/lowering.py index 3bfaa383eb5..f2f1680d653 100644 --- a/numba/core/lowering.py +++ b/numba/core/lowering.py @@ -11,7 +11,8 @@ from numba.core.funcdesc import default_mangler from numba.core.environment import Environment -_use_mlir = True +import numba.mlir.settings +_use_mlir = numba.mlir.settings.USE_MLIR _VarArgItem = namedtuple("_VarArgItem", ("vararg", "index")) diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 37013b7806e..4dd79c922ee 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -467,6 +467,7 @@ def run_pass(self, state): ) return True +import numba.mlir.settings @register_pass(mutates_CFG=True, analysis_only=False) class MlirBackend(LoweringPass): @@ -500,7 +501,7 @@ def run_pass(self, state): fn_name = fndesc.mangled_name ctx = {} - ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': False} + ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': numba.mlir.settings.PRINT_IR} ctx['typemap'] = lambda op: state.typemap[op.name] ctx['fnargs'] = lambda: state.args ctx['restype'] = lambda: state.return_type diff --git a/numba/mlir/settings.py b/numba/mlir/settings.py new file mode 100644 index 00000000000..ec027eaaaac --- /dev/null +++ b/numba/mlir/settings.py @@ -0,0 +1,15 @@ +from os import environ + +def _readenv(name, ctor, default): + value = environ.get(name) + if value is None: + return default() if callable(default) else default + try: + return ctor(value) + except Exception: + warnings.warn("environ %s defined but failed to parse '%s'" % + (name, value), RuntimeWarning) + return default + +USE_MLIR = _readenv('NUMBA_MLIR_ENABLE', int, 1) +PRINT_IR = _readenv('NUMBA_MLIR_PRINT_IR', int, 0) From 29c638f982bd53f382c9f6cd2106266acb14af38 Mon Sep 17 00:00:00 2001 From: Butygin Date: Wed, 9 Dec 2020 14:25:44 +0300 Subject: [PATCH 199/259] fix --- mlir-compiler/readme.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/readme.md b/mlir-compiler/readme.md index a84d5fbb4c1..df49d6887ce 100644 --- a/mlir-compiler/readme.md +++ b/mlir-compiler/readme.md @@ -14,5 +14,5 @@ MLIR backend is not yet integrated into Numba build process # Useful env variables -`NUMBA_MLIR_ENABLE=1` - enable/diasable MLIR backed (default - 1) -`NUMBA_MLIR_PRINT_IR=1` - dump MLIR IR to stdout before and after each pass (default - 0) +* `NUMBA_MLIR_ENABLE=1` - enable/diasable MLIR backed (default - 1) +* `NUMBA_MLIR_PRINT_IR=1` - dump MLIR IR to stdout before and after each pass (default - 0) From 7cb9155a1f897ba65da4a606ef7ecf2fe2df75bd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 13 Dec 2020 03:55:59 +0300 Subject: [PATCH 200/259] update MLIR (#141) --- mlir-compiler/CMakeLists.txt | 2 ++ mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 7 +++++-- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 137af14be3b..a5ba1b096f8 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -74,6 +74,8 @@ target_link_libraries(${PROJECT_NAME} PRIVATE MLIRStandardToLLVM MLIRLinalgTransforms MLIRSCFToStandard + MLIRTensor + MLIRTensorTransforms ) target_include_directories(${PROJECT_NAME} PRIVATE diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 8d7f758355a..876877dd89f 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -c7dbaec396ef98b8bc6acb7631d2919449986add +d716eab197abec0b9aab4a76cd1a52b248b8c3b1 diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 96de5001c4f..dd7bb308c8f 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include #include @@ -204,7 +206,7 @@ mlir::LogicalResult numpy_rewrite( llvm::makeArrayRef(iterators), body).getResult(0); mlir::Value index = rewriter.create(loc, 0); - mlir::Value res = rewriter.create(loc, val, index); + mlir::Value res = rewriter.create(loc, val, index); rewriter.replaceOp(op, res); return mlir::success(); } @@ -255,7 +257,7 @@ struct GetitemOpLowering : public mlir::OpRewritePattern } else if (is_tensor) { - res = rewriter.create(loc, val, index); + res = rewriter.create(loc, val, index); } else { @@ -538,6 +540,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::createLinalgFusionOfTensorOpsPass()); + pm.addNestedPass(mlir::createTensorBufferizePass()); pm.addNestedPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::createStdBufferizePass()); pm.addPass(mlir::createFuncBufferizePass()); From cdda4614e68603bed1711fea5116f2611892f523 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 30 Dec 2020 23:36:15 +0300 Subject: [PATCH 201/259] update llvm (#144) --- mlir-compiler/include/plier/PlierOps.td | 102 +++++++++--------- mlir-compiler/include/plier/dialect.hpp | 1 - mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/src/compiler.cpp | 3 +- mlir-compiler/src/dialect.cpp | 3 +- mlir-compiler/src/lowering.cpp | 8 +- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 14 +-- .../src/pipelines/plier_to_linalg.cpp | 6 +- mlir-compiler/src/pipelines/plier_to_std.cpp | 24 ++--- .../src/rewrites/type_conversion.cpp | 4 +- .../src/rewrites/type_conversion.hpp | 2 +- .../src/transforms/pipeline_utils.cpp | 2 +- 12 files changed, 83 insertions(+), 88 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index cf9a6a4ff26..1c625a6a0de 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -26,7 +26,7 @@ def ArgOp : Plier_Op<"arg", [NoSideEffect]> { let hasFolder = 1; let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, unsigned index, ::mlir::StringRef name"> + OpBuilderDAG<(ins "unsigned":$index, "::mlir::StringRef":$name)> ]; } @@ -37,44 +37,44 @@ def ConstOp : Plier_Op<"const", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Attribute val"> + OpBuilderDAG<(ins "::mlir::Attribute":$val)> ]; } def GlobalOp : Plier_Op<"global", [NoSideEffect]> { - let arguments = (ins - StrAttr:$name); + let arguments = (ins + StrAttr:$name); - let results = (outs AnyType); + let results = (outs AnyType); - let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::StringRef name"> - ]; + let builders = [ + OpBuilderDAG<(ins "::mlir::StringRef":$name)> + ]; } def BinOp : Plier_Op<"binop", []> { - let arguments = (ins - AnyType:$rhs, - AnyType:$lhs, - StrAttr:$op); + let arguments = (ins + AnyType:$rhs, + AnyType:$lhs, + StrAttr:$op); - let results = (outs AnyType); + let results = (outs AnyType); - let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value rhs, ::mlir::Value lhs, ::mlir::StringRef op"> - ]; + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$rhs, "::mlir::Value":$lhs, "::mlir::StringRef ":$op)> + ]; } def UnaryOp : Plier_Op<"unary", []> { - let arguments = (ins - AnyType:$value, - StrAttr:$op); + let arguments = (ins + AnyType:$value, + StrAttr:$op); - let results = (outs AnyType); + let results = (outs AnyType); - let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value, ::mlir::StringRef op"> - ]; + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::StringRef ":$op)> + ]; } def CastOp : Plier_Op<"cast", []> { @@ -86,20 +86,20 @@ def CastOp : Plier_Op<"cast", []> { } def PyCallOp : Plier_Op<"call", []> { - let arguments = (ins - AnyType:$func, - Variadic:$args, - StrAttr:$func_name, - UI32Attr:$kw_start, - ArrayAttr:$kw_names); + let arguments = (ins + AnyType:$func, + Variadic:$args, + StrAttr:$func_name, + UI32Attr:$kw_start, + ArrayAttr:$kw_names); - let results = (outs AnyType); + let results = (outs AnyType); - let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value func, " - "::mlir::StringRef func_name, ::mlir::ValueRange args, " - "::mlir::ArrayRef> kwargs"> - ]; + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$func, "::mlir::StringRef":$func_name, + "::mlir::ValueRange":$args, + "::mlir::ArrayRef>":$kwargs)> + ]; } def BuildTupleOp : Plier_Op<"build_tuple", [NoSideEffect]> { @@ -109,8 +109,8 @@ def BuildTupleOp : Plier_Op<"build_tuple", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::ValueRange args"> -]; + OpBuilderDAG<(ins "::mlir::ValueRange":$args)> + ]; } def GetItemOp : Plier_Op<"getitem", []> { @@ -121,9 +121,8 @@ def GetItemOp : Plier_Op<"getitem", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value, " - "::mlir::Value index"> -]; + OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::Value":$index)> + ]; } def StaticGetItemOp : Plier_Op<"static_getitem", []> { @@ -135,9 +134,8 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value, " - "::mlir::Value index_var, unsigned index"> -]; + OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::Value":$index_var, "unsigned":$index)> + ]; } def SetItemOp : Plier_Op<"setitem", []> { @@ -156,8 +154,8 @@ def GetiterOp : Plier_Op<"getiter", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value"> -]; + OpBuilderDAG<(ins "::mlir::Value":$value)> + ]; } def IternextOp : Plier_Op<"iternext", []> { @@ -167,8 +165,8 @@ def IternextOp : Plier_Op<"iternext", []> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value"> -]; + OpBuilderDAG<(ins "::mlir::Value":$value)> + ]; } def PairfirstOp : Plier_Op<"pair_first", [NoSideEffect]> { @@ -178,8 +176,8 @@ def PairfirstOp : Plier_Op<"pair_first", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value"> -]; + OpBuilderDAG<(ins "::mlir::Value":$value)> + ]; } def PairsecondOp : Plier_Op<"pair_second", [NoSideEffect]> { @@ -189,8 +187,8 @@ def PairsecondOp : Plier_Op<"pair_second", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value"> -]; + OpBuilderDAG<(ins "::mlir::Value":$value)> + ]; } def DelOp : Plier_Op<"del", []> { @@ -206,7 +204,7 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilder<"::mlir::OpBuilder &b, ::mlir::OperationState &state, ::mlir::Value value, ::mlir::StringRef name"> + OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::StringRef":$name)> ]; } diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 7fe8d2e4a40..c710995145e 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -3,7 +3,6 @@ #include #include #include -#include #include #include "plier/PlierOpsEnums.h.inc" diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 876877dd89f..40226def5fd 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -d716eab197abec0b9aab4a76cd1a52b248b8c3b1 +5abfeccf10bcbc0d673ece21ddd8d4ac4a0e7594 diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/src/compiler.cpp index eec1b4eafc9..9aa5dea96c1 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/src/compiler.cpp @@ -1,7 +1,6 @@ #include "compiler.hpp" -#include -#include +#include #include #include diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 82b6610b1ad..df208cd1e7c 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -1,7 +1,8 @@ #include "plier/dialect.hpp" #include -#include +#include +#include #include #include diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 68bb27c003b..d7af490d42f 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -7,9 +7,9 @@ #include -#include +#include #include -#include +#include #include #include @@ -491,7 +491,7 @@ struct plier_lowerer final auto c = loadvar(cond); auto tr_block = blocks_map.find(tr.cast())->second; auto fl_block = blocks_map.find(fl.cast())->second; - auto cond_val = builder.create(get_current_loc(), mlir::IntegerType::get(1, &ctx), c); + auto cond_val = builder.create(get_current_loc(), mlir::IntegerType::get(&ctx, 1), c); builder.create(get_current_loc(), cond_val, tr_block, fl_block); } @@ -530,7 +530,7 @@ struct plier_lowerer final { args.push_back(get_obj_type(arg)); } - return mlir::FunctionType::get(args, {ret}, &ctx); + return mlir::FunctionType::get(&ctx, args, {ret}); } mlir::Location get_current_loc() diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 6a3d5782acb..4e5b613bf13 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -96,8 +96,8 @@ mlir::LLVM::LLVMStructType get_array_type(mlir::TypeConverter& converter, mlir:: { assert(type); auto ctx = type.getContext(); - auto i8p = mlir::LLVM::LLVMType::getInt8Ty(ctx).getPointerTo(); - auto i64 = mlir::LLVM::LLVMType::getIntNTy(ctx, 64); + auto i8p = mlir::LLVM::LLVMPointerType::get(mlir::LLVM::LLVMIntegerType::get(ctx, 8)); + auto i64 = mlir::LLVM::LLVMIntegerType::get(ctx, 64); auto data_type = converter.convertType(type.getElementType()).cast(); assert(data_type); auto shape_type = mlir::LLVM::LLVMArrayType::get(i64, static_cast(type.getRank())); @@ -106,7 +106,7 @@ mlir::LLVM::LLVMStructType get_array_type(mlir::TypeConverter& converter, mlir:: i8p, // 1, parent i64, // 2, nitems i64, // 3, itemsize - data_type.getPointerTo(), // 4, data + mlir::LLVM::LLVMPointerType::get(data_type), // 4, data shape_type, // 5, shape shape_type, // 6, strides }; @@ -202,7 +202,7 @@ struct MemRefConversionCache return func; } auto func_name = gen_conversion_func_name(memref_type); - auto func_type = mlir::FunctionType::get(src_type, dst_type, builder.getContext()); + auto func_type = mlir::FunctionType::get(builder.getContext(),src_type, dst_type); auto loc = builder.getUnknownLoc(); auto new_func = mlir::FuncOp::create(loc, func_name, func_type); new_func.setPrivate(); @@ -344,8 +344,8 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) { process_arg(arg); } - auto ret_type = mlir::IntegerType::get(32, &ctx); - func.setType(mlir::FunctionType::get(args, ret_type, &ctx)); + auto ret_type = mlir::IntegerType::get(&ctx, 32); + func.setType(mlir::FunctionType::get(&ctx, args, ret_type)); } struct ReturnOpLowering : public mlir::OpRewritePattern @@ -359,7 +359,7 @@ struct ReturnOpLowering : public mlir::OpRewritePattern auto insert_ret = [&]() { auto ctx = op.getContext(); - auto ret_type = mlir::IntegerType::get(32, ctx); + auto ret_type = mlir::IntegerType::get(ctx, 32); auto ll_ret_type = mlir::LLVM::LLVMIntegerType::get(ctx, 32); mlir::Value ret = rewriter.create(op.getLoc(), ll_ret_type, mlir::IntegerAttr::get(ret_type, 0)); rewriter.replaceOpWithNewOp(op, ret); diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index dd7bb308c8f..71547ef5287 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -169,7 +169,6 @@ mlir::LogicalResult numpy_rewrite( mlir::TypeRange(res_type), mlir::ValueRange(inputs), mlir::ValueRange(), // outputs - mlir::ValueRange(), // init llvm::makeArrayRef(map), llvm::makeArrayRef(iterators), body).getResult(0); @@ -180,7 +179,7 @@ mlir::LogicalResult numpy_rewrite( { auto loc = op.getLoc(); mlir::Value inputs[] = { args[0] }; - auto elem_type = mlir::IntegerType::get(64, op.getContext()); + auto elem_type = mlir::IntegerType::get(op.getContext(), 64); auto res_type = mlir::RankedTensorType::get(1, elem_type); mlir::Value zero = rewriter.create(loc, get_zero(elem_type)); mlir::Value init = rewriter.create(loc, zero); @@ -200,8 +199,7 @@ mlir::LogicalResult numpy_rewrite( loc, mlir::TypeRange(res_type), mlir::ValueRange(inputs), - mlir::ValueRange(), // outputs - mlir::ValueRange(init), + mlir::ValueRange(init), // outputs llvm::makeArrayRef(map), llvm::makeArrayRef(iterators), body).getResult(0); diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 58fac3f6a33..a43d57edfc7 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -31,7 +31,7 @@ mlir::Type map_int_type(mlir::MLIRContext& ctx, llvm::StringRef& name) if (name.consume_front("int") && !name.consumeInteger(10, num_bits)) { - return mlir::IntegerType::get(num_bits, &ctx); + return mlir::IntegerType::get(&ctx, num_bits); } return nullptr; } @@ -42,7 +42,7 @@ mlir::Type map_int_literal_type(mlir::MLIRContext& ctx, llvm::StringRef& name) if (name.consume_front("Literal[int](") && !name.consumeInteger(10, dummy) && name.consume_front(")")) { - return mlir::IntegerType::get(64, &ctx); // TODO + return mlir::IntegerType::get(&ctx, 64); // TODO } return nullptr; } @@ -51,7 +51,7 @@ mlir::Type map_bool_type(mlir::MLIRContext& ctx, llvm::StringRef& name) { if (name.consume_front("bool")) { - return mlir::IntegerType::get(1, &ctx); + return mlir::IntegerType::get(&ctx, 1); } return nullptr; } @@ -94,7 +94,7 @@ mlir::Type map_pair_type(mlir::MLIRContext& ctx, llvm::StringRef& name) map_type_helper(ctx, name, second) && name.consume_front(">")) { - return mlir::TupleType::get({first, second}, &ctx); + return mlir::TupleType::get(&ctx, {first, second}); } return nullptr; } @@ -110,7 +110,7 @@ mlir::Type map_unituple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) name.consume_front(")")) { llvm::SmallVector types(count, type); - return mlir::TupleType::get(types, &ctx); + return mlir::TupleType::get(&ctx, types); } return nullptr; } @@ -136,7 +136,7 @@ mlir::Type map_tuple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) types.push_back(type); (void)name.consume_front(", "); } - return mlir::TupleType::get(types, &ctx); + return mlir::TupleType::get(&ctx, types); } mlir::Type map_func_type(mlir::MLIRContext& ctx, llvm::StringRef& name) @@ -145,7 +145,7 @@ mlir::Type map_func_type(mlir::MLIRContext& ctx, llvm::StringRef& name) name.consume_front("") && // TODO unhardcode; name.consume_front(")")) { - return mlir::FunctionType::get({}, {}, &ctx); + return mlir::FunctionType::get(&ctx, {}, {}); } return nullptr; } @@ -476,8 +476,8 @@ template void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) { assert(nullptr != op); - auto pred_attr = mlir::IntegerAttr::get(mlir::IntegerType::get(64, op->getContext()), Pred); - mlir::Type new_type = mlir::IntegerType::get(1, op->getContext()); + auto pred_attr = mlir::IntegerAttr::get(mlir::IntegerType::get(op->getContext(), 64), Pred); + mlir::Type new_type = mlir::IntegerType::get(op->getContext(), 1); rewriter.replaceOpWithNewOp(op, new_type, pred_attr, operands[0], operands[1]); } @@ -702,7 +702,7 @@ struct ScfIfRewrite : public mlir::OpRewritePattern mlir::Value cond = op.condition(); if (reverse) { - auto i1 = mlir::IntegerType::get(1, op.getContext()); + auto i1 = mlir::IntegerType::get(op.getContext(), 1); auto one = rewriter.create(loc, mlir::IntegerAttr::get(i1, 1)); cond = rewriter.create(loc, cond, one); } @@ -1091,7 +1091,7 @@ mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef(src_type) .Case([&](auto) { replace_op(do_cast(dst_type, val, rewriter)); }); return mlir::success(success); @@ -1130,7 +1130,7 @@ mlir::LogicalResult lower_math_func( valid_type(args[0].getType())) { auto is_float = ret_type.isa(); - auto func_type = mlir::FunctionType::get(args[0].getType(), ret_type, op.getContext()); + auto func_type = mlir::FunctionType::get(op.getContext(), args[0].getType(), ret_type); auto module = op.getParentOfType(); mlir::FuncOp func; if (is_float) diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index 4419a79691e..9810cbafc6f 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -127,8 +127,8 @@ mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( // Update the function signature in-place. rewriter.updateRootInPlace(funcOp, [&] { - funcOp.setType(mlir::FunctionType::get(result.getConvertedTypes(), newResults, - funcOp.getContext())); + funcOp.setType(mlir::FunctionType::get( + funcOp.getContext(), result.getConvertedTypes(), newResults)); auto res = convertRegionTypes(&funcOp.getBody(), converter, true); assert(mlir::succeeded(res)); }); diff --git a/mlir-compiler/src/rewrites/type_conversion.hpp b/mlir-compiler/src/rewrites/type_conversion.hpp index 63b9b585bb7..638cb48b5dc 100644 --- a/mlir-compiler/src/rewrites/type_conversion.hpp +++ b/mlir-compiler/src/rewrites/type_conversion.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include #include namespace mlir diff --git a/mlir-compiler/src/transforms/pipeline_utils.cpp b/mlir-compiler/src/transforms/pipeline_utils.cpp index 5a280ffe9b4..19123a07caf 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.cpp +++ b/mlir-compiler/src/transforms/pipeline_utils.cpp @@ -1,6 +1,6 @@ #include "transforms/pipeline_utils.hpp" -#include +#include #include #include "plier/dialect.hpp" From 97311c7ebbb71e591bb5da5ec4a32120fee6d852 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 1 Jan 2021 15:12:53 +0300 Subject: [PATCH 202/259] [MLIR] Canonicalize reduction (#145) --- mlir-compiler/CMakeLists.txt | 2 + .../src/pipelines/plier_to_linalg.cpp | 29 +++ mlir-compiler/src/pipelines/plier_to_std.cpp | 1 - .../src/rewrites/canonicalize_reductions.cpp | 196 ++++++++++++++++++ .../src/rewrites/canonicalize_reductions.hpp | 19 ++ 5 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 mlir-compiler/src/rewrites/canonicalize_reductions.cpp create mode 100644 mlir-compiler/src/rewrites/canonicalize_reductions.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index a5ba1b096f8..ea8a01fe1ff 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -25,6 +25,7 @@ set(SOURCES_LIST src/pipelines/plier_to_linalg.cpp src/pipelines/plier_to_std.cpp src/rewrites/call_lowering.cpp + src/rewrites/canonicalize_reductions.cpp src/rewrites/cast_lowering.cpp src/rewrites/type_conversion.cpp src/transforms/loop_utils.cpp @@ -44,6 +45,7 @@ set(HEADERS_LIST src/pipelines/plier_to_linalg.hpp src/pipelines/plier_to_std.hpp src/rewrites/call_lowering.hpp + src/rewrites/canonicalize_reductions.hpp src/rewrites/cast_lowering.hpp src/rewrites/type_conversion.hpp src/transforms/loop_utils.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 71547ef5287..fc8c35bae7d 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -22,6 +22,7 @@ #include "pipelines/plier_to_std.hpp" #include "transforms/pipeline_utils.hpp" #include "rewrites/call_lowering.hpp" +#include "rewrites/canonicalize_reductions.hpp" #include "rewrites/cast_lowering.hpp" #include "rewrites/type_conversion.hpp" @@ -529,6 +530,33 @@ void LowerLinalgPass::runOnOperation() (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } +struct PostLinalgOptPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +void PostLinalgOptPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + patterns.insert< + CanonicalizeReduction + >(&getContext()); + + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); @@ -550,6 +578,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addNestedPass(mlir::createCopyRemovalPass()); pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); } } diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index a43d57edfc7..817ec6882d3 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -899,7 +899,6 @@ struct ScfWhileRewrite : public mlir::OpRewritePattern auto new_op = builder.clone(op, mapper); for (auto user : op.getUsers()) { - user->isBeforeInBlock(user); if (!is_inside_block(user, before_block)) { for (auto it : llvm::zip(op.getResults(), new_op->getResults())) diff --git a/mlir-compiler/src/rewrites/canonicalize_reductions.cpp b/mlir-compiler/src/rewrites/canonicalize_reductions.cpp new file mode 100644 index 00000000000..ffaed84b509 --- /dev/null +++ b/mlir-compiler/src/rewrites/canonicalize_reductions.cpp @@ -0,0 +1,196 @@ +#include "rewrites/canonicalize_reductions.hpp" + +#include +#include +#include + +namespace +{ +bool checkMemrefType(mlir::Value value) +{ + if (auto type = value.getType().dyn_cast()) + { + auto shape = type.getShape(); + return shape.empty() || (1 == shape.size() && 1 == shape[0]); + } + return false; +} + +bool checkForPotentialAliases(mlir::Value value) +{ + auto def_op = value.getDefiningOp(); + if (nullptr == def_op) + { + return false; + } + if (auto effects = mlir::dyn_cast(def_op)) + { + if (!effects.hasEffect()) + { + return false; + } + } + else + { + return false; + } + for (auto user : value.getUsers()) + { + if (mlir::isa(user)) + { + // TODO: very conservative + return false; + } + } + return true; +} + +bool checkSupportedOps(mlir::Value value, mlir::Operation* parent) +{ + for (auto user : value.getUsers()) + { + if (user->getParentOp() == parent && !mlir::isa(user)) + { + return false; + } + } + return true; +} + +bool checkMemref(mlir::Value value, mlir::Operation* parent) +{ + return checkMemrefType(value) && checkForPotentialAliases(value) && + checkSupportedOps(value, parent); +} + +mlir::Value createScalarLoad( + mlir::PatternRewriter &builder, mlir::Location loc, mlir::Value memref) +{ + auto shape = memref.getType().cast().getShape(); + if (shape.empty()) + { + return builder.create(loc, memref); + } + else if (llvm::all_of(shape, [](auto s) { return s == 1; })) + { + auto index = builder.create(loc, 0); + llvm::SmallVector indices(shape.size(), index); + return builder.create(loc, memref, indices); + } + else + { + llvm_unreachable("Invalid shape"); + } +} + +void createScalarStore( + mlir::PatternRewriter &builder, mlir::Location loc, mlir::Value val, + mlir::Value memref) +{ + auto shape = memref.getType().cast().getShape(); + if (shape.empty()) + { + builder.create(loc, val, memref); + } + else if (llvm::all_of(shape, [](auto s) { return s == 1; })) + { + auto index = builder.create(loc, 0); + llvm::SmallVector indices(shape.size(), index); + builder.create(loc, val, memref, indices); + } + else + { + llvm_unreachable("Invalid shape"); + } +} +} + +mlir::LogicalResult CanonicalizeReduction::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const +{ + llvm::SmallVector to_process; + op.walk([&](mlir::LoadOp load) + { + auto memref = load.memref(); + if (checkMemref(memref, op)) + { + to_process.emplace_back(memref); + } + }); + + if (!to_process.empty()) + { + auto loc = op.getLoc(); + auto init_args = llvm::to_vector<8>(op.initArgs()); + for (auto val : to_process) + { + init_args.emplace_back(createScalarLoad(rewriter, loc, val)); + } + auto prev_args_offset = op.initArgs().size(); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange iter_vals) + { + auto& old_body = op.getLoopBody().front(); + mlir::BlockAndValueMapping mapping; + mapping.map(old_body.getArguments().front(), iter); + mapping.map(old_body.getArguments().drop_front(), iter_vals); + auto yield_args = llvm::to_vector<8>(iter_vals); + for (auto& body_op : old_body.without_terminator()) + { + auto invalid_index = static_cast(-1); + auto get_iter_index = [&](auto op)->unsigned + { + auto arg = op.memref(); + for (auto it : llvm::enumerate(to_process)) + { + if (arg == it.value()) + { + return static_cast(it.index() + prev_args_offset); + } + } + return invalid_index; + }; + if (auto load = mlir::dyn_cast(body_op)) + { + auto index = get_iter_index(load); + if (index != invalid_index) + { + mapping.map(body_op.getResults().front(), yield_args[index]); + } + else + { + builder.clone(body_op, mapping); + } + } + else if (auto store = mlir::dyn_cast(body_op)) + { + auto index = get_iter_index(store); + if (index != invalid_index) + { + yield_args[index] = mapping.lookup(store.value()); + } + else + { + builder.clone(body_op, mapping); + } + } + else + { + builder.clone(body_op, mapping); + } + } + auto yield = mlir::cast(old_body.getTerminator()); + llvm::copy(yield.results(), yield_args.begin()); + builder.create(loc, yield_args); + }; + auto results = rewriter.create(loc, op.lowerBound(), op.upperBound(), op.step(), init_args, body).results(); + for (auto it : llvm::enumerate(to_process)) + { + auto index = prev_args_offset + it.index(); + auto result = results[static_cast(index)]; + createScalarStore(rewriter, loc, result, it.value()); + } + rewriter.replaceOp(op, results.take_front(prev_args_offset)); + return mlir::success(); + } + + return mlir::failure(); +} diff --git a/mlir-compiler/src/rewrites/canonicalize_reductions.hpp b/mlir-compiler/src/rewrites/canonicalize_reductions.hpp new file mode 100644 index 00000000000..1936b067fb3 --- /dev/null +++ b/mlir-compiler/src/rewrites/canonicalize_reductions.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace mlir +{ +namespace scf +{ +class ForOp; +} +} + +struct CanonicalizeReduction : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; +}; From 6655498e81fd32aabdd6b28b84226abfb6ff72c5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 2 Jan 2021 19:36:18 +0300 Subject: [PATCH 203/259] [MLIR] Promote loops to parallel (#146) --- mlir-compiler/CMakeLists.txt | 2 + .../src/pipelines/plier_to_linalg.cpp | 4 +- .../src/rewrites/promote_to_parallel.cpp | 132 ++++++++++++++++++ .../src/rewrites/promote_to_parallel.hpp | 19 +++ 4 files changed, 156 insertions(+), 1 deletion(-) create mode 100644 mlir-compiler/src/rewrites/promote_to_parallel.cpp create mode 100644 mlir-compiler/src/rewrites/promote_to_parallel.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index ea8a01fe1ff..fcb7256f2c3 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -27,6 +27,7 @@ set(SOURCES_LIST src/rewrites/call_lowering.cpp src/rewrites/canonicalize_reductions.cpp src/rewrites/cast_lowering.cpp + src/rewrites/promote_to_parallel.cpp src/rewrites/type_conversion.cpp src/transforms/loop_utils.cpp src/transforms/pipeline_utils.cpp @@ -47,6 +48,7 @@ set(HEADERS_LIST src/rewrites/call_lowering.hpp src/rewrites/canonicalize_reductions.hpp src/rewrites/cast_lowering.hpp + src/rewrites/promote_to_parallel.hpp src/rewrites/type_conversion.hpp src/transforms/loop_utils.hpp src/transforms/pipeline_utils.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index fc8c35bae7d..75c2ce673f5 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -24,6 +24,7 @@ #include "rewrites/call_lowering.hpp" #include "rewrites/canonicalize_reductions.hpp" #include "rewrites/cast_lowering.hpp" +#include "rewrites/promote_to_parallel.hpp" #include "rewrites/type_conversion.hpp" #include "base_pipeline.hpp" @@ -550,7 +551,8 @@ void PostLinalgOptPass::runOnOperation() mlir::OwningRewritePatternList patterns; patterns.insert< - CanonicalizeReduction + CanonicalizeReduction, + PromoteToParallel >(&getContext()); diff --git a/mlir-compiler/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/src/rewrites/promote_to_parallel.cpp new file mode 100644 index 00000000000..1d5cca0b103 --- /dev/null +++ b/mlir-compiler/src/rewrites/promote_to_parallel.cpp @@ -0,0 +1,132 @@ +#include "rewrites/promote_to_parallel.hpp" + +#include +#include +#include + +namespace +{ +bool hasMemWrites(mlir::Operation *op) +{ + return op->walk([&](mlir::Operation *op) + { + if (auto effects = mlir::dyn_cast(op)) + { + if(effects.hasEffect()) + { + return mlir::WalkResult::interrupt(); + } + } + if (mlir::dyn_cast(op)) + { + return mlir::WalkResult::interrupt(); + } + return mlir::WalkResult::advance(); + }).wasInterrupted(); +} +} + +mlir::LogicalResult PromoteToParallel::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const +{ + if (hasMemWrites(op)) + { + return mlir::failure(); + } + + auto& old_body = op.getLoopBody().front(); + auto old_yield = mlir::cast(old_body.getTerminator()); + auto reduce_args = old_body.getArguments().drop_front(); + llvm::SmallVector, 8> reduce_bodies(reduce_args.size()); + llvm::DenseSet reduce_ops; + for (auto it : llvm::enumerate(reduce_args)) + { + auto reduce_arg = it.value(); + auto reduce_index = it.index(); + if (!reduce_arg.hasOneUse()) + { + return mlir::failure(); + } + auto reduce_op = *reduce_arg.user_begin(); + if (reduce_op->getNumOperands() != 2) + { + return mlir::failure(); + } + auto& reduce_body = reduce_bodies[reduce_index]; + while (true) + { + if (!reduce_op->hasOneUse()) + { + return mlir::failure(); + } + reduce_body.push_back(reduce_op); + reduce_ops.insert(reduce_op); + auto next_op = *reduce_op->user_begin(); + if (next_op == old_yield) + { + auto yield_operand = old_yield.getOperand(static_cast(reduce_index)); + if (yield_operand != reduce_op->getResult(0)) + { + return mlir::failure(); + } + break; + } + for (auto operand : next_op->getOperands()) + { + if (operand.getDefiningOp() != reduce_op && + operand.getParentBlock() == &old_body) + { + return mlir::failure(); + } + } + reduce_op = next_op; + } + } + + auto body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange iter_vals, mlir::ValueRange temp) + { + assert(1 == iter_vals.size()); + assert(temp.empty()); + mlir::BlockAndValueMapping mapping; + mapping.map(old_body.getArguments().front(), iter_vals.front()); + for (auto& old_op : old_body.without_terminator()) + { + if (0 == reduce_ops.count(&old_op)) + { + builder.clone(old_op, mapping); + } + } + mlir::BlockAndValueMapping reduce_mapping; + for (auto it : llvm::enumerate(reduce_bodies)) + { + auto& reduce_body = it.value(); + assert(!reduce_body.empty()); + reduce_mapping = mapping; + auto first_op = reduce_body.front(); + assert(first_op->getNumOperands() == 2); + auto reduce_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value val0, mlir::Value val1) + { + reduce_mapping.map(first_op->getOperand(0), val0); + reduce_mapping.map(first_op->getOperand(1), val1); + mlir::Operation* last_op = nullptr; + for (auto reduce_op : reduce_body) + { + last_op = builder.clone(*reduce_op, reduce_mapping); + assert(1 == last_op->getNumResults()); + } + builder.create(loc, last_op->getResult(0)); + }; + auto reduce_arg = reduce_args[it.index()]; + auto first_op_operands = first_op->getOperands(); + auto reduce_operand = (first_op_operands[0] == reduce_arg ? first_op_operands[1] : first_op_operands[0]); + assert(reduce_operand != reduce_arg); + reduce_operand = mapping.lookupOrNull(reduce_operand); + assert(reduce_operand); + builder.create(loc, reduce_operand, reduce_body_builder); + } + }; + + auto parallel_op = rewriter.create(op.getLoc(), op.lowerBound(), op.upperBound(), op.step(), op.initArgs(), body_builder); + rewriter.replaceOp(op, parallel_op.getResults()); + + return mlir::success(); +} diff --git a/mlir-compiler/src/rewrites/promote_to_parallel.hpp b/mlir-compiler/src/rewrites/promote_to_parallel.hpp new file mode 100644 index 00000000000..f246ce0696d --- /dev/null +++ b/mlir-compiler/src/rewrites/promote_to_parallel.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace mlir +{ +namespace scf +{ +class ForOp; +} +} + +struct PromoteToParallel : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; +}; From 7c9e138c5a1f6aa735745bdee566172c577a9e91 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 2 Jan 2021 20:13:57 +0300 Subject: [PATCH 204/259] LoopInvariantCodeMotion (#147) --- .../src/pipelines/plier_to_linalg.cpp | 24 +++++++++++++++++++ .../src/rewrites/promote_to_parallel.cpp | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 75c2ce673f5..9d71483333f 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -546,12 +547,35 @@ struct PostLinalgOptPass : void runOnOperation() override; }; +struct LoopInvariantCodeMotion : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override + { + auto parentOp = op->getParentOp(); + rewriter.startRootUpdate(parentOp); + auto res = mlir::moveLoopInvariantCode(op); + if (mlir::succeeded(res)) + { + rewriter.finalizeRootUpdate(parentOp); + } + else + { + rewriter.cancelRootUpdate(parentOp); + } + return res; + } +}; + void PostLinalgOptPass::runOnOperation() { mlir::OwningRewritePatternList patterns; patterns.insert< CanonicalizeReduction, + LoopInvariantCodeMotion, PromoteToParallel >(&getContext()); diff --git a/mlir-compiler/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/src/rewrites/promote_to_parallel.cpp index 1d5cca0b103..fed041350cd 100644 --- a/mlir-compiler/src/rewrites/promote_to_parallel.cpp +++ b/mlir-compiler/src/rewrites/promote_to_parallel.cpp @@ -73,7 +73,7 @@ mlir::LogicalResult PromoteToParallel::matchAndRewrite(mlir::scf::ForOp op, mlir for (auto operand : next_op->getOperands()) { if (operand.getDefiningOp() != reduce_op && - operand.getParentBlock() == &old_body) + operand.getParentBlock() == &old_body) { return mlir::failure(); } From d6d0b29d2d0fb59e3675692dc1a87fe89b40e7f6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 2 Jan 2021 21:37:52 +0300 Subject: [PATCH 205/259] [MLIR] Prange func (#148) --- mlir-compiler/include/plier/dialect.hpp | 5 +- mlir-compiler/src/dialect.cpp | 15 ++++ mlir-compiler/src/lowering.cpp | 2 +- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 2 +- .../src/pipelines/plier_to_linalg.cpp | 69 +++++++++++++++++-- mlir-compiler/src/pipelines/plier_to_std.cpp | 4 +- .../src/rewrites/promote_to_parallel.cpp | 5 +- mlir-compiler/src/transforms/loop_utils.cpp | 8 ++- mlir-compiler/src/transforms/loop_utils.hpp | 10 ++- .../src/transforms/pipeline_utils.cpp | 6 +- numba/mlir/builtin_funcs.py | 4 ++ numba/mlir/tests/test_basic.py | 10 +++ 12 files changed, 122 insertions(+), 18 deletions(-) diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index c710995145e..cd8be69547b 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -14,8 +14,9 @@ namespace plier { namespace attributes { -const constexpr llvm::StringLiteral fastmath("#plier.fastmath"); -const constexpr llvm::StringLiteral jump_markers("#plier.pipeline_jump_markers"); +llvm::StringRef getFastmathName(); +llvm::StringRef getJumpMarkersName(); +llvm::StringRef getParallelName(); } namespace detail diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index df208cd1e7c..bae439cec15 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -11,6 +11,21 @@ namespace plier { +llvm::StringRef attributes::getFastmathName() +{ + return "#plier.fastmath"; +} + +llvm::StringRef attributes::getJumpMarkersName() +{ + return "#plier.pipeline_jump_markers"; +} + +llvm::StringRef attributes::getParallelName() +{ + return "#plier.parallel"; +} + namespace detail { struct PyTypeStorage : public mlir::TypeStorage diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index d7af490d42f..0ba24307d58 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -143,7 +143,7 @@ struct plier_lowerer final func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); if (compilation_context["fastmath"]().cast()) { - func.setAttr(plier::attributes::fastmath, mlir::UnitAttr::get(&ctx)); + func.setAttr(plier::attributes::getFastmathName(), mlir::UnitAttr::get(&ctx)); } lower_func_body(func_ir); mod.push_back(func); diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 4e5b613bf13..459abae51e6 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -271,7 +271,7 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) { return; } - if (func.getAttr(plier::attributes::fastmath)) + if (func.getAttr(plier::attributes::getFastmathName())) { func.setAttr("passthrough", get_fastmath_attrs(*func.getContext())); } diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 9d71483333f..f630cc920ec 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -27,6 +27,7 @@ #include "rewrites/cast_lowering.hpp" #include "rewrites/promote_to_parallel.hpp" #include "rewrites/type_conversion.hpp" +#include "transforms/loop_utils.hpp" #include "base_pipeline.hpp" #include "pipeline_registry.hpp" @@ -141,13 +142,73 @@ void rerun_std_pipeline(mlir::Operation* op) { assert(nullptr != op); auto marker = mlir::StringAttr::get(plier_to_std_pipeline_name(), op->getContext()); - add_pipeline_jump_marker(op->getParentOfType(), marker); + auto mod = op->getParentOfType(); + assert(nullptr != mod); + add_pipeline_jump_marker(mod, marker); } -mlir::LogicalResult numpy_rewrite( +bool is_int(mlir::Type type) +{ + assert(type); + return type.isa(); +} + +mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +{ + if ((operands.size() < 1 || operands.size() > 3) || + !llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());})) + { + return mlir::failure(); + } + mlir::Value val = op.getResult(); + if (!val.getUsers().empty()) + { + auto user = mlir::dyn_cast(*val.getUsers().begin()); + auto get_bounds = [&](mlir::OpBuilder& builder, mlir::Location loc) + { + auto lower_bound = (operands.size() >= 2 ? operands[0] : builder.create(loc, 0)); + auto upper_bound = (operands.size() >= 2 ? operands[1] : operands[0]); + auto step = (operands.size() == 3 ? operands[2] : builder.create(loc, 1)); + return std::make_tuple(lower_bound, upper_bound, step); + }; + auto get_index = [](mlir::OpBuilder& builder, mlir::Location loc, mlir::Type dst_type, mlir::Value index) + { + return builder.create(loc, dst_type, index); + }; + auto set_attr = [](mlir::scf::ForOp op) + { + op->setAttr(plier::attributes::getParallelName(), mlir::UnitAttr::get(op->getContext())); + }; + if (!user || mlir::failed(lower_while_to_for(user, rewriter, get_bounds, get_index, set_attr))) + { + return mlir::failure(); + } + } + + rerun_std_pipeline(op); + if (val.getUsers().empty()) + { + rewriter.eraseOp(op); + } + return mlir::success(); +} + +mlir::LogicalResult call_rewrite( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); + std::pair handlers[] = { + {"numba.prange", lower_prange}, + }; + for (auto& handler : handlers) + { + if (handler.first == name) + { + return handler.second(op, args, rewriter); + } + } + if (name == "numpy.add" && check_numpy_args(args, 2)) { auto loc = op.getLoc(); @@ -493,7 +554,7 @@ void PlierToLinalgPass::runOnOperation() patterns.insert< CallOpLowering - >(type_converter, &getContext(), &numpy_rewrite); + >(type_converter, &getContext(), &call_rewrite); patterns.insert< GetitemOpLowering, @@ -575,7 +636,7 @@ void PostLinalgOptPass::runOnOperation() patterns.insert< CanonicalizeReduction, - LoopInvariantCodeMotion, +// LoopInvariantCodeMotion, TODO PromoteToParallel >(&getContext()); diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 817ec6882d3..ce333ba26ca 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1044,7 +1044,7 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef { return mlir::failure(); } - mlir::Value val(op); + mlir::Value val = op.getResult(); if (!val.getUsers().empty()) { auto user = mlir::dyn_cast(*val.getUsers().begin()); @@ -1059,7 +1059,7 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef { return builder.create(loc, dst_type, index); }; - if (!user || mlir::failed(lower_while_to_for(user,rewriter, get_bounds, get_index))) + if (!user || mlir::failed(lower_while_to_for(user, rewriter, get_bounds, get_index))) { return mlir::failure(); } diff --git a/mlir-compiler/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/src/rewrites/promote_to_parallel.cpp index fed041350cd..0d7d10f8629 100644 --- a/mlir-compiler/src/rewrites/promote_to_parallel.cpp +++ b/mlir-compiler/src/rewrites/promote_to_parallel.cpp @@ -4,6 +4,8 @@ #include #include +#include "plier/dialect.hpp" + namespace { bool hasMemWrites(mlir::Operation *op) @@ -28,7 +30,8 @@ bool hasMemWrites(mlir::Operation *op) mlir::LogicalResult PromoteToParallel::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const { - if (hasMemWrites(op)) + auto has_parallel_attr = op->hasAttr(plier::attributes::getParallelName()); + if (!has_parallel_attr && hasMemWrites(op)) { return mlir::failure(); } diff --git a/mlir-compiler/src/transforms/loop_utils.cpp b/mlir-compiler/src/transforms/loop_utils.cpp index fd08842f233..cc6fa0ea2b6 100644 --- a/mlir-compiler/src/transforms/loop_utils.cpp +++ b/mlir-compiler/src/transforms/loop_utils.cpp @@ -42,7 +42,8 @@ mlir::Value get_last_iter_value( mlir::LogicalResult lower_while_to_for( plier::GetiterOp getiter, mlir::PatternRewriter& builder, llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, - llvm::function_ref get_iter_val) + llvm::function_ref get_iter_val, + llvm::function_ref results) { llvm::SmallVector to_process; for (auto user : getiter.getOperation()->getUsers()) @@ -165,6 +166,11 @@ mlir::LogicalResult lower_while_to_for( assert(while_op.getOperation()->getUsers().empty()); builder.eraseOp(while_op); changed = true; + + if (results) + { + results(loop_op); + } } if (getiter.getOperation()->getUsers().empty()) diff --git a/mlir-compiler/src/transforms/loop_utils.hpp b/mlir-compiler/src/transforms/loop_utils.hpp index 58ae9f635b9..389ab881db4 100644 --- a/mlir-compiler/src/transforms/loop_utils.hpp +++ b/mlir-compiler/src/transforms/loop_utils.hpp @@ -10,6 +10,10 @@ class Value; class Location; class OpBuilder; class Type; +namespace scf +{ +class ForOp; +} } namespace plier @@ -17,7 +21,7 @@ namespace plier class GetiterOp; } -mlir::LogicalResult lower_while_to_for( - plier::GetiterOp getiter, mlir::PatternRewriter& builder, +mlir::LogicalResult lower_while_to_for(plier::GetiterOp getiter, mlir::PatternRewriter& builder, llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, - llvm::function_ref get_iter_val); + llvm::function_ref get_iter_val, + llvm::function_ref results = nullptr); diff --git a/mlir-compiler/src/transforms/pipeline_utils.cpp b/mlir-compiler/src/transforms/pipeline_utils.cpp index 19123a07caf..154de789863 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.cpp +++ b/mlir-compiler/src/transforms/pipeline_utils.cpp @@ -7,7 +7,7 @@ mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module) { - return module.getAttrOfType(plier::attributes::jump_markers); + return module.getAttrOfType(plier::attributes::getJumpMarkersName()); } void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) @@ -15,7 +15,7 @@ void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) assert(name); assert(!name.getValue().empty()); - auto jump_markers = plier::attributes::jump_markers; + auto jump_markers = plier::attributes::getJumpMarkersName(); llvm::SmallVector name_list; if (auto old_attr = module.getAttrOfType(jump_markers)) { @@ -43,7 +43,7 @@ void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) assert(name); assert(!name.getValue().empty()); - auto jump_markers = plier::attributes::jump_markers; + auto jump_markers = plier::attributes::getJumpMarkersName(); llvm::SmallVector name_list; if (auto old_attr = module.getAttrOfType(jump_markers)) { diff --git a/numba/mlir/builtin_funcs.py b/numba/mlir/builtin_funcs.py index bac1fabc921..443f07b387a 100644 --- a/numba/mlir/builtin_funcs.py +++ b/numba/mlir/builtin_funcs.py @@ -1,5 +1,9 @@ from numba.mlir.func_registry import add_func +from numba import prange + add_func(range, 'range') add_func(len, 'len') add_func(bool, 'bool') + +add_func(prange, 'numba.prange') diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 7655d630122..628037bf020 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -216,6 +216,16 @@ def py_func(a, b, c): jit_func = njit(py_func) assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + def test_prange1(self): + def py_func(a): + res = 0 + for i in numba.prange(a): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10), jit_func(10)) + if __name__ == '__main__': unittest.main() From 310229e701e56d3db1b6d5ecb4b3cff50dbcdc06 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 6 Jan 2021 20:19:39 +0300 Subject: [PATCH 206/259] [MLIR] tbb parallel backend (#149) --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/include/plier/PlierOps.td | 29 ++ mlir-compiler/include/plier/dialect.hpp | 15 + mlir-compiler/src/dialect.cpp | 46 +++ mlir-compiler/src/lowering.cpp | 4 + mlir-compiler/src/pipelines/lower_to_llvm.cpp | 264 +++++++++++++++++- .../src/pipelines/parallel_to_tbb.cpp | 207 ++++++++++++++ .../src/pipelines/parallel_to_tbb.hpp | 12 + .../src/rewrites/promote_to_parallel.cpp | 10 +- numba/core/typed_passes.py | 4 + numba/mlir/tests/test_basic.py | 12 +- numba/np/ufunc/parallel.py | 1 + numba/np/ufunc/tbbpool.cpp | 36 +++ 13 files changed, 636 insertions(+), 6 deletions(-) create mode 100644 mlir-compiler/src/pipelines/parallel_to_tbb.cpp create mode 100644 mlir-compiler/src/pipelines/parallel_to_tbb.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index fcb7256f2c3..a7657a21b95 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -22,6 +22,7 @@ add_subdirectory(include/plier) set(SOURCES_LIST src/pipelines/base_pipeline.cpp src/pipelines/lower_to_llvm.cpp + src/pipelines/parallel_to_tbb.cpp src/pipelines/plier_to_linalg.cpp src/pipelines/plier_to_std.cpp src/rewrites/call_lowering.cpp @@ -43,6 +44,7 @@ set(HEADERS_LIST include/plier/PlierOps.td src/pipelines/base_pipeline.hpp src/pipelines/lower_to_llvm.hpp + src/pipelines/parallel_to_tbb.hpp src/pipelines/plier_to_linalg.hpp src/pipelines/plier_to_std.hpp src/rewrites/call_lowering.hpp diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index 1c625a6a0de..bf43e332c40 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -2,6 +2,8 @@ #define PLIER_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" def Plier_Dialect : Dialect { @@ -208,4 +210,31 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { ]; } +def ParallelOp : Plier_Op<"parallel", + [DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"plier::YieldOp">, + RecursiveSideEffects]> { + + let arguments = (ins Index:$lowerBound, + Index:$upperBound, + Index:$step); + let regions = (region SizedRegion<1>:$region); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$lowerBound, "::mlir::Value":$upperBound, "::mlir::Value":$step, + CArg<"::mlir::function_ref", + "nullptr">)> + ]; +} + +def YieldOp : Plier_Op<"yield", [NoSideEffect, ReturnLike, Terminator, + ParentOneOf<["ParallelOp"]>]> { + let arguments = (ins Variadic:$results); + let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>]; + // Override default verifier (defined in SCF_Op), no custom verification + // needed. + let verifier = ?; +} + #endif // PLIER_OPS diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index cd8be69547b..4f46d010bb9 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -3,8 +3,22 @@ #include #include #include +#include +#include #include +namespace plier +{ +// TODO: needed for LoopLikeInterface +using Value = ::mlir::Value; +using Region = ::mlir::Region; +using LogicalResult = ::mlir::LogicalResult; +using Operation = ::mlir::Operation; + +template +using ArrayRef = ::mlir::ArrayRef; +} + #include "plier/PlierOpsEnums.h.inc" #include "plier/PlierOpsDialect.h.inc" #define GET_OP_CLASSES @@ -17,6 +31,7 @@ namespace attributes llvm::StringRef getFastmathName(); llvm::StringRef getJumpMarkersName(); llvm::StringRef getParallelName(); +llvm::StringRef getMaxConcurrencyName(); } namespace detail diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index bae439cec15..dd7d8612cc4 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -6,6 +6,7 @@ #include #include + #include namespace plier @@ -26,6 +27,12 @@ llvm::StringRef attributes::getParallelName() return "#plier.parallel"; } +llvm::StringRef attributes::getMaxConcurrencyName() +{ + return "#plier.max_concurrency"; +} + + namespace detail { struct PyTypeStorage : public mlir::TypeStorage @@ -266,6 +273,45 @@ void GetattrOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, value, name); } +mlir::LogicalResult ParallelOp::moveOutOfLoop(mlir::ArrayRef ops) +{ + for (mlir::Operation *op : ops) + { + op->moveBefore(*this); + } + return mlir::success(); +} + +mlir::Region &ParallelOp::getLoopBody() { return region(); } + +bool ParallelOp::isDefinedOutsideOfLoop(mlir::Value value) +{ + return !region().isAncestor(value.getParentRegion()); +} + +void ParallelOp::build( + mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, + mlir::Value lowerBound, mlir::Value upperBound, mlir::Value step, + mlir::function_ref bodyBuilder) { + odsState.addOperands({lowerBound, upperBound, step}); + auto bodyRegion = odsState.addRegion(); + bodyRegion->push_back(new mlir::Block); + auto& bodyBlock = bodyRegion->front(); + bodyBlock.addArgument(odsBuilder.getIndexType()); // lower bound + bodyBlock.addArgument(odsBuilder.getIndexType()); // upper bound + bodyBlock.addArgument(odsBuilder.getIndexType()); // thread index + + if (bodyBuilder) + { + mlir::OpBuilder::InsertionGuard guard(odsBuilder); + odsBuilder.setInsertionPointToStart(&bodyBlock); + bodyBuilder(odsBuilder, odsState.location, bodyBlock.getArgument(0), + bodyBlock.getArgument(1), bodyBlock.getArgument(2)); + ParallelOp::ensureTerminator(*bodyRegion, odsBuilder, odsState.location); + } +} + } #define GET_OP_CLASSES diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 0ba24307d58..519dd69112d 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -23,6 +23,7 @@ #include "utils.hpp" #include "pipelines/base_pipeline.hpp" +#include "pipelines/parallel_to_tbb.hpp" #include "pipelines/plier_to_std.hpp" #include "pipelines/plier_to_linalg.hpp" #include "pipelines/lower_to_llvm.hpp" @@ -145,6 +146,8 @@ struct plier_lowerer final { func.setAttr(plier::attributes::getFastmathName(), mlir::UnitAttr::get(&ctx)); } + auto max_concurrency = builder.getI64IntegerAttr(compilation_context["max_concurrency"]().cast()); + mod.setAttr(plier::attributes::getMaxConcurrencyName(), max_concurrency); lower_func_body(func_ir); mod.push_back(func); return mod; @@ -645,6 +648,7 @@ void create_pipeline(PipelineRegistry& registry) register_lower_to_llvm_pipeline(registry); register_plier_to_std_pipeline(registry); register_plier_to_linalg_pipeline(registry); + register_parallel_to_tbb_pipeline(registry); } } diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 459abae51e6..8309aaa2c92 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -1,6 +1,7 @@ #include "pipelines/lower_to_llvm.hpp" #include +#include #include #include #include @@ -356,6 +357,12 @@ struct ReturnOpLowering : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite(mlir::ReturnOp op, mlir::PatternRewriter& rewriter) const { + auto parent = op->getParentOfType(); + if (nullptr == parent || parent.isPrivate()) + { + return mlir::failure(); + } + auto insert_ret = [&]() { auto ctx = op.getContext(); @@ -460,6 +467,258 @@ class LLVMFunctionPass : public mlir::OperationPass mlir::LLVM::LLVMFuncOp getFunction() { return this->getOperation(); } }; +struct LowerParallel : public mlir::OpRewritePattern +{ + LowerParallel(mlir::MLIRContext* context): + OpRewritePattern(context), + converter(context) {} + + mlir::LogicalResult + matchAndRewrite(plier::ParallelOp op, + mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector context_vars; + llvm::SmallVector context_constants; + llvm::DenseSet context_vars_set; + auto add_context_var = [&](mlir::Value value) + { + if (0 != context_vars_set.count(value)) + { + return; + } + context_vars_set.insert(value); + if (auto op = value.getDefiningOp()) + { + mlir::ConstantOp a; + if (op->hasTrait()) + { + context_constants.emplace_back(op); + return; + } + } + context_vars.emplace_back(value); + }; + + auto is_defined_inside = [&](mlir::Value value) + { + auto& this_region = op.getLoopBody(); + auto op_region = value.getParentRegion(); + assert(nullptr != op_region); + do + { + if (op_region == &this_region) + { + return true; + } + op_region = op_region->getParentRegion(); + } + while (nullptr != op_region); + return false; + }; + + if (op->walk([&](mlir::Operation* inner)->mlir::WalkResult + { + if (op != inner) + { + for (auto arg : inner->getOperands()) + { + if (!is_defined_inside(arg)) + { + add_context_var(arg); + } + } + } + return mlir::WalkResult::advance(); + }).wasInterrupted()) + { + return mlir::failure(); + } + + auto context_type = [&]()->mlir::LLVM::LLVMStructType + { + llvm::SmallVector fields; + fields.reserve(context_vars.size()); + for (auto var : context_vars) + { + auto type = converter.convertType(var.getType()); + if (!type) + { + return {}; + } + fields.emplace_back(type.cast()); + } + return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), fields); + }(); + + if (!context_type) + { + return mlir::failure(); + } + auto context_ptr_type = mlir::LLVM::LLVMPointerType::get(context_type); + + auto loc = op.getLoc(); + auto llvm_i32_type = mlir::LLVM::LLVMIntegerType::get(op.getContext(), 32); + auto zero = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0)); + auto one = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(1)); + auto context = rewriter.create(loc, context_ptr_type, one, 0); + for (auto it : llvm::enumerate(context_vars)) + { + auto type = context_type.getBody()[it.index()]; + auto llvm_val = rewriter.create(loc, type, it.value()); + auto i = rewriter.getI32IntegerAttr(static_cast(it.index())); + mlir::Value indices[] = { + zero, + rewriter.create(loc, llvm_i32_type, i) + }; + auto pointer_type = mlir::LLVM::LLVMPointerType::get(type); + auto ptr = rewriter.create(loc, pointer_type, context, indices); + rewriter.create(loc, llvm_val, ptr); + } + auto void_ptr_type = mlir::LLVM::LLVMPointerType::get(mlir::LLVM::LLVMIntegerType::get(op.getContext(), 8)); + auto context_abstract = rewriter.create(loc, void_ptr_type, context); + + auto index_type = rewriter.getIndexType(); + auto func_type = [&]() + { + mlir::Type args[] = { + index_type, // lower_bound + index_type, // upper_bound + index_type, // thread index + void_ptr_type // context + }; + return mlir::FunctionType::get(op.getContext(), args, {}); + }(); + + auto mod = op.getParentOfType(); + auto outlined_func = [&]()->mlir::FuncOp + { + auto func = [&]() + { + auto func_name = [&]() + { + auto old_name = op.getParentOfType().getName(); + for (int i = 0;;++i) + { + auto name = (0 == i ? + (llvm::Twine(old_name) + "_outlined").str() : + (llvm::Twine(old_name) + "_outlined_" + llvm::Twine(i)).str()); + if (!mod.lookupSymbol(name)) + { + return name; + } + } + }(); + + mlir::OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(mod.getBody(), + std::prev(mod.getBody()->end())); + auto func = rewriter.create(rewriter.getUnknownLoc(), func_name, func_type); + func.setPrivate(); + return func; + }(); + mlir::BlockAndValueMapping mapping; + auto& old_entry = op.getLoopBody().front(); + auto entry = func.addEntryBlock(); + auto loc = rewriter.getUnknownLoc(); + mlir::OpBuilder::InsertionGuard guard(rewriter); + mapping.map(old_entry.getArgument(0), entry->getArgument(0)); + mapping.map(old_entry.getArgument(1), entry->getArgument(1)); + mapping.map(old_entry.getArgument(2), entry->getArgument(2)); + rewriter.setInsertionPointToStart(entry); + for (auto arg : context_constants) + { + rewriter.clone(*arg, mapping); + } + auto context_ptr = rewriter.create(loc, context_ptr_type, entry->getArgument(3)); + auto zero = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0)); + for (auto it : llvm::enumerate(context_vars)) + { + auto index = it.index(); + auto old_val = it.value(); + mlir::Value indices[] = { + zero, + rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast(index))) + }; + auto pointer_type = mlir::LLVM::LLVMPointerType::get(context_type.getBody()[index]); + auto ptr = rewriter.create(loc, pointer_type, context_ptr, indices); + auto llvm_val = rewriter.create(loc, ptr); + auto val = rewriter.create(loc, old_val.getType(), llvm_val); + mapping.map(old_val, val); + } + op.getLoopBody().cloneInto(&func.getBody(), mapping); + auto& orig_entry = *std::next(func.getBody().begin()); + rewriter.create(loc, &orig_entry); + for (auto& block : func.getBody()) + { + if (auto term = mlir::dyn_cast(block.getTerminator())) + { + rewriter.eraseOp(term); + rewriter.setInsertionPointToEnd(&block); + rewriter.create(loc); + } + } + return func; + }(); + + auto parallel_for = [&]() + { + auto func_name = "numba_parallel_for2"; + if (auto sym = mod.lookupSymbol(func_name)) + { + return sym; + } + mlir::Type args[] = { + index_type, // lower bound + index_type, // upper bound + index_type, // step + func_type, + void_ptr_type + }; + auto func_type = mlir::FunctionType::get(op.getContext(), args, {}); + mlir::OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(mod.getBody(), + std::prev(mod.getBody()->end())); + auto func = rewriter.create(rewriter.getUnknownLoc(), func_name, func_type); + func.setPrivate(); + return func; + }(); + auto func_addr = rewriter.create(loc, func_type, rewriter.getSymbolRefAttr(outlined_func)); + mlir::Value pf_args[] = { + op.lowerBound(), + op.upperBound(), + op.step(), + func_addr, + context_abstract + }; + rewriter.create(loc, parallel_for, pf_args); + rewriter.eraseOp(op); + return mlir::success(); + } + +private: + mutable mlir::LLVMTypeConverter converter; // TODO +}; + +struct LowerParallelToCFGPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override final + { + mlir::OwningRewritePatternList patterns; + patterns.insert(&getContext()); + + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; + struct PreLLVMLowering : public mlir::PassWrapper { virtual void getDependentDialects( @@ -480,7 +739,7 @@ struct PreLLVMLowering : public mlir::PassWrapper(&getContext(), type_helper.get_type_converter()); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -569,8 +828,9 @@ struct LLVMLoweringPass : public mlir::PassWrapper()); pm.addPass(mlir::createLowerToCFGPass()); - pm.addPass(std::make_unique()); +// pm.addPass(std::make_unique()); pm.addNestedPass(std::make_unique()); pm.addPass(std::make_unique(getLLVMOptions())); pm.addNestedPass(std::make_unique()); diff --git a/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/src/pipelines/parallel_to_tbb.cpp new file mode 100644 index 00000000000..d7a8eaf31b6 --- /dev/null +++ b/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -0,0 +1,207 @@ +#include "pipelines/parallel_to_tbb.hpp" + +#include +#include +#include +#include +#include +#include + +#include "plier/dialect.hpp" + +#include "pipeline_registry.hpp" +#include "pipelines/base_pipeline.hpp" +#include "pipelines/lower_to_llvm.hpp" + +namespace +{ +mlir::MemRefType getReduceType(mlir::Type type, int64_t count) +{ + if (type.isIntOrFloat()) + { + return mlir::MemRefType::get(count, type); + } + return {}; +} + +mlir::Value getZeroVal(mlir::OpBuilder& builder, mlir::Location loc, mlir::Type type) +{ + if (type.isa()) + { + return builder.create(loc, 0, type.cast()); + } + if (type.isa()) + { + return builder.create(loc, llvm::APFloat(0.0), type.cast()); + } + llvm_unreachable("Unhandled type"); +} + +struct ParallelToTbb : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::ParallelOp op, mlir::PatternRewriter &rewriter) const override + { + if (mlir::isa(op->getParentOp())) + { + return mlir::failure(); + } + if (op.getNumLoops() != 1) + { + return mlir::failure(); + } + + int64_t max_concurrency = 0; + auto mod = op.getParentOfType(); + if (auto mc = mod.getAttrOfType(plier::attributes::getMaxConcurrencyName())) + { + max_concurrency = mc.getInt(); + } + + if (max_concurrency <= 1) + { + return mlir::failure(); + } + for (auto type : op.getResultTypes()) + { + if (!getReduceType(type, max_concurrency)) + { + return mlir::failure(); + } + } + + auto loc = op.getLoc(); + mlir::BlockAndValueMapping mapping; + llvm::SmallVector reduce_vars(op.getNumResults()); + for (auto it : llvm::enumerate(op.getResultTypes())) + { + auto type = it.value(); + auto reduce_type = getReduceType(type, max_concurrency); + assert(reduce_type); + auto reduce = rewriter.create(loc, reduce_type); + auto index = static_cast(it.index()); + reduce_vars[index] = reduce; + auto zero = getZeroVal(rewriter, loc, type); + mapping.map(op.initVals()[index], zero); + for (unsigned i = 0; i < max_concurrency; ++i) + { + mlir::Value index = rewriter.create(loc, i); + rewriter.create(loc, zero, reduce, index); + } + } + + auto& old_body = op.getLoopBody().front(); + auto orig_lower_bound = op.lowerBound().front(); + auto orig_upper_bound = op.upperBound().front(); + auto orig_step = op.step().front(); + auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value thread_index) + { + mapping.map(orig_lower_bound, lower_bound); + mapping.map(orig_upper_bound, upper_bound); + for (auto it : llvm::enumerate(op.initVals())) + { + auto reduce_var = reduce_vars[it.index()]; + auto val = builder.create(loc, reduce_var, thread_index); + mapping.map(it.value(), val); + } + auto new_op = builder.clone(*op, mapping); + assert(new_op->getNumResults() == reduce_vars.size()); + for (auto it : llvm::enumerate(new_op->getResults())) + { + auto reduce_var = reduce_vars[it.index()]; + builder.create(loc, it.value(), reduce_var, thread_index); + } + }; + + rewriter.create(loc, orig_lower_bound, orig_upper_bound, orig_step, body_builder); + + auto reduce_lower_bound = rewriter.create(loc, 0); + auto reduce_upper_bound = rewriter.create(loc, max_concurrency); + auto reduce_step = rewriter.create(loc, 1); + + auto reduce_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args) + { + assert(args.size() == reduce_vars.size()); + mapping.clear(); + auto reduce_ops = llvm::make_filter_range(old_body.without_terminator(), [](auto& op) + { + return mlir::isa(op); + }); + llvm::SmallVector yield_args; + yield_args.reserve(args.size()); + for (auto it : llvm::enumerate(reduce_ops)) + { + auto& reduce_var = reduce_vars[it.index()]; + auto arg = args[static_cast(it.index())]; + auto reduce_op = mlir::cast(it.value()); + auto& reduce_op_body = reduce_op.reductionOperator().front(); + assert(reduce_op_body.getNumArguments() == 2); + auto prev_val = builder.create(loc, reduce_var, index); + mapping.map(reduce_op_body.getArgument(0), arg); + mapping.map(reduce_op_body.getArgument(1), prev_val); + for (auto& old_reduce_op : reduce_op_body.without_terminator()) + { + builder.clone(old_reduce_op, mapping); + } + auto result = mlir::cast(reduce_op_body.getTerminator()).result(); + result = mapping.lookupOrNull(result); + assert(result); + yield_args.emplace_back(result); + } + builder.create(loc, yield_args); + }; + + auto reduce_loop = rewriter.create(loc, reduce_lower_bound, reduce_upper_bound, reduce_step, op.initVals(), reduce_body_builder); + rewriter.replaceOp(op, reduce_loop.getResults()); + + return mlir::success(); + } +}; + +struct ParallelToTbbPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +void ParallelToTbbPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + patterns.insert< + ParallelToTbb + >(&getContext()); + + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +void populate_parallel_to_tbb_pipeline(mlir::OpPassManager& pm) +{ + pm.addNestedPass(std::make_unique()); +} +} + +void register_parallel_to_tbb_pipeline(PipelineRegistry& registry) +{ + registry.register_pipeline([](auto sink) + { + auto stage = get_lower_lowering_stage(); + auto llvm_pipeline = lower_to_llvm_pipeline_name(); + sink(parallel_to_tbb_pipeline_name(), {stage.begin}, {llvm_pipeline}, {}, &populate_parallel_to_tbb_pipeline); + }); +} + +llvm::StringRef parallel_to_tbb_pipeline_name() +{ + return "parallel_to_tbb"; +} diff --git a/mlir-compiler/src/pipelines/parallel_to_tbb.hpp b/mlir-compiler/src/pipelines/parallel_to_tbb.hpp new file mode 100644 index 00000000000..e19f9fde47b --- /dev/null +++ b/mlir-compiler/src/pipelines/parallel_to_tbb.hpp @@ -0,0 +1,12 @@ +#pragma once + +class PipelineRegistry; + +namespace llvm +{ +class StringRef; +} + +void register_parallel_to_tbb_pipeline(PipelineRegistry& registry); + +llvm::StringRef parallel_to_tbb_pipeline_name(); diff --git a/mlir-compiler/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/src/rewrites/promote_to_parallel.cpp index 0d7d10f8629..c69c2b009d9 100644 --- a/mlir-compiler/src/rewrites/promote_to_parallel.cpp +++ b/mlir-compiler/src/rewrites/promote_to_parallel.cpp @@ -8,7 +8,7 @@ namespace { -bool hasMemWrites(mlir::Operation *op) +bool hasSideEffects(mlir::Operation *op) { return op->walk([&](mlir::Operation *op) { @@ -19,7 +19,11 @@ bool hasMemWrites(mlir::Operation *op) return mlir::WalkResult::interrupt(); } } - if (mlir::dyn_cast(op)) +// if (op->hasTrait()) +// { +// return mlir::WalkResult::interrupt(); +// } + if (mlir::isa(op)) { return mlir::WalkResult::interrupt(); } @@ -31,7 +35,7 @@ bool hasMemWrites(mlir::Operation *op) mlir::LogicalResult PromoteToParallel::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const { auto has_parallel_attr = op->hasAttr(plier::attributes::getParallelName()); - if (!has_parallel_attr && hasMemWrites(op)) + if (!has_parallel_attr && hasSideEffects(op)) { return mlir::failure(); } diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 4dd79c922ee..53ec79236f2 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -500,6 +500,7 @@ def run_pass(self, state): noalias=flags.noalias) fn_name = fndesc.mangled_name + from numba.np.ufunc.parallel import get_thread_count ctx = {} ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': numba.mlir.settings.PRINT_IR} ctx['typemap'] = lambda op: state.typemap[op.name] @@ -508,9 +509,12 @@ def run_pass(self, state): ctx['fnname'] = lambda: fn_name ctx['resolve_func'] = self._resolve_func_name ctx['fastmath'] = lambda: state.targetctx.fastmath + ctx['max_concurrency'] = lambda: get_thread_count() if state.flags.auto_parallel.enabled else 0 import mlir_compiler mod = mlir_compiler.lower_normal_function(ctx, state.func_ir) setattr(state, 'mlir_blob', mod) + _reload_parfors() + state.reload_init.append(_reload_parfors) return True def _resolve_func_name(self, obj): diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 628037bf020..c963373e16b 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -223,9 +223,19 @@ def py_func(a): res = res + i return res - jit_func = njit(py_func) + jit_func = njit(py_func, parallel=True) assert_equal(py_func(10), jit_func(10)) + def test_prange2(self): + def py_func(a, b): + res = 0 + for i in numba.prange(a, b): + res = res + i + return res + + jit_func = njit(py_func, parallel=True) + assert_equal(py_func(10, 20), jit_func(10, 20)) + if __name__ == '__main__': unittest.main() diff --git a/numba/np/ufunc/parallel.py b/numba/np/ufunc/parallel.py index 1b4a3b53ddf..fbf250c6585 100644 --- a/numba/np/ufunc/parallel.py +++ b/numba/np/ufunc/parallel.py @@ -497,6 +497,7 @@ def raise_with_hint(required): raise_with_hint(requirements) ll.add_symbol('numba_parallel_for', lib.parallel_for) + ll.add_symbol('numba_parallel_for2', lib.parallel_for2) ll.add_symbol('do_scheduling_signed', lib.do_scheduling_signed) ll.add_symbol('do_scheduling_unsigned', lib.do_scheduling_unsigned) diff --git a/numba/np/ufunc/tbbpool.cpp b/numba/np/ufunc/tbbpool.cpp index faff4790fb7..d5ade473509 100644 --- a/numba/np/ufunc/tbbpool.cpp +++ b/numba/np/ufunc/tbbpool.cpp @@ -15,10 +15,14 @@ Implement parallel vectorize workqueue on top of Intel TBB. #include #include #include +#include #include "workqueue.h" #include "gufunc_scheduler.h" +#undef min +#undef max + /* TBB 2019 U5 is the minimum required version as this is needed: * https://github.com/intel/tbb/blob/18070344d755ece04d169e6cc40775cae9288cee/CHANGES#L133-L134 * and therefore @@ -202,6 +206,36 @@ parallel_for(void *fn, char **args, size_t *dimensions, size_t *steps, void *dat }); } +using parallel_for2_fptr = void(*)(size_t, size_t, size_t, void*); +static void parallel_for2(size_t lower_bound, size_t upper_bound, size_t step, parallel_for2_fptr func, void* ctx) +{ + auto num_threads = get_num_threads(); + if(_DEBUG) + { + printf("parallel_for2 %d %d %d %d\n", (int)lower_bound, (int)upper_bound, (int)step, (int)num_threads); + } + tbb::task_arena limited(num_threads); + fix_tls_observer observer(limited, num_threads); + + limited.execute([&] + { + size_t count = (upper_bound - lower_bound - 1) / step + 1; + size_t grain = std::max(size_t(1), std::min(count / num_threads / 2, size_t(64))); + tbb::parallel_for(tbb::blocked_range(0, count, grain), + [&](const tbb::blocked_range& r) + { + auto thread_index = static_cast(tbb::this_task_arena::current_thread_index()); + auto begin = lower_bound + r.begin() * step; + auto end = lower_bound + r.end() * step; + if(_DEBUG) + { + printf("parallel_for2 body %d %d %d\n", (int)begin, (int)end, (int)thread_index); + } + func(begin, end, thread_index, ctx); + }, tbb::auto_partitioner()); + }); +} + void ignore_blocking_terminate_assertion( const char*, int, const char*, const char * ) { tbb::internal::runtime_warning("Unable to wait for threads to shut down before fork(). It can break multithreading in child process\n"); @@ -307,6 +341,8 @@ MOD_INIT(tbbpool) PyLong_FromVoidPtr((void*)&add_task)); PyObject_SetAttrString(m, "parallel_for", PyLong_FromVoidPtr((void*)¶llel_for)); + PyObject_SetAttrString(m, "parallel_for2", + PyLong_FromVoidPtr((void*)¶llel_for2)); PyObject_SetAttrString(m, "do_scheduling_signed", PyLong_FromVoidPtr((void*)&do_scheduling_signed)); PyObject_SetAttrString(m, "do_scheduling_unsigned", From 93256d79117eb3b5fe396529e3e1d279f74020ab Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 6 Jan 2021 21:03:15 +0300 Subject: [PATCH 207/259] some fixes to while-to-for lowering (#150) --- mlir-compiler/src/rewrites/promote_to_parallel.cpp | 2 +- mlir-compiler/src/transforms/loop_utils.cpp | 6 +++--- numba/mlir/tests/test_basic.py | 14 +++++++++++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/src/rewrites/promote_to_parallel.cpp index c69c2b009d9..09cef94f295 100644 --- a/mlir-compiler/src/rewrites/promote_to_parallel.cpp +++ b/mlir-compiler/src/rewrites/promote_to_parallel.cpp @@ -126,7 +126,7 @@ mlir::LogicalResult PromoteToParallel::matchAndRewrite(mlir::scf::ForOp op, mlir auto first_op_operands = first_op->getOperands(); auto reduce_operand = (first_op_operands[0] == reduce_arg ? first_op_operands[1] : first_op_operands[0]); assert(reduce_operand != reduce_arg); - reduce_operand = mapping.lookupOrNull(reduce_operand); + reduce_operand = mapping.lookupOrDefault(reduce_operand); assert(reduce_operand); builder.create(loc, reduce_operand, reduce_body_builder); } diff --git a/mlir-compiler/src/transforms/loop_utils.cpp b/mlir-compiler/src/transforms/loop_utils.cpp index cc6fa0ea2b6..ebae0c332e3 100644 --- a/mlir-compiler/src/transforms/loop_utils.cpp +++ b/mlir-compiler/src/transforms/loop_utils.cpp @@ -73,7 +73,7 @@ mlir::LogicalResult lower_while_to_for( } return op; }; - if (!iternext || !pairfirst || !pairsecond || !before_term || + if (!iternext || !pairsecond || !before_term || skip_casts(before_term.condition()) != pairsecond) { continue; @@ -91,7 +91,7 @@ mlir::LogicalResult lower_while_to_for( { auto block_arg = std::get<0>(it); auto term_arg = std::get<1>(it); - if (term_arg == pairfirst) // iter arg + if (pairfirst && term_arg == pairfirst) // iter arg { auto iter_val = get_iter_val(builder, loc, pairfirst.getType(), iv); mapper.map(block_arg, iter_val); @@ -153,7 +153,7 @@ mlir::LogicalResult lower_while_to_for( break; } } - if (operand == pairfirst && !old_res.getUsers().empty()) + if (pairfirst && operand == pairfirst && !old_res.getUsers().empty()) { auto val = get_last_iter_value(builder, loc, lower_bound, upper_bound, step); auto new_res = builder.create(loc, old_res.getType(), val); diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index c963373e16b..79a7b5df9ee 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -204,7 +204,19 @@ def py_func(n): jit_func = njit(py_func) assert_equal(py_func(10), jit_func(10)) - def test_range_nested(self): + def test_range_nested1(self): + def py_func(a, b, c): + res = 0 + for i in range(a): + for j in range(b): + for k in range(c): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + + def test_range_nested2(self): def py_func(a, b, c): res = 0 for i in range(a): From fe438e8ab873037a8e46c7b2b1baa208ec5ab422 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 7 Jan 2021 14:53:34 +0300 Subject: [PATCH 208/259] [MLIR] Some fixes (#151) --- mlir-compiler/CMakeLists.txt | 2 + mlir-compiler/src/lowering.cpp | 7 +++- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 42 +++++++++++-------- mlir-compiler/src/pipelines/plier_to_std.cpp | 9 +--- mlir-compiler/src/transforms/func_utils.cpp | 18 ++++++++ mlir-compiler/src/transforms/func_utils.hpp | 17 ++++++++ 6 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 mlir-compiler/src/transforms/func_utils.cpp create mode 100644 mlir-compiler/src/transforms/func_utils.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index a7657a21b95..13449da466c 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -30,6 +30,7 @@ set(SOURCES_LIST src/rewrites/cast_lowering.cpp src/rewrites/promote_to_parallel.cpp src/rewrites/type_conversion.cpp + src/transforms/func_utils.cpp src/transforms/loop_utils.cpp src/transforms/pipeline_utils.cpp src/compiler.cpp @@ -52,6 +53,7 @@ set(HEADERS_LIST src/rewrites/cast_lowering.hpp src/rewrites/promote_to_parallel.hpp src/rewrites/type_conversion.hpp + src/transforms/func_utils.hpp src/transforms/loop_utils.hpp src/transforms/pipeline_utils.hpp src/compiler.hpp diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 519dd69112d..71585414b62 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -146,8 +146,11 @@ struct plier_lowerer final { func.setAttr(plier::attributes::getFastmathName(), mlir::UnitAttr::get(&ctx)); } - auto max_concurrency = builder.getI64IntegerAttr(compilation_context["max_concurrency"]().cast()); - mod.setAttr(plier::attributes::getMaxConcurrencyName(), max_concurrency); + auto max_concurrency = compilation_context["max_concurrency"]().cast(); + if (max_concurrency > 0) + { + mod.setAttr(plier::attributes::getMaxConcurrencyName(), builder.getI64IntegerAttr(max_concurrency)); + } lower_func_body(func_ir); mod.push_back(func); return mod; diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 8309aaa2c92..fdea76031fa 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -22,6 +22,8 @@ #include "plier/dialect.hpp" +#include "transforms/func_utils.hpp" + #include "base_pipeline.hpp" #include "pipeline_registry.hpp" @@ -205,11 +207,9 @@ struct MemRefConversionCache auto func_name = gen_conversion_func_name(memref_type); auto func_type = mlir::FunctionType::get(builder.getContext(),src_type, dst_type); auto loc = builder.getUnknownLoc(); - auto new_func = mlir::FuncOp::create(loc, func_name, func_type); - new_func.setPrivate(); + auto new_func = add_function(builder, module, func_name, func_type); auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); new_func.setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); - module.push_back(new_func); cache.insert({memref_type, new_func}); mlir::OpBuilder::InsertionGuard guard(builder); auto block = new_func.addEntryBlock(); @@ -467,6 +467,22 @@ class LLVMFunctionPass : public mlir::OperationPass mlir::LLVM::LLVMFuncOp getFunction() { return this->getOperation(); } }; +void copyAttrs(mlir::Operation* src, mlir::Operation* dst) +{ + const mlir::StringRef attrs[] = { + plier::attributes::getFastmathName(), + plier::attributes::getParallelName(), + plier::attributes::getMaxConcurrencyName(), + }; + for (auto name : attrs) + { + if (auto attr = src->getAttr(name)) + { + dst->setAttr(name, attr); + } + } +} + struct LowerParallel : public mlir::OpRewritePattern { LowerParallel(mlir::MLIRContext* context): @@ -593,9 +609,11 @@ struct LowerParallel : public mlir::OpRewritePattern { auto func = [&]() { + auto parent_func = op.getParentOfType(); + assert(parent_func); auto func_name = [&]() { - auto old_name = op.getParentOfType().getName(); + auto old_name = parent_func.getName(); for (int i = 0;;++i) { auto name = (0 == i ? @@ -608,12 +626,8 @@ struct LowerParallel : public mlir::OpRewritePattern } }(); - mlir::OpBuilder::InsertionGuard guard(rewriter); - // Insert before module terminator. - rewriter.setInsertionPoint(mod.getBody(), - std::prev(mod.getBody()->end())); - auto func = rewriter.create(rewriter.getUnknownLoc(), func_name, func_type); - func.setPrivate(); + auto func = add_function(rewriter, mod, func_name, func_type); + copyAttrs(parent_func, func); return func; }(); mlir::BlockAndValueMapping mapping; @@ -675,13 +689,7 @@ struct LowerParallel : public mlir::OpRewritePattern void_ptr_type }; auto func_type = mlir::FunctionType::get(op.getContext(), args, {}); - mlir::OpBuilder::InsertionGuard guard(rewriter); - // Insert before module terminator. - rewriter.setInsertionPoint(mod.getBody(), - std::prev(mod.getBody()->end())); - auto func = rewriter.create(rewriter.getUnknownLoc(), func_name, func_type); - func.setPrivate(); - return func; + return add_function(rewriter, mod, func_name, func_type); }(); auto func_addr = rewriter.create(loc, func_type, rewriter.getSymbolRefAttr(outlined_func)); mlir::Value pf_args[] = { diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index ce333ba26ca..11c364b9700 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -18,6 +18,7 @@ #include "rewrites/call_lowering.hpp" #include "rewrites/cast_lowering.hpp" #include "rewrites/type_conversion.hpp" +#include "transforms/func_utils.hpp" #include "transforms/loop_utils.hpp" #include "base_pipeline.hpp" @@ -1107,13 +1108,7 @@ mlir::FuncOp get_lib_symbol( return op; } - mlir::OpBuilder::InsertionGuard guard(rewriter); - // Insert before module terminator. - rewriter.setInsertionPoint(mod.getBody(), - std::prev(mod.getBody()->end())); - auto func = rewriter.create(rewriter.getUnknownLoc(), name, type); - func.setPrivate(); - return func; + return add_function(rewriter, mod, name, type); } mlir::LogicalResult lower_math_func( diff --git a/mlir-compiler/src/transforms/func_utils.cpp b/mlir-compiler/src/transforms/func_utils.cpp new file mode 100644 index 00000000000..cd7205dde7b --- /dev/null +++ b/mlir-compiler/src/transforms/func_utils.cpp @@ -0,0 +1,18 @@ +#include "transforms/func_utils.hpp" + +#include +#include + +#include + +mlir::FuncOp add_function(mlir::OpBuilder& builder, mlir::ModuleOp module, + llvm::StringRef name, mlir::FunctionType type) +{ + mlir::OpBuilder::InsertionGuard guard(builder); + // Insert before module terminator. + builder.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + auto func = builder.create(builder.getUnknownLoc(), name, type); + func.setPrivate(); + return func; +} diff --git a/mlir-compiler/src/transforms/func_utils.hpp b/mlir-compiler/src/transforms/func_utils.hpp new file mode 100644 index 00000000000..242696e1115 --- /dev/null +++ b/mlir-compiler/src/transforms/func_utils.hpp @@ -0,0 +1,17 @@ +#pragma once + +namespace mlir +{ +class ModuleOp; +class FuncOp; +class OpBuilder; +class FunctionType; +} + +namespace llvm +{ +class StringRef; +} + +mlir::FuncOp add_function(mlir::OpBuilder& builder, mlir::ModuleOp module, + llvm::StringRef name, mlir::FunctionType type); From af6b0daa8907a7eb72bffd70cc7a24a62bf587c7 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 7 Jan 2021 18:48:31 +0300 Subject: [PATCH 209/259] update to llvm master (#152) --- mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 38 +++++++++---------- mlir-compiler/src/pipelines/plier_to_std.cpp | 2 +- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 40226def5fd..2b0a3860e74 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -5abfeccf10bcbc0d673ece21ddd8d4ac4a0e7594 +c1d58c2b0023cd41f0da128f5190fa887d8f6c69 diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index fdea76031fa..8f045e4d3b6 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -58,15 +58,15 @@ struct LLVMTypeHelper LLVMTypeHelper(mlir::MLIRContext& ctx): type_converter(&ctx) {} - mlir::LLVM::LLVMType i(unsigned bits) + mlir::Type i(unsigned bits) { return mlir::LLVM::LLVMIntegerType::get(&type_converter.getContext(), bits); } - mlir::LLVM::LLVMType ptr(mlir::Type type) + mlir::Type ptr(mlir::Type type) { assert(static_cast(type)); - auto ll_type = type_converter.convertType(type).cast(); + auto ll_type = type_converter.convertType(type); assert(static_cast(ll_type)); return mlir::LLVM::LLVMPointerType::get(ll_type); } @@ -87,7 +87,7 @@ struct LLVMTypeHelper mlir::Type getExceptInfoType(LLVMTypeHelper& type_helper) { - mlir::LLVM::LLVMType elems[] = { + mlir::Type elems[] = { type_helper.ptr(type_helper.i(8)), type_helper.i(32), type_helper.ptr(type_helper.i(8)), @@ -101,10 +101,10 @@ mlir::LLVM::LLVMStructType get_array_type(mlir::TypeConverter& converter, mlir:: auto ctx = type.getContext(); auto i8p = mlir::LLVM::LLVMPointerType::get(mlir::LLVM::LLVMIntegerType::get(ctx, 8)); auto i64 = mlir::LLVM::LLVMIntegerType::get(ctx, 64); - auto data_type = converter.convertType(type.getElementType()).cast(); + auto data_type = converter.convertType(type.getElementType()); assert(data_type); auto shape_type = mlir::LLVM::LLVMArrayType::get(i64, static_cast(type.getRank())); - const mlir::LLVM::LLVMType members[] = { + const mlir::Type members[] = { i8p, // 0, meminfo i8p, // 1, parent i64, // 2, nitems @@ -117,7 +117,7 @@ mlir::LLVM::LLVMStructType get_array_type(mlir::TypeConverter& converter, mlir:: } template -void flatten_type(mlir::LLVM::LLVMType type, F&& func) +void flatten_type(mlir::Type type, F&& func) { if (auto struct_type = type.dyn_cast()) { @@ -142,7 +142,7 @@ void flatten_type(mlir::LLVM::LLVMType type, F&& func) } template -mlir::Value unflatten(mlir::LLVM::LLVMType type, mlir::Location loc, mlir::OpBuilder& builder, F&& next_func) +mlir::Value unflatten(mlir::Type type, mlir::Location loc, mlir::OpBuilder& builder, F&& next_func) { namespace mllvm = mlir::LLVM; if (auto struct_type = type.dyn_cast()) @@ -209,7 +209,7 @@ struct MemRefConversionCache auto loc = builder.getUnknownLoc(); auto new_func = add_function(builder, module, func_name, func_type); auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); - new_func.setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); + new_func->setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); cache.insert({memref_type, new_func}); mlir::OpBuilder::InsertionGuard guard(builder); auto block = new_func.addEntryBlock(); @@ -272,9 +272,9 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) { return; } - if (func.getAttr(plier::attributes::getFastmathName())) + if (func->getAttr(plier::attributes::getFastmathName())) { - func.setAttr("passthrough", get_fastmath_attrs(*func.getContext())); + func->setAttr("passthrough", get_fastmath_attrs(*func.getContext())); } auto old_type = func.getType(); assert(old_type.getNumResults() <= 1); @@ -320,7 +320,7 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) return ret; }); - auto mod = mlir::cast(func.getParentOp()); + auto mod = mlir::cast(func->getParentOp()); auto dst_type = type_helper.get_type_converter().convertType(memref_type); assert(dst_type); auto conv_func = conversion_cache.get_conversion_func(mod, builder, memref_type, arr_type, dst_type.cast()); @@ -373,7 +373,7 @@ struct ReturnOpLowering : public mlir::OpRewritePattern }; rewriter.setInsertionPoint(op); - auto addr = op.getParentRegion()->front().getArgument(0); + auto addr = op->getParentRegion()->front().getArgument(0); if (op.getNumOperands() == 0) { assert(addr.getType().isa()); @@ -551,7 +551,7 @@ struct LowerParallel : public mlir::OpRewritePattern auto context_type = [&]()->mlir::LLVM::LLVMStructType { - llvm::SmallVector fields; + llvm::SmallVector fields; fields.reserve(context_vars.size()); for (auto var : context_vars) { @@ -560,7 +560,7 @@ struct LowerParallel : public mlir::OpRewritePattern { return {}; } - fields.emplace_back(type.cast()); + fields.emplace_back(type); } return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), fields); }(); @@ -604,12 +604,12 @@ struct LowerParallel : public mlir::OpRewritePattern return mlir::FunctionType::get(op.getContext(), args, {}); }(); - auto mod = op.getParentOfType(); + auto mod = op->getParentOfType(); auto outlined_func = [&]()->mlir::FuncOp { auto func = [&]() { - auto parent_func = op.getParentOfType(); + auto parent_func = op->getParentOfType(); assert(parent_func); auto func_name = [&]() { @@ -826,8 +826,8 @@ struct LLVMLoweringPass : public mlir::PassWrappersetAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), + StringAttr::get(options.dataLayout.getStringRepresentation(), m.getContext())); } private: diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 11c364b9700..e12a8d78759 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -241,7 +241,7 @@ struct ReturnOpLowering : public mlir::OpRewritePattern mlir::ReturnOp op, mlir::PatternRewriter &rewriter) const override { auto operands = op.getOperands(); - auto func = mlir::cast(op.getParentOp()); + auto func = mlir::cast(op->getParentOp()); auto res_types = func.getType().getResults(); assert(res_types.size() == operands.size() || res_types.empty()); bool converted = (res_types.size() != operands.size()); From 3bd3eb6059dbee7dfaad3a78d75ff57ede4d0af6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 7 Jan 2021 19:20:43 +0300 Subject: [PATCH 210/259] [MLIR] fastmath support (#153) --- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 60 ++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 8f045e4d3b6..0d7bc6dcba0 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -262,6 +262,7 @@ mlir::Attribute get_fastmath_attrs(mlir::MLIRContext& ctx) add_pair("no-nans-fp-math", "true"), add_pair("no-signed-zeros-fp-math", "true"), add_pair("unsafe-fp-math", "true"), + add_pair(plier::attributes::getFastmathName(), "1"), }; return mlir::ArrayAttr::get(attrs, &ctx); } @@ -403,6 +404,7 @@ struct ReturnOpLowering : public mlir::OpRewritePattern mlir::TypeConverter& type_converter; }; +// Remove redundant bitcasts we have created on PreLowering struct RemoveBitcasts : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -419,6 +421,51 @@ struct RemoveBitcasts : public mlir::OpRewritePattern } }; +template +struct ApplyFastmathFlags : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter& rewriter) const + { + auto parent = mlir::cast(op->getParentOp()); + bool changed = false; + + rewriter.startRootUpdate(op); + auto fmf = op.fastmathFlags(); + getFastmathFlags(parent, [&](auto flag) + { + if (!mlir::LLVM::bitEnumContains(fmf, flag)) + { + fmf = fmf | flag; + changed = true; + } + }); + if (changed) + { + op.fastmathFlagsAttr(mlir::LLVM::FMFAttr::get(fmf, op.getContext())); + rewriter.finalizeRootUpdate(op); + } + else + { + rewriter.cancelRootUpdate(op); + } + + return mlir::success(changed); + } + +private: + template + static void getFastmathFlags(mlir::LLVM::LLVMFuncOp func, F&& sink) + { + if (func->hasAttr(plier::attributes::getFastmathName())) + { + sink(mlir::LLVM::FastmathFlags::fast); + } + } +}; + class CheckForPlierTypes : public mlir::PassWrapper> { @@ -764,8 +811,16 @@ struct PostLLVMLowering : { mlir::OwningRewritePatternList patterns; - // Remove redundant bitcasts we have created on PreLowering - patterns.insert(&getContext()); + patterns.insert< + RemoveBitcasts, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags, + ApplyFastmathFlags + >(&getContext()); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } @@ -841,6 +896,7 @@ void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) // pm.addPass(std::make_unique()); pm.addNestedPass(std::make_unique()); pm.addPass(std::make_unique(getLLVMOptions())); +// pm.addPass(mlir::createLowerToLLVMPass(getLLVMOptions())); pm.addNestedPass(std::make_unique()); } } From a38cfc2ea5d1edb8aa2bf02e136b1a0cbafe55e2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 8 Jan 2021 18:15:28 +0300 Subject: [PATCH 211/259] [MLIR] refactor compiler machinery (#154) --- mlir-compiler/CMakeLists.txt | 3 +- mlir-compiler/src/dialect.cpp | 1 - mlir-compiler/src/lowering.cpp | 65 ++++++++++++--- mlir-compiler/src/lowering.hpp | 10 ++- mlir-compiler/src/module.cpp | 18 ----- mlir-compiler/src/py_module.cpp | 12 +++ mlir-compiler/src/py_module.hpp | 1 + .../src/rewrites/type_conversion.cpp | 28 +++++++ .../src/transforms/pipeline_utils.cpp | 10 +-- numba/core/typed_passes.py | 80 +++++++++++-------- 10 files changed, 153 insertions(+), 75 deletions(-) delete mode 100644 mlir-compiler/src/module.cpp create mode 100644 mlir-compiler/src/py_module.cpp create mode 100644 mlir-compiler/src/py_module.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 13449da466c..b82b39d2042 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -36,8 +36,8 @@ set(SOURCES_LIST src/compiler.cpp src/dialect.cpp src/lowering.cpp - src/module.cpp src/pipeline_registry.cpp + src/py_module.cpp src/utils.cpp ) set(HEADERS_LIST @@ -59,6 +59,7 @@ set(HEADERS_LIST src/compiler.hpp src/lowering.hpp src/pipeline_registry.hpp + src/py_module.hpp src/utils.hpp ) diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index dd7d8612cc4..9e0636529e6 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -111,7 +111,6 @@ mlir::OpFoldResult ArgOp::fold(llvm::ArrayRef /*operands*/) if (ind >= func.getNumArguments() || func.getArgument(ind).getType() != getType()) { - emitError("Invalid function args"); return nullptr; } return func.getArgument(ind); diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 71585414b62..f430cf3a8fa 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -134,9 +134,9 @@ struct plier_lowerer final ctx.loadDialect(); } - mlir::ModuleOp lower(const py::object& compilation_context, const py::object& func_ir) + mlir::FuncOp lower(const py::object& compilation_context, mlir::ModuleOp mod, const py::object& func_ir) { - auto mod = mlir::ModuleOp::create(builder.getUnknownLoc()); + typemap = compilation_context["typemap"]; func_name_resolver = compilation_context["resolve_func"]; auto name = compilation_context["fnname"]().cast(); @@ -144,16 +144,16 @@ struct plier_lowerer final func = mlir::FuncOp::create(builder.getUnknownLoc(), name, typ); if (compilation_context["fastmath"]().cast()) { - func.setAttr(plier::attributes::getFastmathName(), mlir::UnitAttr::get(&ctx)); + func->setAttr(plier::attributes::getFastmathName(), mlir::UnitAttr::get(&ctx)); } auto max_concurrency = compilation_context["max_concurrency"]().cast(); if (max_concurrency > 0) { - mod.setAttr(plier::attributes::getMaxConcurrencyName(), builder.getI64IntegerAttr(max_concurrency)); + mod->setAttr(plier::attributes::getMaxConcurrencyName(), builder.getI64IntegerAttr(max_concurrency)); } lower_func_body(func_ir); mod.push_back(func); - return mod; + return func; } private: mlir::MLIRContext& ctx; @@ -653,18 +653,57 @@ void create_pipeline(PipelineRegistry& registry) register_plier_to_linalg_pipeline(registry); register_parallel_to_tbb_pipeline(registry); } -} -py::bytes lower_function(const py::object& compilation_context, const py::object& func_ir) +struct Module { -// mlir::registerDialect(); -// mlir::registerDialect(); mlir::MLIRContext context; - auto mod = plier_lowerer(context).lower(compilation_context, func_ir); PipelineRegistry registry; - create_pipeline(registry); + mlir::ModuleOp module; + + Module() + { + create_pipeline(registry); + } +}; + +mlir::FuncOp run_compiler(Module& mod, const py::object& compilation_context, const py::object& func_ir) +{ + auto& context = mod.context; + auto& module = mod.module; + auto& registry = mod.registry; + auto func = plier_lowerer(context).lower(compilation_context, module, func_ir); + auto settings = get_settings(compilation_context["compiler_settings"]); CompilerContext compiler(context, settings, registry); - compiler.run(mod); - return gen_ll_module(mod); + compiler.run(module); + return func; +} +} + +py::capsule create_module() +{ + auto mod = std::make_unique(); + { + mlir::OpBuilder builder(&mod->context); + mod->module = mlir::ModuleOp::create(builder.getUnknownLoc()); + } + py::capsule capsule(mod.get(), [](void* ptr) + { + delete static_cast(ptr); + }); + mod.release(); + return capsule; +} + +py::capsule lower_function(const py::object& compilation_context, const py::capsule& py_mod, const py::object& func_ir) +{ + auto mod = static_cast(py_mod); + auto func = run_compiler(*mod, compilation_context, func_ir); + return py::capsule(func.getOperation()); // no dtor, func owned by module +} + +py::bytes serialize_module(const py::capsule& py_mod) +{ + auto mod = static_cast(py_mod); + return gen_ll_module(mod->module); } diff --git a/mlir-compiler/src/lowering.hpp b/mlir-compiler/src/lowering.hpp index ce709920840..db21dc1d866 100644 --- a/mlir-compiler/src/lowering.hpp +++ b/mlir-compiler/src/lowering.hpp @@ -3,8 +3,14 @@ namespace pybind11 { class bytes; +class capsule; class object; } -pybind11::bytes lower_function(const pybind11::object& compilation_context, - const pybind11::object& func_ir); +pybind11::capsule create_module(); + +pybind11::capsule lower_function(const pybind11::object& compilation_context, + const pybind11::capsule& py_mod, + const pybind11::object& func_ir); + +pybind11::bytes serialize_module(const pybind11::capsule& py_mod); diff --git a/mlir-compiler/src/module.cpp b/mlir-compiler/src/module.cpp deleted file mode 100644 index fb262dc6e6b..00000000000 --- a/mlir-compiler/src/module.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include - -#include "lowering.hpp" - -namespace py = pybind11; - -namespace -{ -py::bytes lower_normal_function(py::object compilation_context, py::object func_ir) -{ - return lower_function(compilation_context, func_ir); -} -} - -PYBIND11_MODULE(mlir_compiler, m) -{ - m.def("lower_normal_function", &lower_normal_function, "todo"); -} diff --git a/mlir-compiler/src/py_module.cpp b/mlir-compiler/src/py_module.cpp new file mode 100644 index 00000000000..a4e80d15da8 --- /dev/null +++ b/mlir-compiler/src/py_module.cpp @@ -0,0 +1,12 @@ +#include + +#include "py_module.hpp" + +#include "lowering.hpp" + +PYBIND11_MODULE(mlir_compiler, m) +{ + m.def("create_module", &create_module, "todo"); + m.def("lower_function", &lower_function, "todo"); + m.def("serialize_module", &serialize_module, "todo"); +} diff --git a/mlir-compiler/src/py_module.hpp b/mlir-compiler/src/py_module.hpp new file mode 100644 index 00000000000..6f70f09beec --- /dev/null +++ b/mlir-compiler/src/py_module.hpp @@ -0,0 +1 @@ +#pragma once diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index 9810cbafc6f..3bf4ef839ee 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -125,12 +125,40 @@ mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( return mlir::failure(); } + bool ret_type_changed = false; // Update the function signature in-place. rewriter.updateRootInPlace(funcOp, [&] { + ret_type_changed = (newResults != funcOp.getType().getResults()); funcOp.setType(mlir::FunctionType::get( funcOp.getContext(), result.getConvertedTypes(), newResults)); auto res = convertRegionTypes(&funcOp.getBody(), converter, true); assert(mlir::succeeded(res)); }); + if (ret_type_changed) + { + auto ret_types = funcOp.getType().getResults(); + funcOp.walk([&](mlir::ReturnOp ret) + { + if (ret->getParentOp() == funcOp) + { + mlir::OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(ret); + for (auto it : llvm::enumerate(llvm::zip(ret.getOperandTypes(), ret_types))) + { + auto prev_type = std::get<0>(it.value()); + auto new_type = std::get<1>(it.value()); + if (prev_type != new_type) + { + auto index = static_cast(it.index()); + auto cast = rewriter.create(ret.getLoc(), new_type, ret.getOperand(index)); + rewriter.updateRootInPlace(ret, [&]() + { + ret.setOperand(index, cast); + }); + } + } + } + }); + } return mlir::success(); } diff --git a/mlir-compiler/src/transforms/pipeline_utils.cpp b/mlir-compiler/src/transforms/pipeline_utils.cpp index 154de789863..91f59faf0b1 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.cpp +++ b/mlir-compiler/src/transforms/pipeline_utils.cpp @@ -7,7 +7,7 @@ mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module) { - return module.getAttrOfType(plier::attributes::getJumpMarkersName()); + return module->getAttrOfType(plier::attributes::getJumpMarkersName()); } void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) @@ -17,7 +17,7 @@ void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) auto jump_markers = plier::attributes::getJumpMarkersName(); llvm::SmallVector name_list; - if (auto old_attr = module.getAttrOfType(jump_markers)) + if (auto old_attr = module->getAttrOfType(jump_markers)) { name_list.assign(old_attr.begin(), old_attr.end()); } @@ -34,7 +34,7 @@ void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) { name_list.insert(it, name); } - module.setAttr(jump_markers, mlir::ArrayAttr::get(name_list, module.getContext())); + module->setAttr(jump_markers, mlir::ArrayAttr::get(name_list, module.getContext())); } @@ -45,7 +45,7 @@ void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) auto jump_markers = plier::attributes::getJumpMarkersName(); llvm::SmallVector name_list; - if (auto old_attr = module.getAttrOfType(jump_markers)) + if (auto old_attr = module->getAttrOfType(jump_markers)) { name_list.assign(old_attr.begin(), old_attr.end()); } @@ -56,5 +56,5 @@ void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) }); assert(it != name_list.end()); name_list.erase(it); - module.setAttr(jump_markers, mlir::ArrayAttr::get(name_list, module.getContext())); + module->setAttr(jump_markers, mlir::ArrayAttr::get(name_list, module.getContext())); } diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 53ec79236f2..366dfd9567e 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -468,39 +468,35 @@ def run_pass(self, state): return True import numba.mlir.settings +_mlir_last_compiled_func = None +_mlir_active_module = None -@register_pass(mutates_CFG=True, analysis_only=False) -class MlirBackend(LoweringPass): - - _name = "mlir_backend" +class MlirBackendBase(FunctionPass): def __init__(self): - # LoweringPass.__init__(self) import numba.mlir.func_registry self._get_func_name = numba.mlir.func_registry.get_func_name + FunctionPass.__init__(self) - def run_pass(self, state): - targetctx = state.targetctx - library = state.library - interp = state.func_ir # why is it called this?! - typemap = state.typemap - restype = state.return_type - calltypes = state.calltypes - flags = state.flags - metadata = state.metadata + def _resolve_func_name(self, obj): + if isinstance(obj, types.Function): + func = obj.typing_key + return self._get_func_name(func) + if isinstance(obj, types.BoundFunction): + return str(obj.typing_key) + return None - msg = ("Function %s failed at nopython " - "mode lowering" % (state.func_id.func_name,)) - with fallback_context(state, msg): - # Lowering - fndesc = \ - funcdesc.PythonFunctionDescriptor.from_specialized_function( - interp, typemap, restype, calltypes, - mangler=targetctx.mangler, inline=flags.forceinline, - noalias=flags.noalias) - fn_name = fndesc.mangled_name + def _get_func_context(self, state): + mangler = state.targetctx.mangler + mangler = default_mangler if mangler is None else mangler + unique_name = state.func_ir.func_id.unique_name + modname = state.func_ir.func_id.func.__module__ + from numba.core.funcdesc import qualifying_prefix + qualprefix = qualifying_prefix(modname, unique_name) + fn_name = mangler(qualprefix, state.args) from numba.np.ufunc.parallel import get_thread_count + ctx = {} ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': numba.mlir.settings.PRINT_IR} ctx['typemap'] = lambda op: state.typemap[op.name] @@ -510,22 +506,36 @@ def run_pass(self, state): ctx['resolve_func'] = self._resolve_func_name ctx['fastmath'] = lambda: state.targetctx.fastmath ctx['max_concurrency'] = lambda: get_thread_count() if state.flags.auto_parallel.enabled else 0 + return ctx + + + +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirBackend(MlirBackendBase): + + _name = "mlir_backend" + + def __init__(self): + MlirBackendBase.__init__(self) + + def run_pass(self, state): import mlir_compiler - mod = mlir_compiler.lower_normal_function(ctx, state.func_ir) - setattr(state, 'mlir_blob', mod) + global _mlir_active_module; + old_module = _mlir_active_module + try: + module = mlir_compiler.create_module() + _mlir_active_module = module + global _mlir_last_compiled_func + ctx = self._get_func_context(state) + _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) + mod_ir = mlir_compiler.serialize_module(module) + finally: + _mlir_active_module = old_module + setattr(state, 'mlir_blob', mod_ir) _reload_parfors() state.reload_init.append(_reload_parfors) return True - def _resolve_func_name(self, obj): - if isinstance(obj, types.Function): - func = obj.typing_key - return self._get_func_name(func) - if isinstance(obj, types.BoundFunction): - return str(obj.typing_key) - return None - - @register_pass(mutates_CFG=True, analysis_only=False) class InlineOverloads(FunctionPass): """ From 60e78202c7e7d28f4b83e0cd1c3882a56cc75ffa Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 12 Jan 2021 18:31:22 +0300 Subject: [PATCH 212/259] fixes for linux (#155) --- mlir-compiler/CMakeLists.txt | 12 +++-- mlir-compiler/src/lowering.cpp | 52 +++++++++---------- .../src/pipelines/parallel_to_tbb.cpp | 4 +- mlir-compiler/src/utils.cpp | 3 +- 4 files changed, 37 insertions(+), 34 deletions(-) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index b82b39d2042..344e97689ed 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -4,6 +4,12 @@ project(mlir_compiler LANGUAGES CXX) set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +if(UNIX) + add_link_options("-Wl,--exclude-libs,ALL") +endif() + find_package(pybind11 REQUIRED) @@ -75,15 +81,13 @@ target_link_libraries(${PROJECT_NAME} PRIVATE LLVM${LLVM_NATIVE_ARCH}CodeGen LLVM${LLVM_NATIVE_ARCH}Desc LLVMTarget - MLIRSupport + MLIRIR MLIRLLVMIR - MLIRStandard MLIRTargetLLVMIR MLIRTransforms - MLIRStandardToLLVM + MLIRStandardOpsTransforms MLIRLinalgTransforms MLIRSCFToStandard - MLIRTensor MLIRTensorTransforms ) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index f430cf3a8fa..02a9fc8ca3b 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -58,6 +58,29 @@ py::list get_body(const py::handle& block) return block.attr("body").cast(); } +struct OpId +{ + llvm::StringRef op; + llvm::StringRef name; +}; + +static const constexpr OpId inst_ops_names[] = { + {"+", "add"}, // binary + {"+", "pos"}, // unary + {"-", "sub"}, // binary + {"-", "neg"}, // unary + {"*", "mul"}, + {"/", "truediv"}, + {"//", "floordiv"}, + + {">", "gt"}, + {">=", "ge"}, + {"<", "lt"}, + {"<=", "le"}, + {"!=", "ne"}, + {"==", "eq"}, +}; + struct inst_handles { inst_handles() @@ -78,7 +101,7 @@ struct inst_handles auto ops = py::module::import("operator"); - for (auto elem : llvm::zip(ops_names, ops_handles)) + for (auto elem : llvm::zip(inst_ops_names, ops_handles)) { auto name = std::get<0>(elem).name; std::get<1>(elem) = ops.attr(name.data()); @@ -98,30 +121,7 @@ struct inst_handles py::handle Const; py::handle Global; - struct OpId - { - llvm::StringRef op; - llvm::StringRef name; - }; - - static const constexpr OpId ops_names[] = { - {"+", "add"}, // binary - {"+", "pos"}, // unary - {"-", "sub"}, // binary - {"-", "neg"}, // unary - {"*", "mul"}, - {"/", "truediv"}, - {"//", "floordiv"}, - - {">", "gt"}, - {">=", "ge"}, - {"<", "lt"}, - {"<=", "le"}, - {"!=", "ne"}, - {"==", "eq"}, - }; - - std::array ops_handles; + std::array ops_handles; }; struct plier_lowerer final @@ -436,7 +436,7 @@ struct plier_lowerer final llvm::StringRef resolve_op(const py::handle& op) { - for (auto elem : llvm::zip(insts.ops_names, insts.ops_handles)) + for (auto elem : llvm::zip(inst_ops_names, insts.ops_handles)) { if (op.is(std::get<1>(elem))) { diff --git a/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/src/pipelines/parallel_to_tbb.cpp index d7a8eaf31b6..8634ec09779 100644 --- a/mlir-compiler/src/pipelines/parallel_to_tbb.cpp +++ b/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -54,8 +54,8 @@ struct ParallelToTbb : public mlir::OpRewritePattern } int64_t max_concurrency = 0; - auto mod = op.getParentOfType(); - if (auto mc = mod.getAttrOfType(plier::attributes::getMaxConcurrencyName())) + auto mod = op->getParentOfType(); + if (auto mc = mod->getAttrOfType(plier::attributes::getMaxConcurrencyName())) { max_concurrency = mc.getInt(); } diff --git a/mlir-compiler/src/utils.cpp b/mlir-compiler/src/utils.cpp index 6b760cc1125..8d3b7f19d0b 100644 --- a/mlir-compiler/src/utils.cpp +++ b/mlir-compiler/src/utils.cpp @@ -6,6 +6,5 @@ void report_error(const llvm::Twine& msg) { - auto str = msg.str(); - throw std::exception(str.c_str()); + throw std::runtime_error(msg.str()); } From 6067d0c148436dfd0446cac9498957ec545b7f84 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 14 Jan 2021 16:42:14 +0300 Subject: [PATCH 213/259] option to dump plier before any passes (#156) --- mlir-compiler/src/lowering.cpp | 10 ++++++++++ mlir-compiler/src/lowering.hpp | 3 +++ mlir-compiler/src/py_module.cpp | 1 + numba/core/compiler.py | 12 +++++++----- numba/core/typed_passes.py | 14 ++++++++++++++ numba/mlir/settings.py | 1 + 6 files changed, 36 insertions(+), 5 deletions(-) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 02a9fc8ca3b..6939bf4e9d8 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -707,3 +707,13 @@ py::bytes serialize_module(const py::capsule& py_mod) auto mod = static_cast(py_mod); return gen_ll_module(mod->module); } + +py::str module_str(const py::capsule& py_mod) +{ + auto mod = static_cast(py_mod); + std::string ret; + llvm::raw_string_ostream ss(ret); + mod->module.print(ss); + ss.flush(); + return py::str(ss.str()); +} diff --git a/mlir-compiler/src/lowering.hpp b/mlir-compiler/src/lowering.hpp index db21dc1d866..e6e454c48e7 100644 --- a/mlir-compiler/src/lowering.hpp +++ b/mlir-compiler/src/lowering.hpp @@ -5,6 +5,7 @@ namespace pybind11 class bytes; class capsule; class object; +class str; } pybind11::capsule create_module(); @@ -14,3 +15,5 @@ pybind11::capsule lower_function(const pybind11::object& compilation_context, const pybind11::object& func_ir); pybind11::bytes serialize_module(const pybind11::capsule& py_mod); + +pybind11::str module_str(const pybind11::capsule& py_mod); diff --git a/mlir-compiler/src/py_module.cpp b/mlir-compiler/src/py_module.cpp index a4e80d15da8..3609c75eaa0 100644 --- a/mlir-compiler/src/py_module.cpp +++ b/mlir-compiler/src/py_module.cpp @@ -9,4 +9,5 @@ PYBIND11_MODULE(mlir_compiler, m) m.def("create_module", &create_module, "todo"); m.def("lower_function", &lower_function, "todo"); m.def("serialize_module", &serialize_module, "todo"); + m.def("module_str", &module_str, "todo"); } diff --git a/numba/core/compiler.py b/numba/core/compiler.py index c7f44dd7403..df3a4fa70c1 100644 --- a/numba/core/compiler.py +++ b/numba/core/compiler.py @@ -29,14 +29,11 @@ ParforPass, DumpParforDiagnostics, IRLegalization, NoPythonBackend, InlineOverloads, PreLowerStripPhis, - MlirBackend) + MlirDumpPlier, MlirBackend) from numba.core.object_mode_passes import (ObjectModeFrontEnd, ObjectModeBackEnd, CompileInterpMode) - -from numba.core.lowering import _use_mlir - class Flags(utils.ConfigOptions): # These options are all false by default, but the defaults are # different with the @jit decorator (see targets.options.TargetOptions). @@ -506,7 +503,12 @@ def define_typed_pipeline(state, name="typed"): pm.add_pass(NopythonTypeInference, "nopython frontend") pm.add_pass(AnnotateTypes, "annotate types") - if _use_mlir: + import numba.mlir.settings + + if numba.mlir.settings.DUMP_PLIER: + pm.add_pass(MlirDumpPlier, "mlir dump plier") + + if numba.mlir.settings.USE_MLIR: pm.add_pass(MlirBackend, "mlir backend") # strip phis diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 366dfd9567e..c102f8b4507 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -508,7 +508,21 @@ def _get_func_context(self, state): ctx['max_concurrency'] = lambda: get_thread_count() if state.flags.auto_parallel.enabled else 0 return ctx +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirDumpPlier(MlirBackendBase): + + _name = "mlir_dump_plier" + + def __init__(self): + MlirBackendBase.__init__(self) + def run_pass(self, state): + import mlir_compiler + module = mlir_compiler.create_module() + ctx = self._get_func_context(state) + mlir_compiler.lower_function(ctx, module, state.func_ir) + print(mlir_compiler.module_str(module)) + return True @register_pass(mutates_CFG=True, analysis_only=False) class MlirBackend(MlirBackendBase): diff --git a/numba/mlir/settings.py b/numba/mlir/settings.py index ec027eaaaac..8f09d635b3c 100644 --- a/numba/mlir/settings.py +++ b/numba/mlir/settings.py @@ -12,4 +12,5 @@ def _readenv(name, ctor, default): return default USE_MLIR = _readenv('NUMBA_MLIR_ENABLE', int, 1) +DUMP_PLIER = _readenv('NUMBA_MLIR_DUMP_PLIER', int, 0) PRINT_IR = _readenv('NUMBA_MLIR_PRINT_IR', int, 0) From ae8f354b68f05b9863d9c1f459d0622205604d82 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 19 Jan 2021 15:58:20 +0300 Subject: [PATCH 214/259] simple CSE rewrite (#157) --- mlir-compiler/CMakeLists.txt | 2 + .../src/pipelines/plier_to_linalg.cpp | 4 +- mlir-compiler/src/rewrites/cse.cpp | 88 +++++++++++++++++++ mlir-compiler/src/rewrites/cse.hpp | 25 ++++++ numba/mlir/tests/test_numpy.py | 14 +++ 5 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 mlir-compiler/src/rewrites/cse.cpp create mode 100644 mlir-compiler/src/rewrites/cse.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 344e97689ed..edbabbc8be4 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -34,6 +34,7 @@ set(SOURCES_LIST src/rewrites/call_lowering.cpp src/rewrites/canonicalize_reductions.cpp src/rewrites/cast_lowering.cpp + src/rewrites/cse.cpp src/rewrites/promote_to_parallel.cpp src/rewrites/type_conversion.cpp src/transforms/func_utils.cpp @@ -57,6 +58,7 @@ set(HEADERS_LIST src/rewrites/call_lowering.hpp src/rewrites/canonicalize_reductions.hpp src/rewrites/cast_lowering.hpp + src/rewrites/cse.hpp src/rewrites/promote_to_parallel.hpp src/rewrites/type_conversion.hpp src/transforms/func_utils.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index f630cc920ec..ab861b40265 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -25,6 +25,7 @@ #include "rewrites/call_lowering.hpp" #include "rewrites/canonicalize_reductions.hpp" #include "rewrites/cast_lowering.hpp" +#include "rewrites/cse.hpp" #include "rewrites/promote_to_parallel.hpp" #include "rewrites/type_conversion.hpp" #include "transforms/loop_utils.hpp" @@ -637,7 +638,8 @@ void PostLinalgOptPass::runOnOperation() patterns.insert< CanonicalizeReduction, // LoopInvariantCodeMotion, TODO - PromoteToParallel + PromoteToParallel, + CSERewrite >(&getContext()); diff --git a/mlir-compiler/src/rewrites/cse.cpp b/mlir-compiler/src/rewrites/cse.cpp new file mode 100644 index 00000000000..ee4dc3eb203 --- /dev/null +++ b/mlir-compiler/src/rewrites/cse.cpp @@ -0,0 +1,88 @@ +#include "rewrites/cse.hpp" + +#include +#include +#include + +#include +#include + +namespace +{ +struct SimpleOperationInfo : public llvm::DenseMapInfo { + static unsigned getHashValue(const mlir::Operation *opC) { + return static_cast(mlir::OperationEquivalence::computeHash(const_cast(opC))); + } + static bool isEqual(const mlir::Operation *lhsC, const mlir::Operation *rhsC) { + auto *lhs = const_cast(lhsC); + auto *rhs = const_cast(rhsC); + if (lhs == rhs) + return true; + if (lhs == getTombstoneKey() || lhs == getEmptyKey() || + rhs == getTombstoneKey() || rhs == getEmptyKey()) + return false; + return mlir::OperationEquivalence::isEquivalentTo(const_cast(lhsC), + const_cast(rhsC)); + } +}; + +using AllocatorTy = llvm::RecyclingAllocator< + llvm::BumpPtrAllocator, + llvm::ScopedHashTableVal>; +using ScopedMapTy = llvm::ScopedHashTable; + +mlir::LogicalResult simplifyRegion(ScopedMapTy& map, mlir::Region& region, mlir::PatternRewriter& rewriter) +{ + if (region.empty() || std::next(region.begin()) != region.end()) + { + return mlir::failure(); + } + + bool success = false; + for (auto &inst : llvm::make_early_inc_range(region.front())) + { + if (inst.isKnownTerminator()) + { + break; + } + if (!mlir::MemoryEffectOpInterface::hasNoEffect(&inst)) + { + continue; + } + if (!inst.getRegions().empty()) + { + for (auto& reg : inst.getRegions()) + { + ScopedMapTy::ScopeTy scope(map); + if (mlir::succeeded(simplifyRegion(map, reg, rewriter))) + { + success = true; + } + } + continue; + } + + auto* previous_op = map.lookup(&inst); + if (previous_op != nullptr) + { + rewriter.replaceOp(&inst, previous_op->getResults()); + success = true; + } + else + { + map.insert(&inst, &inst); + } + } + return mlir::success(success); +} +} + + + +mlir::LogicalResult CSE::detail::applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter) +{ + ScopedMapTy map; + ScopedMapTy::ScopeTy scope(map); + return simplifyRegion(map, region, rewriter); +} diff --git a/mlir-compiler/src/rewrites/cse.hpp b/mlir-compiler/src/rewrites/cse.hpp new file mode 100644 index 00000000000..1536e8c0c9f --- /dev/null +++ b/mlir-compiler/src/rewrites/cse.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace CSE +{ +namespace detail +{ +mlir::LogicalResult applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter); +} +} + +template +struct CSERewrite : public mlir::OpRewritePattern +{ + CSERewrite(mlir::MLIRContext *context): + OpRewritePattern(context, /*benefit*/0) {} + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter &rewriter) const override + { + return ::CSE::detail::applyCSE(op.getRegion(), rewriter); + } +}; diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index a90d0ac7c69..d6fcef955a5 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -79,5 +79,19 @@ def py_func(a): arr = np.asarray([3,2,1]) assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + def test_array_bounds(self): + def py_func(a): + res = 0 + for i in range(len(a)): + if i >= len(a): + res = res + 1 + else: + res = res + a[i] + return res + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + if __name__ == '__main__': unittest.main() From 720e1fa5b2669262dc64a7499fdeb5f23d92b7ce Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 19 Jan 2021 16:59:04 +0300 Subject: [PATCH 215/259] [MLIR] Getattr+global rewrite and CallRewrite simplification (#158) * rewrite `global + getattr` to `global` * refactor call lowering --- mlir-compiler/include/plier/PlierOps.td | 2 + mlir-compiler/src/dialect.cpp | 29 ++++++++++++++ .../src/pipelines/plier_to_linalg.cpp | 2 +- mlir-compiler/src/pipelines/plier_to_std.cpp | 6 +-- mlir-compiler/src/rewrites/call_lowering.cpp | 38 ++----------------- mlir-compiler/src/rewrites/call_lowering.hpp | 2 +- 6 files changed, 40 insertions(+), 39 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index bf43e332c40..a4ef3339d04 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -205,6 +205,8 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { let results = (outs AnyType); + let hasCanonicalizer = 1; + let builders = [ OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::StringRef":$name)> ]; diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 9e0636529e6..e7b66373647 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include @@ -272,6 +273,34 @@ void GetattrOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, value, name); } +namespace +{ +struct GetattrGlobalRewrite : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + GetattrOp op, mlir::PatternRewriter &rewriter) const override + { + auto prev_op = mlir::dyn_cast_or_null(op.getOperand().getDefiningOp()); + if (prev_op) + { + auto new_name = llvm::Twine(prev_op.name() + "." + op.name()).str(); + auto new_op = rewriter.create(op.getLoc(), op.getType(), new_name); + rewriter.replaceOp(op, new_op.getResult()); + return mlir::success(); + } + return mlir::failure(); + } +}; +} + +void GetattrOp::getCanonicalizationPatterns( + ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) +{ + results.insert(context); +} + mlir::LogicalResult ParallelOp::moveOutOfLoop(mlir::ArrayRef ops) { for (mlir::Operation *op : ops) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index ab861b40265..c92e8d99247 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -555,7 +555,7 @@ void PlierToLinalgPass::runOnOperation() patterns.insert< CallOpLowering - >(type_converter, &getContext(), &call_rewrite); + >(type_converter, &getContext(), call_rewrite); patterns.insert< GetitemOpLowering, diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index e12a8d78759..a55f1a1175f 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1204,13 +1204,13 @@ void PlierToStdPass::runOnOperation() FixupWhileTypes >(type_converter, context); - patterns.insert< + patterns.insert< CastOpLowering >(type_converter, context, &do_cast); - patterns.insert< + patterns.insert< CallOpLowering - >(type_converter, context, &basic_rewrite); + >(type_converter, context, basic_rewrite); mlir::populateStdExpandOpsPatterns(context, patterns); diff --git a/mlir-compiler/src/rewrites/call_lowering.cpp b/mlir-compiler/src/rewrites/call_lowering.cpp index d107845d5dc..36295f03ee9 100644 --- a/mlir-compiler/src/rewrites/call_lowering.cpp +++ b/mlir-compiler/src/rewrites/call_lowering.cpp @@ -1,26 +1,5 @@ #include "call_lowering.hpp" -namespace -{ -llvm::StringRef extract_bound_func_name(llvm::StringRef name) -{ - assert(!name.empty()); - auto len = name.find(' '); - return name.substr(0, len); -} - -bool check_class_name(llvm::StringRef& str, llvm::StringRef prefix) -{ - llvm::StringRef temp = str; - if (temp.consume_front(prefix) && temp.consume_front("(") && temp.consume_back(")")) - { - str = temp; - return true; - } - return false; -} -} - CallOpLowering::CallOpLowering( mlir::TypeConverter&, mlir::MLIRContext* context, CallOpLowering::resolver_t resolver): @@ -38,33 +17,24 @@ mlir::LogicalResult CallOpLowering::matchAndRewrite(plier::PyCallOp op, mlir::Pa { return mlir::failure(); } - auto name = func_type.cast().getName(); + llvm::SmallVector arg_types; llvm::SmallVector args; - if (check_class_name(name, "Function")) + auto getattr = mlir::dyn_cast_or_null(operands[0].getDefiningOp()); + if (!getattr) { llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); // TODO kwargs } - else if (check_class_name(name, "BoundFunction")) + else { - auto getattr = mlir::dyn_cast(operands[0].getDefiningOp()); - if (!getattr) - { - return mlir::failure(); - } arg_types.push_back(getattr.getOperand().getType()); args.push_back(getattr.getOperand()); llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); - name = extract_bound_func_name(name); // TODO kwargs } - else - { - return mlir::failure(); - } return resolver(op, op.func_name(), args, rewriter); } diff --git a/mlir-compiler/src/rewrites/call_lowering.hpp b/mlir-compiler/src/rewrites/call_lowering.hpp index 93046a5795a..06200938f16 100644 --- a/mlir-compiler/src/rewrites/call_lowering.hpp +++ b/mlir-compiler/src/rewrites/call_lowering.hpp @@ -13,7 +13,7 @@ class TypeConverter; struct CallOpLowering : public mlir::OpRewritePattern { - using resolver_t = std::function, mlir::PatternRewriter&)>; + using resolver_t = llvm::function_ref, mlir::PatternRewriter&)>; CallOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, From c30bf744155992a327f7728e91e7385dae37ab66 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 20 Jan 2021 02:28:34 +0300 Subject: [PATCH 216/259] [MLIR] call njit funcs from other njit funcs (#159) --- mlir-compiler/CMakeLists.txt | 4 + mlir-compiler/src/dialect.cpp | 15 +- mlir-compiler/src/lowering.cpp | 16 +- mlir-compiler/src/lowering.hpp | 3 +- mlir-compiler/src/mangle.cpp | 244 ++++++++++++++++++ mlir-compiler/src/mangle.hpp | 18 ++ mlir-compiler/src/pipelines/plier_to_std.cpp | 112 ++++++-- mlir-compiler/src/py_func_resolver.cpp | 144 +++++++++++ mlir-compiler/src/py_func_resolver.hpp | 29 +++ mlir-compiler/src/py_module.cpp | 2 +- .../src/rewrites/type_conversion.cpp | 21 ++ numba/core/typed_passes.py | 57 +++- numba/mlir/func_registry.py | 30 ++- numba/mlir/inner_compiler.py | 34 +++ numba/mlir/tests/test_basic.py | 26 ++ 15 files changed, 710 insertions(+), 45 deletions(-) create mode 100644 mlir-compiler/src/mangle.cpp create mode 100644 mlir-compiler/src/mangle.hpp create mode 100644 mlir-compiler/src/py_func_resolver.cpp create mode 100644 mlir-compiler/src/py_func_resolver.hpp create mode 100644 numba/mlir/inner_compiler.py diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index edbabbc8be4..5ceaf24b03f 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -43,7 +43,9 @@ set(SOURCES_LIST src/compiler.cpp src/dialect.cpp src/lowering.cpp + src/mangle.cpp src/pipeline_registry.cpp + src/py_func_resolver.cpp src/py_module.cpp src/utils.cpp ) @@ -66,7 +68,9 @@ set(HEADERS_LIST src/transforms/pipeline_utils.hpp src/compiler.hpp src/lowering.hpp + src/mangle.hpp src/pipeline_registry.hpp + src/py_func_resolver.hpp src/py_module.hpp src/utils.hpp ) diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index e7b66373647..cb620782964 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -107,14 +107,17 @@ void ArgOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::OpFoldResult ArgOp::fold(llvm::ArrayRef /*operands*/) { - auto func = getParentOfType(); - auto ind = index(); - if (ind >= func.getNumArguments() || - func.getArgument(ind).getType() != getType()) + auto func = getOperation()->getParentOfType(); + if (func) { - return nullptr; + auto ind = index(); + if (ind < func.getNumArguments() && + func.getArgument(ind).getType() == getType()) + { + return func.getArgument(ind); + } } - return func.getArgument(ind); + return nullptr; } void ConstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 6939bf4e9d8..5299f03fcf5 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -98,6 +98,7 @@ struct inst_handles Var = mod.attr("Var"); Const = mod.attr("Const"); Global = mod.attr("Global"); + FreeVar = mod.attr("FreeVar"); auto ops = py::module::import("operator"); @@ -120,6 +121,7 @@ struct inst_handles py::handle Var; py::handle Const; py::handle Global; + py::handle FreeVar; std::array ops_handles; }; @@ -276,7 +278,8 @@ struct plier_lowerer final { return get_const(value.attr("value")); } - if (py::isinstance(value, insts.Global)) + if (py::isinstance(value, insts.Global) || + py::isinstance(value, insts.FreeVar)) { auto name = value.attr("name").cast(); return builder.create(get_current_loc(), @@ -666,17 +669,15 @@ struct Module } }; -mlir::FuncOp run_compiler(Module& mod, const py::object& compilation_context, const py::object& func_ir) +void run_compiler(Module& mod, const py::object& compilation_context) { auto& context = mod.context; auto& module = mod.module; auto& registry = mod.registry; - auto func = plier_lowerer(context).lower(compilation_context, module, func_ir); auto settings = get_settings(compilation_context["compiler_settings"]); CompilerContext compiler(context, settings, registry); compiler.run(module); - return func; } } @@ -698,13 +699,16 @@ py::capsule create_module() py::capsule lower_function(const py::object& compilation_context, const py::capsule& py_mod, const py::object& func_ir) { auto mod = static_cast(py_mod); - auto func = run_compiler(*mod, compilation_context, func_ir); + auto& context = mod->context; + auto& module = mod->module; + auto func = plier_lowerer(context).lower(compilation_context, module, func_ir); return py::capsule(func.getOperation()); // no dtor, func owned by module } -py::bytes serialize_module(const py::capsule& py_mod) +py::bytes compile_module(const py::object& compilation_context, const py::capsule& py_mod) { auto mod = static_cast(py_mod); + run_compiler(*mod, compilation_context); return gen_ll_module(mod->module); } diff --git a/mlir-compiler/src/lowering.hpp b/mlir-compiler/src/lowering.hpp index e6e454c48e7..0580e7e58f3 100644 --- a/mlir-compiler/src/lowering.hpp +++ b/mlir-compiler/src/lowering.hpp @@ -14,6 +14,7 @@ pybind11::capsule lower_function(const pybind11::object& compilation_context, const pybind11::capsule& py_mod, const pybind11::object& func_ir); -pybind11::bytes serialize_module(const pybind11::capsule& py_mod); +pybind11::bytes compile_module(const pybind11::object& compilation_context, + const pybind11::capsule& py_mod); pybind11::str module_str(const pybind11::capsule& py_mod); diff --git a/mlir-compiler/src/mangle.cpp b/mlir-compiler/src/mangle.cpp new file mode 100644 index 00000000000..826194c7d57 --- /dev/null +++ b/mlir-compiler/src/mangle.cpp @@ -0,0 +1,244 @@ +#include "mangle.hpp" + +#include +#include + +#include +#include +#include + +#include + +namespace +{ +static const constexpr auto PREFIX = "_Z"; + +template +bool mangle_int(llvm::raw_ostream& res, mlir::Type type) +{ + if (auto i = type.dyn_cast()) + { + if (i.getWidth() == Width && i.getSignedness() == Sign) + { + res << Symbol; + return true; + } + } + return false; +} + +template +bool mangle_float(llvm::raw_ostream& res, mlir::Type type) +{ + if (auto i = type.dyn_cast()) + { + if (i.getWidth() == Width) + { + res << Symbol; + return true; + } + } + return false; +} + +void mangle_memref_impl(llvm::raw_ostream& res, mlir::MemRefType type); + +bool mangle_memref(llvm::raw_ostream& res, mlir::Type type) +{ + if (auto m = type.dyn_cast()) + { + mangle_memref_impl(res, m); + return true; + } + return false; +} + +using type_mangler_t = bool(*)(llvm::raw_ostream&, mlir::Type); + +static const constexpr type_mangler_t type_manglers[] = { + &mangle_int<1, mlir::IntegerType::Signed, 'b'>, + &mangle_int<1, mlir::IntegerType::Unsigned, 'b'>, + &mangle_int<1, mlir::IntegerType::Signless, 'b'>, + + &mangle_int<8, mlir::IntegerType::Signed, 'a'>, + &mangle_int<8, mlir::IntegerType::Unsigned, 'h'>, + &mangle_int<8, mlir::IntegerType::Signless, 'c'>, + + &mangle_int<16, mlir::IntegerType::Signed, 's'>, + &mangle_int<16, mlir::IntegerType::Unsigned, 't'>, + &mangle_int<16, mlir::IntegerType::Signless, 's'>, + + &mangle_int<32, mlir::IntegerType::Signed, 'i'>, + &mangle_int<32, mlir::IntegerType::Unsigned, 'j'>, + &mangle_int<32, mlir::IntegerType::Signless, 'i'>, + + &mangle_int<64, mlir::IntegerType::Signed, 'x'>, + &mangle_int<64, mlir::IntegerType::Unsigned, 'm'>, + &mangle_int<64, mlir::IntegerType::Signless, 'x'>, + + &mangle_int<128, mlir::IntegerType::Signed, 'n'>, + &mangle_int<128, mlir::IntegerType::Unsigned, 'o'>, + &mangle_int<128, mlir::IntegerType::Signless, 'n'>, + + &mangle_float<32, 'f'>, + &mangle_float<64, 'd'>, + &mangle_float<80, 'e'>, + &mangle_float<128, 'g'>, + + &mangle_memref, +}; + +bool check_type(mlir::Type type) +{ + llvm::raw_null_ostream ss; + for (auto mangler : type_manglers) + { + if (mangler(ss, type)) + { + return true; + } + } + return false; +} + +bool is_valid_char(char c) +{ + return (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') || + (c == '_'); +} + +std::string escape_string(llvm::StringRef str) +{ + std::string ret; + llvm::raw_string_ostream ss(ret); + for (auto c : str) + { + if (is_valid_char(c)) + { + ss << c; + } + else + { + ss << "$" << llvm::format_hex_no_prefix(static_cast(c), 2); + } + } + ss.flush(); + return ret; +} + +template +void mangle_ident_impl(llvm::raw_ostream& res, llvm::StringRef ident, F&& template_params) +{ + assert(!ident.empty()); + llvm::SmallVector parts; + ident.split(parts, '.'); + assert(!parts.empty()); + auto write_part = [&](auto part) + { + auto escaped = escape_string(part); + if (std::isdigit(escaped.front())) + { + res << escaped.size() + 1 << '_' << escaped; + } + else + { + res << escaped.size() << escaped; + } + }; + if (parts.size() == 1) + { + write_part(parts.front()); + template_params(res); + } + else + { + res << 'N'; + for (auto& part : parts) + { + write_part(part); + } + template_params(res); + res << 'E'; + } +} + +void mangle_ident(llvm::raw_ostream& res, llvm::StringRef ident) +{ + auto dummy = [](auto&) {}; + mangle_ident_impl(res, ident, dummy); +} + +template +void mangle_ident(llvm::raw_ostream& res, llvm::StringRef ident, F&& template_params) +{ + auto wrap_template = [&](llvm::raw_ostream& s) + { + s << 'I'; + template_params(s); + s << 'E'; + }; + mangle_ident_impl(res, ident, wrap_template); +} + +void mangle_type(llvm::raw_ostream& res, mlir::Type type) +{ + for(auto m : type_manglers) + { + if (m(res, type)) + { + return; + } + } + llvm_unreachable("Cannot mangle type"); +} + +void mangle_memref_impl(llvm::raw_ostream& res, mlir::MemRefType type) +{ + auto params = [&](llvm::raw_ostream& s) + { + mangle_type(s, type.getElementType()); + s << "Li"<< type.getRank() << "E"; + mangle_ident(s, "C"); + }; + mangle_ident(res, "array", params); +} + +void mangle_types(llvm::raw_ostream& res, mlir::TypeRange types) +{ + for (auto type : types) + { + mangle_type(res, type); + } +} + +} + +bool mangle(llvm::raw_ostream& res, llvm::StringRef ident, mlir::TypeRange types) +{ + for (auto type : types) + { + if (!check_type(type)) + { + return false; + } + } + res << PREFIX; + mangle_ident(res, ident); + mangle_types(res, types); + return true; +} + + +std::string mangle(llvm::StringRef ident, mlir::TypeRange types) +{ + std::string ret; + llvm::raw_string_ostream ss(ret); + if (!mangle(ss, ident, types)) + { + return {}; + } + ss.flush(); + return ret; +} diff --git a/mlir-compiler/src/mangle.hpp b/mlir-compiler/src/mangle.hpp new file mode 100644 index 00000000000..6d050392757 --- /dev/null +++ b/mlir-compiler/src/mangle.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include + +namespace llvm +{ +class StringRef; +class raw_ostream; +} + +namespace mlir +{ +class TypeRange; +} + +bool mangle(llvm::raw_ostream& res, llvm::StringRef ident, mlir::TypeRange types); + +std::string mangle(llvm::StringRef ident, mlir::TypeRange types); diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index a55f1a1175f..7d2cbb5b878 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -23,6 +23,8 @@ #include "base_pipeline.hpp" #include "pipeline_registry.hpp" +#include "py_func_resolver.hpp" +#include "mangle.hpp" namespace { @@ -231,6 +233,41 @@ struct ConstOpLowering : public mlir::OpRewritePattern } }; +struct ArgOpLowering : public mlir::OpRewritePattern +{ + ArgOpLowering(mlir::TypeConverter &typeConverter, + mlir::MLIRContext *context): + OpRewritePattern(context), converter(typeConverter) {} + + mlir::LogicalResult matchAndRewrite( + plier::ArgOp op, mlir::PatternRewriter &rewriter) const override + { + auto func = op->getParentOfType(); + if (!func) + { + return mlir::failure(); + } + + auto index= op.index(); + if (index >= func.getNumArguments()) + { + return mlir::failure(); + } + + auto arg = func.getArgument(index); + if(converter.convertType(op.getType()) != arg.getType()) + { + return mlir::failure(); + } + rewriter.replaceOp(op, arg); + return mlir::success(); + } +private: + mlir::TypeConverter& converter; +}; + + + struct ReturnOpLowering : public mlir::OpRewritePattern { ReturnOpLowering(mlir::TypeConverter &/*typeConverter*/, @@ -1102,7 +1139,7 @@ mlir::FuncOp get_lib_symbol( mlir::PatternRewriter& rewriter) { assert(!name.empty()); - if (auto op = mlir::dyn_cast_or_null(mod.lookupSymbol(name))) + if (auto op = mod.lookupSymbol(name)) { assert(op.getType() == type); return op; @@ -1125,7 +1162,7 @@ mlir::LogicalResult lower_math_func( { auto is_float = ret_type.isa(); auto func_type = mlir::FunctionType::get(op.getContext(), args[0].getType(), ret_type); - auto module = op.getParentOfType(); + auto module = op->getParentOfType(); mlir::FuncOp func; if (is_float) { @@ -1143,28 +1180,60 @@ mlir::LogicalResult lower_math_func( return mlir::failure(); } -mlir::LogicalResult basic_rewrite( - plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, - mlir::PatternRewriter& rewriter) +struct CallLowerer { - if (mlir::succeeded(lower_math_func(op, name, args, rewriter))) + mlir::LogicalResult operator()(plier::PyCallOp op, llvm::StringRef name, + llvm::ArrayRef args, mlir::PatternRewriter& rewriter) { - return mlir::success(); - } - using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); - std::pair handlers[] = { - {"bool", lower_bool_cast}, - {"range", lower_range}, - }; - for (auto& handler : handlers) - { - if (handler.first == name) + if (mlir::succeeded(lower_math_func(op, name, args, rewriter))) + { + return mlir::success(); + } + + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); + std::pair handlers[] = { + {"bool", lower_bool_cast}, + {"range", lower_range}, + }; + for (auto& handler : handlers) { - return handler.second(op, args, rewriter); + if (handler.first == name) + { + return handler.second(op, args, rewriter); + } } + + mlir::ValueRange r(args); + auto mangled_name = mangle(name, r.getTypes()); + if (!mangled_name.empty()) + { + auto mod = op->getParentOfType(); + assert(mod); + auto func = mod.lookupSymbol(mangled_name); + if (!func) + { + func = py_resolver.get_func(name, r.getTypes()); + if (func) + { + func.setPrivate(); + func.setName(mangled_name); + } + } + if (func) + { + assert(func.getType().getNumResults() == op->getNumResults()); + auto new_func_call = rewriter.create(op.getLoc(), func, args); + rewriter.replaceOp(op, new_func_call.getResults()); + return mlir::success(); + } + } + + return mlir::failure(); } - return mlir::failure(); -} + +private: + PyFuncResolver py_resolver; +}; struct PlierToStdPass : public mlir::PassWrapper> @@ -1193,6 +1262,7 @@ void PlierToStdPass::runOnOperation() patterns.insert< FuncOpSignatureConversion, + ArgOpLowering, ReturnOpLowering, ConstOpLowering, SelectOpLowering, @@ -1208,9 +1278,11 @@ void PlierToStdPass::runOnOperation() CastOpLowering >(type_converter, context, &do_cast); + CallLowerer callLowerer; + patterns.insert< CallOpLowering - >(type_converter, context, basic_rewrite); + >(type_converter, context, callLowerer); mlir::populateStdExpandOpsPatterns(context, patterns); diff --git a/mlir-compiler/src/py_func_resolver.cpp b/mlir-compiler/src/py_func_resolver.cpp new file mode 100644 index 00000000000..851498b5e07 --- /dev/null +++ b/mlir-compiler/src/py_func_resolver.cpp @@ -0,0 +1,144 @@ +#include "py_func_resolver.hpp" + +#include + +#include + +namespace py = pybind11; + +namespace +{ + +template +bool is_int(mlir::Type type) +{ + if (auto t = type.dyn_cast()) + { + if (t.getWidth() == Width && t.getSignedness() == Signed) + { + return true; + } + } + return false; +} + +template +bool is_float(mlir::Type type) +{ + if (auto f = type.dyn_cast()) + { + if (f.getWidth() == Width) + { + return true; + } + } + return false; +} + +py::handle map_type(const py::handle& types_mod, mlir::Type type) +{ + using fptr_t = bool(*)(mlir::Type); + const std::pair primitive_types[] = { + {&is_int<1, mlir::IntegerType::Signed>, "boolean"}, + {&is_int<1, mlir::IntegerType::Signless>, "boolean"}, + {&is_int<1, mlir::IntegerType::Unsigned>, "boolean"}, + + {&is_int<8, mlir::IntegerType::Signed>, "int8"}, + {&is_int<8, mlir::IntegerType::Signless>, "int8"}, + {&is_int<8, mlir::IntegerType::Unsigned>, "uint8"}, + + {&is_int<16, mlir::IntegerType::Signed>, "int16"}, + {&is_int<16, mlir::IntegerType::Signless>, "int16"}, + {&is_int<16, mlir::IntegerType::Unsigned>, "uint16"}, + + {&is_int<32, mlir::IntegerType::Signed>, "int32"}, + {&is_int<32, mlir::IntegerType::Signless>, "int32"}, + {&is_int<32, mlir::IntegerType::Unsigned>, "uint32"}, + + {&is_int<64, mlir::IntegerType::Signed>, "int64"}, + {&is_int<64, mlir::IntegerType::Signless>, "int64"}, + {&is_int<64, mlir::IntegerType::Unsigned>, "uint64"}, + + {&is_float<32>, "float"}, + {&is_float<64>, "double"}, + }; + + for (auto h : primitive_types) + { + if (h.first(type)) + { + auto name = h.second; + return types_mod.attr(py::str(name.data(), name.size())); + } + } + + if (auto m = type.dyn_cast()) + { + auto elem_type = map_type(types_mod, m.getElementType()); + if (!elem_type) + { + return {}; + } + auto ndims = py::int_(m.getRank()); + auto array_type = types_mod.attr("Array"); + return array_type(elem_type, ndims, py::str("C")); + } + return {}; +} + +py::list map_types(const py::handle& types_mod, mlir::TypeRange types) +{ + py::list ret; + for (auto type : types) + { + auto elem = map_type(types_mod, type); + if (!elem) + { + return py::none(); + } + ret.append(std::move(elem)); + } + return ret; +} +} + +struct PyFuncResolver::Context +{ + py::handle resolver; + py::handle compiler; + py::handle types; +}; + +PyFuncResolver::PyFuncResolver(): + context(std::make_unique()) +{ + auto registry_mod = py::module::import("numba.mlir.func_registry"); + auto compiler_mod = py::module::import("numba.mlir.inner_compiler"); + context->resolver = registry_mod.attr("find_active_func"); + context->compiler = compiler_mod.attr("compile_func"); + context->types = py::module::import("numba.core.types"); +} + +PyFuncResolver::~PyFuncResolver() +{ + +} + +mlir::FuncOp PyFuncResolver::get_func(llvm::StringRef name, mlir::TypeRange types) +{ + assert(!name.empty()); + auto py_name = py::str(name.data(), name.size()); + auto py_func = context->resolver(py_name); + if (py_func.is_none()) + { + return {}; + } + auto py_types = map_types(context->types, types); + if (py_types.is_none()) + { + return {}; + } + auto res = static_cast(context->compiler(py_func, py_types).cast()); + auto func = (res ? mlir::cast(res) : nullptr); + return func; +} diff --git a/mlir-compiler/src/py_func_resolver.hpp b/mlir-compiler/src/py_func_resolver.hpp new file mode 100644 index 00000000000..403011a1790 --- /dev/null +++ b/mlir-compiler/src/py_func_resolver.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include + +namespace llvm +{ +class StringRef; +} + +namespace mlir +{ +class FuncOp; +class TypeRange; +} + +class PyFuncResolver +{ +public: + PyFuncResolver(); + ~PyFuncResolver(); + + PyFuncResolver(const PyFuncResolver&) = delete; + + mlir::FuncOp get_func(llvm::StringRef name, mlir::TypeRange types); + +private: + struct Context; + std::unique_ptr context; +}; diff --git a/mlir-compiler/src/py_module.cpp b/mlir-compiler/src/py_module.cpp index 3609c75eaa0..20fb44dd6b6 100644 --- a/mlir-compiler/src/py_module.cpp +++ b/mlir-compiler/src/py_module.cpp @@ -8,6 +8,6 @@ PYBIND11_MODULE(mlir_compiler, m) { m.def("create_module", &create_module, "todo"); m.def("lower_function", &lower_function, "todo"); - m.def("serialize_module", &serialize_module, "todo"); + m.def("compile_module", &compile_module, "todo"); m.def("module_str", &module_str, "todo"); } diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/src/rewrites/type_conversion.cpp index 3bf4ef839ee..4cd262275fe 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/src/rewrites/type_conversion.cpp @@ -134,6 +134,7 @@ mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( auto res = convertRegionTypes(&funcOp.getBody(), converter, true); assert(mlir::succeeded(res)); }); + if (ret_type_changed) { auto ret_types = funcOp.getType().getResults(); @@ -159,6 +160,26 @@ mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( } } }); + auto mod = funcOp->getParentOfType(); + auto uses = funcOp.getSymbolUses(mod); + if (uses) + { + for (auto use : *uses) + { + if (auto call = mlir::dyn_cast(use.getUser())) + { + rewriter.updateRootInPlace(call, [&]() + { + for (auto it : llvm::zip(call.getResults(), ret_types)) + { + auto res = std::get<0>(it); + auto type = std::get<1>(it); + res.setType(type); + } + }); + } + } + } } return mlir::success(); } diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index c102f8b4507..0aa0e74b8cc 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -468,6 +468,8 @@ def run_pass(self, state): return True import numba.mlir.settings +import numba.mlir.func_registry +import numba.core.types.functions _mlir_last_compiled_func = None _mlir_active_module = None @@ -478,13 +480,30 @@ def __init__(self): self._get_func_name = numba.mlir.func_registry.get_func_name FunctionPass.__init__(self) + def run_pass(self, state): + numba.mlir.func_registry.push_active_funcs_stack() + try: + res = self.run_pass_impl(state) + finally: + numba.mlir.func_registry.pop_active_funcs_stack() + return res + def _resolve_func_name(self, obj): + name, func = self._resolve_func_name_impl(obj) + if not (name is None or func is None): + numba.mlir.func_registry.add_active_funcs(name, func) + return name + + def _resolve_func_name_impl(self, obj): if isinstance(obj, types.Function): func = obj.typing_key - return self._get_func_name(func) + return (self._get_func_name(func), None) if isinstance(obj, types.BoundFunction): - return str(obj.typing_key) - return None + return (str(obj.typing_key), None) + if isinstance(obj, numba.core.types.functions.Dispatcher): + func = obj.dispatcher.py_func + return (func.__module__ + "." + func.__qualname__, func) + return (None, None) def _get_func_context(self, state): mangler = state.targetctx.mangler @@ -524,6 +543,10 @@ def run_pass(self, state): print(mlir_compiler.module_str(module)) return True +def get_mlir_func(): + global _mlir_last_compiled_func + return _mlir_last_compiled_func + @register_pass(mutates_CFG=True, analysis_only=False) class MlirBackend(MlirBackendBase): @@ -532,17 +555,18 @@ class MlirBackend(MlirBackendBase): def __init__(self): MlirBackendBase.__init__(self) - def run_pass(self, state): + def run_pass_impl(self, state): import mlir_compiler - global _mlir_active_module; + global _mlir_active_module old_module = _mlir_active_module + try: module = mlir_compiler.create_module() _mlir_active_module = module global _mlir_last_compiled_func ctx = self._get_func_context(state) _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) - mod_ir = mlir_compiler.serialize_module(module) + mod_ir = mlir_compiler.compile_module(ctx, module) finally: _mlir_active_module = old_module setattr(state, 'mlir_blob', mod_ir) @@ -550,6 +574,27 @@ def run_pass(self, state): state.reload_init.append(_reload_parfors) return True +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirBackendInner(MlirBackendBase): + + _name = "mlir_backend_inner" + + def __init__(self): + MlirBackendBase.__init__(self) + + def run_pass_impl(self, state): + import mlir_compiler + global _mlir_active_module + module = _mlir_active_module + assert not module is None + global _mlir_last_compiled_func + ctx = self._get_func_context(state) + _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) + from numba.core.compiler import compile_result + state.cr = compile_result() + return True + + @register_pass(mutates_CFG=True, analysis_only=False) class InlineOverloads(FunctionPass): """ diff --git a/numba/mlir/func_registry.py b/numba/mlir/func_registry.py index 926906b56d7..3570a90f038 100644 --- a/numba/mlir/func_registry.py +++ b/numba/mlir/func_registry.py @@ -1,10 +1,6 @@ _mlir_func_names = {} - # id(range) : 'range', - # id(len) : 'len', - # id(bool) : 'bool', - # id(numpy.add) : 'numpy.add' - # } +_active_funcs_stack = [] def add_func(func, name): key = id(func) @@ -13,3 +9,27 @@ def add_func(func, name): def get_func_name(func): return _mlir_func_names.get(id(func), None) + +def push_active_funcs_stack(): + global _active_funcs_stack + _active_funcs_stack.append({}) + +def pop_active_funcs_stack(): + global _active_funcs_stack + assert(len(_active_funcs_stack) > 0) + _active_funcs_stack.pop() + +def add_active_funcs(name, func): + global _active_funcs_stack + assert(len(_active_funcs_stack) > 0) + top = _active_funcs_stack[-1] + top[name] = func + +def find_active_func(name): + global _active_funcs_stack + assert(len(_active_funcs_stack) > 0) + for elem in reversed(_active_funcs_stack): + res = elem.get(name) + if not res is None: + return res + return None diff --git a/numba/mlir/inner_compiler.py b/numba/mlir/inner_compiler.py new file mode 100644 index 00000000000..de83883a705 --- /dev/null +++ b/numba/mlir/inner_compiler.py @@ -0,0 +1,34 @@ +from numba.core.typed_passes import get_mlir_func, NopythonTypeInference, AnnotateTypes, MlirBackendInner +from numba.core.compiler import CompilerBase, DefaultPassBuilder, DEFAULT_FLAGS, compile_extra +from numba.core.compiler_machinery import PassManager +from numba.core import typing, cpu +# from numba import njit + +class MlirTempCompiler(CompilerBase): # custom compiler extends from CompilerBase + + def define_pipelines(self): + dpb = DefaultPassBuilder + pm = PassManager('MlirTempCompiler') + untyped_passes = dpb.define_untyped_pipeline(self.state) + pm.passes.extend(untyped_passes.passes) + + pm.add_pass(NopythonTypeInference, "nopython frontend") + pm.add_pass(AnnotateTypes, "annotate types") + pm.add_pass(MlirBackendInner, "mlir backend") + + pm.finalize() + return [pm] + +def _compile_isolated(func, args, return_type=None, flags=DEFAULT_FLAGS, + locals={}): + from numba.core.registry import cpu_target + typingctx = typing.Context() + targetctx = cpu.CPUContext(typingctx) + # Register the contexts in case for nested @jit or @overload calls + with cpu_target.nested_context(typingctx, targetctx): + return compile_extra(typingctx, targetctx, func, args, return_type, + flags, locals, pipeline_class=MlirTempCompiler) + +def compile_func(func, args): + _compile_isolated(func, args) + return get_mlir_func() diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 79a7b5df9ee..76e2de1fa98 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -248,6 +248,32 @@ def py_func(a, b): jit_func = njit(py_func, parallel=True) assert_equal(py_func(10, 20), jit_func(10, 20)) + def test_func_call1(self): + def py_func1(b): + return b + 3 + + jit_func1 = njit(py_func1) + + def py_func2(a): + return jit_func1(a) * 4 + + jit_func2 = njit(py_func2) + + assert_equal(py_func2(10), jit_func2(10)) + + def test_func_call2(self): + def py_func1(b): + return b + 3 + + jit_func1 = njit(py_func1) + + def py_func2(a): + return jit_func1(a) * jit_func1(a + 1) + + jit_func2 = njit(py_func2) + + assert_equal(py_func2(10), jit_func2(10)) + if __name__ == '__main__': unittest.main() From b593a501fbd9a580c99f6f2a837cd95990d92fb6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 20 Jan 2021 16:07:30 +0300 Subject: [PATCH 217/259] update to MLIR master (#160) --- mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/src/pipelines/lower_to_llvm.cpp | 14 +++++++------- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 2b0a3860e74..c8a3e9f598a 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -c1d58c2b0023cd41f0da128f5190fa887d8f6c69 +de4ba7073bd7e200aca704e6a26403e07bc246a5 diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 0d7bc6dcba0..12f7415e378 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -60,7 +60,7 @@ struct LLVMTypeHelper mlir::Type i(unsigned bits) { - return mlir::LLVM::LLVMIntegerType::get(&type_converter.getContext(), bits); + return mlir::IntegerType::get(&type_converter.getContext(), bits); } mlir::Type ptr(mlir::Type type) @@ -99,8 +99,8 @@ mlir::LLVM::LLVMStructType get_array_type(mlir::TypeConverter& converter, mlir:: { assert(type); auto ctx = type.getContext(); - auto i8p = mlir::LLVM::LLVMPointerType::get(mlir::LLVM::LLVMIntegerType::get(ctx, 8)); - auto i64 = mlir::LLVM::LLVMIntegerType::get(ctx, 64); + auto i8p = mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(ctx, 8)); + auto i64 = mlir::IntegerType::get(ctx, 64); auto data_type = converter.convertType(type.getElementType()); assert(data_type); auto shape_type = mlir::LLVM::LLVMArrayType::get(i64, static_cast(type.getRank())); @@ -225,7 +225,7 @@ struct MemRefConversionCache auto ptr = extract(4); auto shape = extract(5); auto strides = extract(6); - auto i64 = mllvm::LLVMIntegerType::get(builder.getContext(), 64); + auto i64 = mlir::IntegerType::get(builder.getContext(), 64); auto offset = builder.create(loc, i64, builder.getI64IntegerAttr(0)); mlir::Value res = builder.create(loc, dst_type); auto insert = [&](unsigned index, mlir::Value val) @@ -368,7 +368,7 @@ struct ReturnOpLowering : public mlir::OpRewritePattern { auto ctx = op.getContext(); auto ret_type = mlir::IntegerType::get(ctx, 32); - auto ll_ret_type = mlir::LLVM::LLVMIntegerType::get(ctx, 32); + auto ll_ret_type = mlir::IntegerType::get(ctx, 32); mlir::Value ret = rewriter.create(op.getLoc(), ll_ret_type, mlir::IntegerAttr::get(ret_type, 0)); rewriter.replaceOpWithNewOp(op, ret); }; @@ -619,7 +619,7 @@ struct LowerParallel : public mlir::OpRewritePattern auto context_ptr_type = mlir::LLVM::LLVMPointerType::get(context_type); auto loc = op.getLoc(); - auto llvm_i32_type = mlir::LLVM::LLVMIntegerType::get(op.getContext(), 32); + auto llvm_i32_type = mlir::IntegerType::get(op.getContext(), 32); auto zero = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0)); auto one = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(1)); auto context = rewriter.create(loc, context_ptr_type, one, 0); @@ -636,7 +636,7 @@ struct LowerParallel : public mlir::OpRewritePattern auto ptr = rewriter.create(loc, pointer_type, context, indices); rewriter.create(loc, llvm_val, ptr); } - auto void_ptr_type = mlir::LLVM::LLVMPointerType::get(mlir::LLVM::LLVMIntegerType::get(op.getContext(), 8)); + auto void_ptr_type = mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(op.getContext(), 8)); auto context_abstract = rewriter.create(loc, void_ptr_type, context); auto index_type = rewriter.getIndexType(); diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index c92e8d99247..903d6419f08 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -247,7 +247,7 @@ mlir::LogicalResult call_rewrite( auto elem_type = mlir::IntegerType::get(op.getContext(), 64); auto res_type = mlir::RankedTensorType::get(1, elem_type); mlir::Value zero = rewriter.create(loc, get_zero(elem_type)); - mlir::Value init = rewriter.create(loc, zero); + mlir::Value init = rewriter.create(loc, zero); mlir::AffineMap map[] = { mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), mlir::AffineMap::get(1, 0, mlir::getAffineConstantExpr(0, op.getContext())), From fda25f06ccdd2f419eb259eb093ccf34b60c71f4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 24 Jan 2021 00:06:12 +0300 Subject: [PATCH 218/259] [MLIR] Some fixes (#161) * more tests --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 +- numba/mlir/tests/test_basic.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 903d6419f08..9c03bdcacd4 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -432,7 +432,7 @@ struct SetitemOpLoweringSSA : public mlir::OpRewritePattern // return mlir::failure(); } - auto new_tensor = rewriter.create(loc, value); + auto new_tensor = rewriter.create(loc, value); auto new_index = index_cast(index, loc, rewriter); mlir::Value one = rewriter.create(loc, 1); auto new_value = rewriter.create(loc, new_tensor, target, new_index, one, one); diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 76e2de1fa98..5322d4bd182 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -97,7 +97,7 @@ def py_func2(): assert_equal(py_func1(), jit_func1()) assert_equal(py_func2(), jit_func2()) - def test_jump(self): + def test_if1(self): def py_func(a, b): c = 3 if a > 5: @@ -109,6 +109,17 @@ def py_func(a, b): for a, b in itertools.product(_test_values, _test_values): assert_equal(py_func(a, b), jit_func(a, b)) + def test_if2(self): + def py_func(a, b): + if a > b: + return a + b + else: + return a - b + + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + assert_equal(py_func(a, b), jit_func(a, b)) + @unittest.skip def test_tuple(self): def py_func(a, b, c): From 2682ad900bf2d5e9efadd1e78acdb7606bbe2c97 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 27 Jan 2021 21:09:21 +0300 Subject: [PATCH 219/259] fix numpy.add (#162) --- mlir-compiler/src/pipelines/plier_to_linalg.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 9c03bdcacd4..53e63bde6de 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -223,9 +223,12 @@ mlir::LogicalResult call_rewrite( }; mlir::StringRef iterators[] = { "parallel" }; + auto dim = rewriter.create(loc, args[0], 0).getResult(); + auto init_tensor = rewriter.create(loc, dim, elem_type).getResult(); + auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) { - assert(args.size() == 2); + assert(args.size() == 3); mlir::Value res = builder.create(loc, args[0], args[1]); builder.create(loc, res); }; @@ -233,7 +236,7 @@ mlir::LogicalResult call_rewrite( loc, mlir::TypeRange(res_type), mlir::ValueRange(inputs), - mlir::ValueRange(), // outputs + mlir::ValueRange(init_tensor), llvm::makeArrayRef(map), llvm::makeArrayRef(iterators), body).getResult(0); From 2fcb7e5419e01bfaeae6234f818e26f6cfd5193c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 29 Jan 2021 23:26:45 +0300 Subject: [PATCH 220/259] [MLIR] Range negative step (#163) --- .../src/pipelines/plier_to_linalg.cpp | 13 +- mlir-compiler/src/pipelines/plier_to_std.cpp | 2 + mlir-compiler/src/transforms/loop_utils.cpp | 144 ++++++++++++------ numba/mlir/tests/test_basic.py | 30 ++++ 4 files changed, 144 insertions(+), 45 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 53e63bde6de..c57bdad2f88 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -536,6 +536,8 @@ struct SetitemOpLowering : public mlir::OpRewritePattern void PlierToLinalgPass::runOnOperation() { + auto context = &getContext(); + mlir::TypeConverter type_converter; // Convert unknown types to itself type_converter.addConversion([](mlir::Type type) { return type; }); @@ -554,11 +556,11 @@ void PlierToLinalgPass::runOnOperation() patterns.insert< FuncOpSignatureConversion, CastOpLowering - >(type_converter, &getContext()); + >(type_converter, context); patterns.insert< CallOpLowering - >(type_converter, &getContext(), call_rewrite); + >(type_converter, context, call_rewrite); patterns.insert< GetitemOpLowering, @@ -566,6 +568,13 @@ void PlierToLinalgPass::runOnOperation() SetitemOpLowering >(&getContext()); + // range/prange lowering need dead branch pruning to properly + // handle negative steps + for (auto *op : context->getRegisteredOperations()) + { + op->getCanonicalizationPatterns(patterns, context); + } + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 7d2cbb5b878..13c4e50b313 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1286,6 +1286,8 @@ void PlierToStdPass::runOnOperation() mlir::populateStdExpandOpsPatterns(context, patterns); + // range/prange lowering need dead branch pruning to properly + // handle negative steps for (auto *op : context->getRegisteredOperations()) { op->getCanonicalizationPatterns(patterns, context); diff --git a/mlir-compiler/src/transforms/loop_utils.cpp b/mlir-compiler/src/transforms/loop_utils.cpp index ebae0c332e3..2952826b89e 100644 --- a/mlir-compiler/src/transforms/loop_utils.cpp +++ b/mlir-compiler/src/transforms/loop_utils.cpp @@ -37,6 +37,7 @@ mlir::Value get_last_iter_value( auto inc = builder.create(loc, count, step); return builder.create(loc, lower_bound, inc); } + } mlir::LogicalResult lower_while_to_for( @@ -54,6 +55,24 @@ mlir::LogicalResult lower_while_to_for( } } + auto loc = getiter.getLoc(); + mlir::Value zero_val; + auto get_zero_index = [&]() + { + if (!zero_val) + { + mlir::OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(getiter); + zero_val = builder.create(loc, 0); + } + return zero_val; + }; + + auto get_neg = [&](mlir::Value value) + { + return builder.create(loc, get_zero_index(), value); + }; + bool changed = false; for (auto while_op : to_process) { @@ -81,58 +100,102 @@ mlir::LogicalResult lower_while_to_for( auto& after_block = while_op.after().front(); - auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) + auto index_cast = [&](mlir::Value val)->mlir::Value { - mlir::BlockAndValueMapping mapper; - assert(before_block.getNumArguments() == iterargs.size()); - assert(after_block.getNumArguments() == before_term.args().size()); - mapper.map(before_block.getArguments(), iterargs); - for (auto it : llvm::zip(after_block.getArguments(), before_term.args())) + if (!val.getType().isa()) { - auto block_arg = std::get<0>(it); - auto term_arg = std::get<1>(it); - if (pairfirst && term_arg == pairfirst) // iter arg - { - auto iter_val = get_iter_val(builder, loc, pairfirst.getType(), iv); - mapper.map(block_arg, iter_val); - } - else - { - mapper.map(block_arg, mapper.lookupOrDefault(term_arg)); - } + return builder.create(loc, val, mlir::IndexType::get(val.getContext())); } + return val; + }; + + auto bounds = get_bounds(builder, loc); + auto orig_lower_bound = index_cast(std::get<0>(bounds)); + auto orig_upper_bound = index_cast(std::get<1>(bounds)); + auto orig_step = index_cast(std::get<2>(bounds)); + + // scf::ForOp/ParallelOp doesn't support negative step, so generate + // IfOp and 2 version for different step signs + // branches for const steps will be pruned later + auto gen_for = [&](bool positive) + { + auto get_loop_body_builder = [&](bool positive) + { + return [&, positive](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) + { + if (!positive) + { + iv = get_neg(iv); + } + mlir::BlockAndValueMapping mapper; + assert(before_block.getNumArguments() == iterargs.size()); + assert(after_block.getNumArguments() == before_term.args().size()); + mapper.map(before_block.getArguments(), iterargs); + for (auto it : llvm::zip(after_block.getArguments(), before_term.args())) + { + auto block_arg = std::get<0>(it); + auto term_arg = std::get<1>(it); + if (pairfirst && term_arg == pairfirst) // iter arg + { + auto iter_val = get_iter_val(builder, loc, pairfirst.getType(), iv); + mapper.map(block_arg, iter_val); + } + else + { + mapper.map(block_arg, mapper.lookupOrDefault(term_arg)); + } + } - for (auto& op : after_block) // with terminator + for (auto& op : after_block) // with terminator + { + builder.clone(op, mapper); + } + }; + }; + + auto lower_bound = orig_lower_bound; + auto upper_bound = orig_upper_bound; + auto step = orig_step; + + if (!positive) { - builder.clone(op, mapper); + lower_bound = get_neg(lower_bound); + upper_bound = get_neg(upper_bound); + step = get_neg(step); } + + return builder.create( + loc, + lower_bound, + upper_bound, + step, + while_op.getOperands(), // iterArgs + get_loop_body_builder(positive) + ); }; - auto loc = getiter.getLoc(); - auto index_cast = [&](mlir::Value val)->mlir::Value + auto get_if_body_builder = [&](bool positive) { - if (!val.getType().isa()) + return [&, positive](mlir::OpBuilder& builder, mlir::Location loc) { - return builder.create(loc, val, mlir::IndexType::get(val.getContext())); - } - return val; + auto loop_op = gen_for(positive); + if (results) + { + results(loop_op); + } + builder.create(loc, loop_op.getResults()); + }; }; - auto bounds = get_bounds(builder, loc); - auto lower_bound = index_cast(std::get<0>(bounds)); - auto upper_bound = index_cast(std::get<1>(bounds)); - auto step = index_cast(std::get<2>(bounds)); - builder.setInsertionPoint(while_op); - auto loop_op = builder.create( + auto step_sign = builder.create(loc, mlir::CmpIPredicate::sge, orig_step, get_zero_index()); + auto loop_op = builder.create( loc, - lower_bound, - upper_bound, - step, - while_op.getOperands(), // iterArgs - body - ); + while_op.getOperands().getTypes(), + step_sign, + get_if_body_builder(true), + get_if_body_builder(false)); assert(while_op.getNumResults() >= loop_op.getNumResults()); builder.updateRootInPlace(while_op, [&]() @@ -155,7 +218,7 @@ mlir::LogicalResult lower_while_to_for( } if (pairfirst && operand == pairfirst && !old_res.getUsers().empty()) { - auto val = get_last_iter_value(builder, loc, lower_bound, upper_bound, step); + auto val = get_last_iter_value(builder, loc, orig_lower_bound, orig_upper_bound, orig_step); auto new_res = builder.create(loc, old_res.getType(), val); old_res.replaceAllUsesWith(new_res); } @@ -166,11 +229,6 @@ mlir::LogicalResult lower_while_to_for( assert(while_op.getOperation()->getUsers().empty()); builder.eraseOp(while_op); changed = true; - - if (results) - { - results(loop_op); - } } if (getiter.getOperation()->getUsers().empty()) diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 5322d4bd182..1f0f6bb509e 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -160,6 +160,36 @@ def py_func(a, b, c): jit_func = njit(py_func) assert_equal(py_func(10, 20, 2), jit_func(10, 20, 2)) + def test_range_negative_step(self): + def py_func(a, b, c): + res = 0 + for i in range(a, b, c): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(5, -8, -2), jit_func(5, -8, -2)) + + def test_range_const_step1(self): + def py_func(a, b): + res = 0 + for i in range(a, b, -2): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(5, -8), jit_func(5, -8)) + + def test_range_const_step2(self): + def py_func(a, b): + res = 0 + for i in range(a, b, 2): + res = res + i + return res + + jit_func = njit(py_func) + assert_equal(py_func(-5, 8), jit_func(-5, 8)) + def test_range_use_index_after(self): def py_func(n): res = 0 From 3939fcde71826b90a822421de6a3a245589c258e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 30 Jan 2021 19:53:49 +0300 Subject: [PATCH 221/259] [MLIR] High level linalg api (#164) --- mlir-compiler/CMakeLists.txt | 6 + mlir-compiler/include/plier/dialect.hpp | 1 + mlir-compiler/src/dialect.cpp | 7 +- .../src/pipelines/plier_to_linalg.cpp | 135 ++--- mlir-compiler/src/pipelines/plier_to_std.cpp | 7 +- mlir-compiler/src/py_func_resolver.cpp | 100 +--- mlir-compiler/src/py_func_resolver.hpp | 2 - mlir-compiler/src/py_linalg_resolver.cpp | 480 ++++++++++++++++++ mlir-compiler/src/py_linalg_resolver.hpp | 36 ++ mlir-compiler/src/py_map_types.cpp | 108 ++++ mlir-compiler/src/py_map_types.hpp | 17 + mlir-compiler/src/rewrites/force_inline.cpp | 43 ++ mlir-compiler/src/rewrites/force_inline.hpp | 16 + numba/mlir/__init__.py | 8 +- numba/mlir/linalg_builder.py | 62 +++ numba/mlir/numpy/funcs.py | 36 ++ 16 files changed, 873 insertions(+), 191 deletions(-) create mode 100644 mlir-compiler/src/py_linalg_resolver.cpp create mode 100644 mlir-compiler/src/py_linalg_resolver.hpp create mode 100644 mlir-compiler/src/py_map_types.cpp create mode 100644 mlir-compiler/src/py_map_types.hpp create mode 100644 mlir-compiler/src/rewrites/force_inline.cpp create mode 100644 mlir-compiler/src/rewrites/force_inline.hpp create mode 100644 numba/mlir/linalg_builder.py create mode 100644 numba/mlir/numpy/funcs.py diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 5ceaf24b03f..a01eea32174 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -35,6 +35,7 @@ set(SOURCES_LIST src/rewrites/canonicalize_reductions.cpp src/rewrites/cast_lowering.cpp src/rewrites/cse.cpp + src/rewrites/force_inline.cpp src/rewrites/promote_to_parallel.cpp src/rewrites/type_conversion.cpp src/transforms/func_utils.cpp @@ -46,6 +47,8 @@ set(SOURCES_LIST src/mangle.cpp src/pipeline_registry.cpp src/py_func_resolver.cpp + src/py_linalg_resolver.cpp + src/py_map_types.cpp src/py_module.cpp src/utils.cpp ) @@ -61,6 +64,7 @@ set(HEADERS_LIST src/rewrites/canonicalize_reductions.hpp src/rewrites/cast_lowering.hpp src/rewrites/cse.hpp + src/rewrites/force_inline.hpp src/rewrites/promote_to_parallel.hpp src/rewrites/type_conversion.hpp src/transforms/func_utils.hpp @@ -71,6 +75,8 @@ set(HEADERS_LIST src/mangle.hpp src/pipeline_registry.hpp src/py_func_resolver.hpp + src/py_linalg_resolver.hpp + src/py_map_types.hpp src/py_module.hpp src/utils.hpp ) diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/include/plier/dialect.hpp index 4f46d010bb9..a4a8696892f 100644 --- a/mlir-compiler/include/plier/dialect.hpp +++ b/mlir-compiler/include/plier/dialect.hpp @@ -32,6 +32,7 @@ llvm::StringRef getFastmathName(); llvm::StringRef getJumpMarkersName(); llvm::StringRef getParallelName(); llvm::StringRef getMaxConcurrencyName(); +llvm::StringRef getForceInlineName(); } namespace detail diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index cb620782964..549c25a6b87 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -30,7 +30,12 @@ llvm::StringRef attributes::getParallelName() llvm::StringRef attributes::getMaxConcurrencyName() { - return "#plier.max_concurrency"; + return "#plier.max_concurrency"; +} + +llvm::StringRef attributes::getForceInlineName() +{ + return "#plier.force_inline"; } diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index c57bdad2f88..7fa3fcf87f6 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -28,10 +28,12 @@ #include "rewrites/cse.hpp" #include "rewrites/promote_to_parallel.hpp" #include "rewrites/type_conversion.hpp" +#include "rewrites/force_inline.hpp" #include "transforms/loop_utils.hpp" #include "base_pipeline.hpp" #include "pipeline_registry.hpp" +#include "py_linalg_resolver.hpp" #include @@ -194,99 +196,55 @@ mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef return mlir::success(); } -mlir::LogicalResult call_rewrite( - plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, - mlir::PatternRewriter& rewriter) +struct CallLowerer { - using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); - std::pair handlers[] = { - {"numba.prange", lower_prange}, - }; - for (auto& handler : handlers) + mlir::LogicalResult operator()( + plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, + mlir::PatternRewriter& rewriter) { - if (handler.first == name) + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); + std::pair handlers[] = { + {"numba.prange", lower_prange}, + }; + for (auto& handler : handlers) { - return handler.second(op, args, rewriter); + if (handler.first == name) + { + return handler.second(op, args, rewriter); + } } - } - if (name == "numpy.add" && check_numpy_args(args, 2)) - { - auto loc = op.getLoc(); - mlir::Value inputs[] = { args[0], args[1] }; - auto elem_type = get_elem_type(args[0].getType()); - mlir::Type res_type = mlir::RankedTensorType::get(-1, elem_type); - mlir::AffineMap map[] = { - mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), - mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), - mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), - }; - mlir::StringRef iterators[] = { "parallel" }; - - auto dim = rewriter.create(loc, args[0], 0).getResult(); - auto init_tensor = rewriter.create(loc, dim, elem_type).getResult(); - - auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args)) { - assert(args.size() == 3); - mlir::Value res = builder.create(loc, args[0], args[1]); - builder.create(loc, res); - }; - auto res = rewriter.create( - loc, - mlir::TypeRange(res_type), - mlir::ValueRange(inputs), - mlir::ValueRange(init_tensor), - llvm::makeArrayRef(map), - llvm::makeArrayRef(iterators), - body).getResult(0); - rewriter.replaceOp(op, res); - return mlir::success(); - } - if (name == "array.sum" && check_numpy_args(args, 1)) - { - auto loc = op.getLoc(); - mlir::Value inputs[] = { args[0] }; - auto elem_type = mlir::IntegerType::get(op.getContext(), 64); - auto res_type = mlir::RankedTensorType::get(1, elem_type); - mlir::Value zero = rewriter.create(loc, get_zero(elem_type)); - mlir::Value init = rewriter.create(loc, zero); - mlir::AffineMap map[] = { - mlir::AffineMap::getMultiDimIdentityMap(1, op.getContext()), - mlir::AffineMap::get(1, 0, mlir::getAffineConstantExpr(0, op.getContext())), - }; - mlir::StringRef iterators[] = { "reduction" }; - auto body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + assert(result->size() == op->getNumResults()); + rerun_std_pipeline(op); + if (result->empty()) + { + rewriter.eraseOp(op); + } + else + { + rewriter.replaceOp(op, *result); + } + return mlir::success(); + } + + if (name == "len" && check_numpy_args(args, 1)) { - assert(args.size() == 2); - auto val = builder.create(loc, args[0], elem_type); - mlir::Value res = builder.create(loc, val, args[1]); - builder.create(loc, res); - }; - auto val = rewriter.create( - loc, - mlir::TypeRange(res_type), - mlir::ValueRange(inputs), - mlir::ValueRange(init), // outputs - llvm::makeArrayRef(map), - llvm::makeArrayRef(iterators), - body).getResult(0); - mlir::Value index = rewriter.create(loc, 0); - mlir::Value res = rewriter.create(loc, val, index); - rewriter.replaceOp(op, res); - return mlir::success(); - } - if (name == "len" && check_numpy_args(args, 1)) - { - auto loc = op.getLoc(); - mlir::Value dim = rewriter.create(loc, args[0], 0); - mlir::Value res = rewriter.create(loc, op.getType(), dim); - rerun_std_pipeline(op); - rewriter.replaceOp(op, res); - return mlir::success(); + auto loc = op.getLoc(); + mlir::Value dim = rewriter.create(loc, args[0], 0); + mlir::Value res = rewriter.create(loc, op.getType(), dim); + rerun_std_pipeline(op); + rewriter.replaceOp(op, res); + return mlir::success(); + } + return mlir::failure(); } - return mlir::failure(); -} + +private: + + PyLinalgResolver linalg_resolver; +}; template struct GetitemOpLowering : public mlir::OpRewritePattern @@ -558,14 +516,17 @@ void PlierToLinalgPass::runOnOperation() CastOpLowering >(type_converter, context); + CallLowerer callLowerer; + patterns.insert< CallOpLowering - >(type_converter, context, call_rewrite); + >(type_converter, context, callLowerer); patterns.insert< GetitemOpLowering, GetitemOpLowering, - SetitemOpLowering + SetitemOpLowering, + ForceInline >(&getContext()); // range/prange lowering need dead branch pruning to properly diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 13c4e50b313..3b0c4773f8c 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -499,7 +499,12 @@ template void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) { assert(nullptr != op); - rewriter.replaceOpWithNewOp(op, new_type, operands); + llvm::SmallVector new_operands(operands.size()); + for (auto it : llvm::enumerate(operands)) + { + new_operands[it.index()] = do_cast(new_type, it.value(), rewriter); + } + rewriter.replaceOpWithNewOp(op, new_type, new_operands); } void replace_itruediv_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) diff --git a/mlir-compiler/src/py_func_resolver.cpp b/mlir-compiler/src/py_func_resolver.cpp index 851498b5e07..f927c028203 100644 --- a/mlir-compiler/src/py_func_resolver.cpp +++ b/mlir-compiler/src/py_func_resolver.cpp @@ -4,103 +4,9 @@ #include -namespace py = pybind11; - -namespace -{ - -template -bool is_int(mlir::Type type) -{ - if (auto t = type.dyn_cast()) - { - if (t.getWidth() == Width && t.getSignedness() == Signed) - { - return true; - } - } - return false; -} - -template -bool is_float(mlir::Type type) -{ - if (auto f = type.dyn_cast()) - { - if (f.getWidth() == Width) - { - return true; - } - } - return false; -} - -py::handle map_type(const py::handle& types_mod, mlir::Type type) -{ - using fptr_t = bool(*)(mlir::Type); - const std::pair primitive_types[] = { - {&is_int<1, mlir::IntegerType::Signed>, "boolean"}, - {&is_int<1, mlir::IntegerType::Signless>, "boolean"}, - {&is_int<1, mlir::IntegerType::Unsigned>, "boolean"}, - - {&is_int<8, mlir::IntegerType::Signed>, "int8"}, - {&is_int<8, mlir::IntegerType::Signless>, "int8"}, - {&is_int<8, mlir::IntegerType::Unsigned>, "uint8"}, - - {&is_int<16, mlir::IntegerType::Signed>, "int16"}, - {&is_int<16, mlir::IntegerType::Signless>, "int16"}, - {&is_int<16, mlir::IntegerType::Unsigned>, "uint16"}, +#include "py_map_types.hpp" - {&is_int<32, mlir::IntegerType::Signed>, "int32"}, - {&is_int<32, mlir::IntegerType::Signless>, "int32"}, - {&is_int<32, mlir::IntegerType::Unsigned>, "uint32"}, - - {&is_int<64, mlir::IntegerType::Signed>, "int64"}, - {&is_int<64, mlir::IntegerType::Signless>, "int64"}, - {&is_int<64, mlir::IntegerType::Unsigned>, "uint64"}, - - {&is_float<32>, "float"}, - {&is_float<64>, "double"}, - }; - - for (auto h : primitive_types) - { - if (h.first(type)) - { - auto name = h.second; - return types_mod.attr(py::str(name.data(), name.size())); - } - } - - if (auto m = type.dyn_cast()) - { - auto elem_type = map_type(types_mod, m.getElementType()); - if (!elem_type) - { - return {}; - } - auto ndims = py::int_(m.getRank()); - auto array_type = types_mod.attr("Array"); - return array_type(elem_type, ndims, py::str("C")); - } - return {}; -} - -py::list map_types(const py::handle& types_mod, mlir::TypeRange types) -{ - py::list ret; - for (auto type : types) - { - auto elem = map_type(types_mod, type); - if (!elem) - { - return py::none(); - } - ret.append(std::move(elem)); - } - return ret; -} -} +namespace py = pybind11; struct PyFuncResolver::Context { @@ -133,7 +39,7 @@ mlir::FuncOp PyFuncResolver::get_func(llvm::StringRef name, mlir::TypeRange type { return {}; } - auto py_types = map_types(context->types, types); + auto py_types = map_types_to_numba(context->types, types); if (py_types.is_none()) { return {}; diff --git a/mlir-compiler/src/py_func_resolver.hpp b/mlir-compiler/src/py_func_resolver.hpp index 403011a1790..8c94dabb96e 100644 --- a/mlir-compiler/src/py_func_resolver.hpp +++ b/mlir-compiler/src/py_func_resolver.hpp @@ -19,8 +19,6 @@ class PyFuncResolver PyFuncResolver(); ~PyFuncResolver(); - PyFuncResolver(const PyFuncResolver&) = delete; - mlir::FuncOp get_func(llvm::StringRef name, mlir::TypeRange types); private: diff --git a/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/src/py_linalg_resolver.cpp new file mode 100644 index 00000000000..3071149f56a --- /dev/null +++ b/mlir-compiler/src/py_linalg_resolver.cpp @@ -0,0 +1,480 @@ +#include "py_linalg_resolver.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "plier/dialect.hpp" +#include "py_map_types.hpp" +#include "utils.hpp" + +namespace py = pybind11; + +struct PyBuilderContext +{ + mlir::Location loc; + mlir::OpBuilder& builder; + PyLinalgResolver::Context& context; +}; + +namespace +{ +bool is_compatible_types(mlir::TypeRange types) +{ + return !types.empty() && llvm::all_of(types, [](mlir::Type t){ return t.isa(); }); +} + +py::handle get_dim(int64_t val) +{ + if (val == -1) + { + return py::none(); + } + return py::int_(val); +} + +size_t py_func_arg_count(py::handle signature, py::handle func) +{ + return py::len(signature(func).attr("parameters")); +} + +template +py::capsule wrap_mlir(T val) +{ + return py::capsule(val.getAsOpaquePointer()); +} + +template +T unwrap_mlir(py::capsule obj) +{ + return T::getFromOpaquePointer(static_cast(obj)); +} + +auto unwrap_ssa_val(py::handle obj) +{ + return unwrap_mlir(obj.attr("_ssa_val").cast()); +} + +auto unwrap_shape(py::list shape) +{ + llvm::SmallVector ret; + ret.reserve(shape.size()); + for (auto elem : shape) + { + ret.push_back(unwrap_ssa_val(elem)); + } + return ret; +} + +size_t container_size(py::handle obj) +{ + if (py::isinstance(obj)) + { + return obj.cast().size(); + } + if (py::isinstance(obj)) + { + return obj.cast().size(); + } + return 1; +} + +template +void container_iterate(py::handle obj, F&& func) +{ + auto impl = [&](auto cont) + { + for (auto it : llvm::enumerate(cont)) + { + func(it.index(), it.value()); + } + }; + if (py::isinstance(obj)) + { + impl(obj.cast()); + } + else if (py::isinstance(obj)) + { + impl(obj.cast()); + } + else + { + func(std::size_t(0), obj); + } +} +} + +struct PyLinalgResolver::Context +{ + py::handle var; + py::handle val; + py::handle builder; + py::handle signature; + py::handle types_mod; + py::handle compile_func; + py::handle lookup_func; + + py::object create_var(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value value) + { + if (value.getType().isa()) + { + auto make_dim_val = [&](auto dim, auto ssa_val) + { + return val(get_dim(dim), wrap_mlir(ssa_val)); + }; + auto mlir_type = value.getType().cast(); + auto shape = mlir_type.getShape(); + auto elem_type = mlir_type.getElementType(); + py::list py_shape(shape.size()); + for (auto it2 : llvm::enumerate(shape)) + { + mlir::Value mlir_dim = builder.create(loc, value, it2.index()); + py_shape[it2.index()] = make_dim_val(it2.value(), mlir_dim); + } + return var(wrap_mlir(value), py_shape, wrap_mlir(elem_type)); + } + return var(wrap_mlir(value), py::none(), wrap_mlir(value.getType())); + } + + mlir::FuncOp compile_body(py::handle body, py::list arg_types) + { + auto func = compile_func(body, arg_types).cast(); + auto mlir_func = mlir::cast(static_cast(func)); + mlir_func.setPrivate(); + mlir_func->setAttr(plier::attributes::getForceInlineName(), mlir::UnitAttr::get(mlir_func->getContext())); + return mlir_func; + } + + py::object wrap_result(mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange values) + { + if (values.empty()) + { + return py::none(); + } + if (values.size() == 1) + { + return create_var(loc, builder, values.front()); + } + py::tuple ret(values.size()); + for (auto it : llvm::enumerate(values)) + { + ret[it.index()] = create_var(loc, builder, it.value()); + } + return std::move(ret); + } +}; + +namespace +{ + +PyBuilderContext& get_py_context(py::capsule& ctx) +{ + return *static_cast(ctx); +} + +mlir::Value get_var_value(py::handle var) +{ + return unwrap_mlir(var.attr("_ssa_val").cast()); +} + +auto get_types(mlir::ValueRange values) +{ + return values.getTypes(); +} + +auto get_agrs_from_tuple(py::handle args) +{ + llvm::SmallVector ret; + if (py::isinstance(args)) + { + auto tuple = args.cast(); + ret.resize(tuple.size()); + for (auto it : llvm::enumerate(tuple)) + { + ret[it.index()] = get_var_value(it.value()); + } + } + else + { + ret.emplace_back(get_var_value(args)); + } + return ret; +} + +auto get_iterators(py::list iterators, mlir::MLIRContext& ctx) +{ + llvm::SmallVector ret(iterators.size()); + for (auto it : llvm::enumerate(iterators)) + { + ret[it.index()] = mlir::StringAttr::get(it.value().cast(), &ctx).getValue(); + } + return ret; +} + +auto get_affine_maps(py::list maps, mlir::MLIRContext& ctx) +{ + llvm::SmallVector ret(maps.size()); + for (auto it : llvm::enumerate(maps)) + { + auto str = (llvm::Twine("affine_map<") + it.value().cast() + ">").str(); + auto attr = mlir::parseAttribute(str, &ctx); + ret[it.index()] = attr.cast().getValue(); + } + return ret; +} + +auto get_generic_op_body_types(mlir::ValueRange inputs, mlir::ValueRange outputs) +{ + llvm::SmallVector ret; + ret.reserve(inputs.size() + outputs.size()); + for (auto r : {inputs, outputs}) + { + for (auto type : r.getTypes()) + { + auto elem_type = type.cast().getElementType(); + ret.emplace_back(elem_type); + } + } + return ret; +} + +auto generic_op_body_result_types(mlir::ValueRange outputs) +{ + llvm::SmallVector ret; + ret.reserve(outputs.size()); + for (auto type : outputs.getTypes()) + { + auto elem_type = type.cast().getElementType(); + ret.emplace_back(elem_type); + } + return ret; +} + +py::object broadcast_impl(py::capsule /*context*/, py::tuple args) +{ + if (1 == args.size()) + { + return args[0]; + } + else + { + return std::move(args); + } +} + +py::object init_tensor_impl(py::capsule context, py::list shape, py::capsule dtype) +{ + auto& ctx = get_py_context(context); + auto elem_type = unwrap_mlir(dtype); + auto init = ctx.builder.create(ctx.loc, unwrap_shape(shape), elem_type); + return ctx.context.create_var(ctx.loc, ctx.builder, init); +} + +py::object generic_impl(py::capsule context, py::handle inputs, py::handle outputs, py::list iterators, py::list maps, py::handle body) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto& mlir_context = *builder.getContext(); + + auto inputs_args = get_agrs_from_tuple(inputs); + auto output_args = get_agrs_from_tuple(outputs); + auto ret_types = get_types(output_args); + auto affine_maps = get_affine_maps(maps, mlir_context); + auto mlir_iterators = get_iterators(iterators, mlir_context); + + auto func_types = map_types_to_numba(ctx.context.types_mod, get_generic_op_body_types(inputs_args, output_args)); + auto body_func = ctx.context.compile_body(body, func_types); + auto body_builder = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + { + auto cast_values = [&](mlir::ValueRange vals, mlir::TypeRange types) + { + assert(vals.size() == types.size()); + llvm::SmallVector ret(vals.size()); + auto do_cast = [&](mlir::Value val, mlir::Type type) + { + if (val.getType() == type) + { + return val; + } + return builder.create(loc, type, val).getResult(); + }; + for (auto it : llvm::enumerate(vals)) + { + auto index = static_cast(it.index()); + ret[index] = do_cast(it.value(), types[index]); + } + return ret; + }; + auto func_type = body_func.getType(); + auto new_args = cast_values(args, func_type.getInputs()); + auto call = builder.create(loc, body_func, new_args); + auto new_results = cast_values(call.getResults(), generic_op_body_result_types(output_args)); + builder.create(loc, new_results); + }; + + auto generic_op = builder.create( + ctx.loc, + ret_types, + inputs_args, + output_args, + affine_maps, + mlir_iterators, + body_builder); + return ctx.context.wrap_result(ctx.loc, builder, generic_op.getResults()); +} + +py::object from_elements_impl(py::capsule context, py::handle values, py::capsule dtype) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto type = unwrap_mlir(dtype); + + llvm::SmallVector vals(container_size(values)); + container_iterate(values, [&](auto index, py::handle obj) + { + if (py::isinstance(obj, ctx.context.var)) + { + vals[index] = unwrap_ssa_val(obj); + } + else if (py::isinstance(obj) || + py::isinstance(obj)) + { + auto attr = [&]()->mlir::Attribute + { + if (type.isa()) + { + return mlir::IntegerAttr::get(type, obj.cast()); + } + if (type.isa()) + { + return mlir::FloatAttr::get(type, obj.cast()); + } + report_error("Invalid dtype"); + }(); + vals[index] = builder.create(loc, attr); + } + else + { + report_error("Invalid element type"); + } + }); + auto res = builder.create(loc, vals); + return ctx.context.create_var(ctx.loc, ctx.builder, res); +} + +py::object extract_impl(py::capsule context, py::handle value, py::handle indices) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto loc = ctx.loc; + + llvm::SmallVector ind(container_size(indices)); + container_iterate(indices, [&](auto index, py::handle obj) + { + if (py::isinstance(obj, ctx.context.var)) + { + ind[index] = unwrap_ssa_val(obj); + } + else if (py::isinstance(obj)) + { + ind[index] = builder.create(loc, obj.cast()); + } + else + { + report_error("Invalid element type"); + } + }); + auto res = builder.create(loc, get_var_value(value), ind); + return ctx.context.create_var(ctx.loc, ctx.builder, res); +} + +void setup_py_builder(py::handle builder) +{ + py::setattr(builder, "_broadcast", py::cpp_function(&broadcast_impl)); + py::setattr(builder, "_init_tensor", py::cpp_function(&init_tensor_impl)); + py::setattr(builder, "_generic", py::cpp_function(&generic_impl)); + py::setattr(builder, "_from_elements", py::cpp_function(&from_elements_impl)); + py::setattr(builder, "_extract", py::cpp_function(&extract_impl)); +} + +PyLinalgResolver::Values unpack_results(py::handle object) +{ + PyLinalgResolver::Values ret; + if (object.is_none()) + { + return ret; + } + if (py::isinstance(object)) + { + auto tuple = object.cast(); + ret.resize(tuple.size()); + for (auto it : llvm::enumerate(tuple)) + { + ret[it.index()] = unwrap_ssa_val(it.value()); + } + return ret; + } + ret.emplace_back(unwrap_ssa_val(object)); + return ret; +} +} + +PyLinalgResolver::PyLinalgResolver(): + context(std::make_unique()) +{ + auto builder_mod = py::module::import("numba.mlir.linalg_builder"); + context->var = builder_mod.attr("Var"); + context->val = builder_mod.attr("Val"); + context->builder = builder_mod.attr("Builder"); + context->signature = py::module::import("inspect").attr("signature"); + context->types_mod = py::module::import("numba.core.types"); + context->compile_func = builder_mod.attr("compile_func"); + context->lookup_func = builder_mod.attr("lookup_func"); +} + +PyLinalgResolver::~PyLinalgResolver() +{ + +} + +llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args) +{ + assert(!name.empty()); + if (!is_compatible_types(args.getTypes())) + { + return {}; + } + + auto builder_func = context->lookup_func(py::str(name.data(), name.size())); + if (builder_func.is_none() || py_func_arg_count(context->signature, builder_func) != (args.size() + 1)) + { + return {}; + } + + PyBuilderContext py_builder_context{loc, builder, *context}; + auto py_builder = context->builder(py::capsule(&py_builder_context)); + setup_py_builder(py_builder); + + assert(!args.empty()); + auto module = args.front().getParentRegion()->getParentOfType(); + assert(module); + + py::list py_args(args.size()); + for (auto it : llvm::enumerate(args)) + { + auto index = static_cast(it.index()); + auto mlir_arg = it.value(); + py_args[index] = context->create_var(loc, builder, mlir_arg); + } + + auto result = builder_func(py_builder, *py_args); + return unpack_results(result); +} diff --git a/mlir-compiler/src/py_linalg_resolver.hpp b/mlir-compiler/src/py_linalg_resolver.hpp new file mode 100644 index 00000000000..80ca93da3d0 --- /dev/null +++ b/mlir-compiler/src/py_linalg_resolver.hpp @@ -0,0 +1,36 @@ +#pragma once + +#include + +#include +#include + +namespace llvm +{ +class StringRef; +} + +namespace mlir +{ +class Value; +class FuncOp; +class ValueRange; +class OpBuilder; +class Location; +} + +class PyLinalgResolver +{ +public: + PyLinalgResolver(); + ~PyLinalgResolver(); + + using Values = llvm::SmallVector; + + llvm::Optional rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args); + +private: + friend struct PyBuilderContext; + struct Context; + std::unique_ptr context; +}; diff --git a/mlir-compiler/src/py_map_types.cpp b/mlir-compiler/src/py_map_types.cpp new file mode 100644 index 00000000000..2e85ee500aa --- /dev/null +++ b/mlir-compiler/src/py_map_types.cpp @@ -0,0 +1,108 @@ +#include "py_map_types.hpp" + +#include + +#include +#include + +namespace py = pybind11; + +namespace +{ +template +bool is_int(mlir::Type type) +{ + if (auto t = type.dyn_cast()) + { + if (t.getWidth() == Width && t.getSignedness() == Signed) + { + return true; + } + } + return false; +} + +template +bool is_float(mlir::Type type) +{ + if (auto f = type.dyn_cast()) + { + if (f.getWidth() == Width) + { + return true; + } + } + return false; +} + +py::object map_type(const py::handle& types_mod, mlir::Type type) +{ + using fptr_t = bool(*)(mlir::Type); + const std::pair primitive_types[] = { + {&is_int<1, mlir::IntegerType::Signed>, "boolean"}, + {&is_int<1, mlir::IntegerType::Signless>, "boolean"}, + {&is_int<1, mlir::IntegerType::Unsigned>, "boolean"}, + + {&is_int<8, mlir::IntegerType::Signed>, "int8"}, + {&is_int<8, mlir::IntegerType::Signless>, "int8"}, + {&is_int<8, mlir::IntegerType::Unsigned>, "uint8"}, + + {&is_int<16, mlir::IntegerType::Signed>, "int16"}, + {&is_int<16, mlir::IntegerType::Signless>, "int16"}, + {&is_int<16, mlir::IntegerType::Unsigned>, "uint16"}, + + {&is_int<32, mlir::IntegerType::Signed>, "int32"}, + {&is_int<32, mlir::IntegerType::Signless>, "int32"}, + {&is_int<32, mlir::IntegerType::Unsigned>, "uint32"}, + + {&is_int<64, mlir::IntegerType::Signed>, "int64"}, + {&is_int<64, mlir::IntegerType::Signless>, "int64"}, + {&is_int<64, mlir::IntegerType::Unsigned>, "uint64"}, + + {&is_float<32>, "float"}, + {&is_float<64>, "double"}, + }; + + for (auto h : primitive_types) + { + if (h.first(type)) + { + auto name = h.second; + return types_mod.attr(py::str(name.data(), name.size())); + } + } + + if (auto m = type.dyn_cast()) + { + auto elem_type = map_type(types_mod, m.getElementType()); + if (!elem_type) + { + return {}; + } + auto ndims = py::int_(m.getRank()); + auto array_type = types_mod.attr("Array"); + return array_type(elem_type, ndims, py::str("C")); + } + return {}; +} +} +pybind11::object map_type_to_numba(pybind11::handle types_mod, mlir::Type type) +{ + auto elem = map_type(types_mod, type); + if (!elem) + { + return py::none(); + } + return elem; +} + +pybind11::list map_types_to_numba(pybind11::handle types_mod, mlir::TypeRange types) +{ + py::list ret(types.size()); + for (auto it : llvm::enumerate(types)) + { + ret[it.index()] = map_type_to_numba(types_mod, it.value()); + } + return ret; +} + diff --git a/mlir-compiler/src/py_map_types.hpp b/mlir-compiler/src/py_map_types.hpp new file mode 100644 index 00000000000..90cd42d4fb6 --- /dev/null +++ b/mlir-compiler/src/py_map_types.hpp @@ -0,0 +1,17 @@ +#pragma once + +namespace pybind11 +{ +class list; +class object; +class handle; +} + +namespace mlir +{ +class Type; +class TypeRange; +} + +pybind11::object map_type_to_numba(pybind11::handle types_mod, mlir::Type type); +pybind11::list map_types_to_numba(pybind11::handle types_mod, mlir::TypeRange types); diff --git a/mlir-compiler/src/rewrites/force_inline.cpp b/mlir-compiler/src/rewrites/force_inline.cpp new file mode 100644 index 00000000000..8eb16c3f48f --- /dev/null +++ b/mlir-compiler/src/rewrites/force_inline.cpp @@ -0,0 +1,43 @@ +#include "rewrites/force_inline.hpp" + +#include +#include + +#include "plier/dialect.hpp" + +mlir::LogicalResult ForceInline::matchAndRewrite(mlir::CallOp op, mlir::PatternRewriter& rewriter) const +{ + auto attr_name = plier::attributes::getForceInlineName(); + auto mod = op->getParentOfType(); + assert(mod); + auto func = mod.lookupSymbol(op.callee()); + if (!func) + { + return mlir::failure(); + } + if (!op->hasAttr(attr_name) && + !func->hasAttr(attr_name)) + { + return mlir::failure(); + } + + if (!llvm::hasNItems(func.getRegion(), 1)) + { + return mlir::failure(); + } + mlir::InlinerInterface inliner_interface(op->getContext()); + auto parent = op->getParentOp(); + rewriter.startRootUpdate(parent); + auto res = mlir::inlineCall(inliner_interface, op, func, &func.getRegion()); + if (mlir::succeeded(res)) + { + assert(op->getUsers().empty()); + rewriter.eraseOp(op); + rewriter.finalizeRootUpdate(parent); + } + else + { + rewriter.cancelRootUpdate(parent); + } + return res; +} diff --git a/mlir-compiler/src/rewrites/force_inline.hpp b/mlir-compiler/src/rewrites/force_inline.hpp new file mode 100644 index 00000000000..881ae3f493f --- /dev/null +++ b/mlir-compiler/src/rewrites/force_inline.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include + +namespace mlir +{ +class CallOp; +} + +struct ForceInline : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::CallOp op, mlir::PatternRewriter &rewriter) const override; +}; diff --git a/numba/mlir/__init__.py b/numba/mlir/__init__.py index 7100473c946..fc4d383024a 100644 --- a/numba/mlir/__init__.py +++ b/numba/mlir/__init__.py @@ -1,8 +1,10 @@ from numba import runtests -import numba.mlir.builtin_funcs -import numba.mlir.numpy_funcs -import numba.mlir.math_funcs +from . import builtin_funcs +from . import numpy_funcs +from . import math_funcs + +from .numpy import funcs def test(*args, **kwargs): return runtests.main("numba.mlir.tests", *args, **kwargs) diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py new file mode 100644 index 00000000000..8d20a4e8ef7 --- /dev/null +++ b/numba/mlir/linalg_builder.py @@ -0,0 +1,62 @@ + + +class Var: + def __init__(self, ssa_val, shape, dtype): + self._ssa_val = ssa_val + self._shape = shape + self._dtype = dtype + + @property + def shape(self): + return self._shape + + @property + def dtype(self): + return self._dtype + + + +class Val: + def __init__(self, const_val, ssa_val): + self._const_val = const_val + self._ssa_val = ssa_val + + def is_const(self): + return not _const_val is None + +class Builder: + def __init__(self, context): + self._context = context + + def broadcast(self, *args): + return self._broadcast(self._context, args) + + def init_tensor(self, shape, dtype): + return self._init_tensor(self._context, shape, dtype) + + def generic(self, inputs, outputs, iterators, maps, body): + return self._generic(self._context, inputs, outputs, iterators, maps, body) + + def from_elements(self, values, dtype): + return self._from_elements(self._context, values, dtype) + + def extract(self, value, indices): + return self._extract(self._context, value, indices) + +def compile_func(*args, **kwargs): + import numba.mlir.inner_compiler + return numba.mlir.inner_compiler.compile_func(*args, **kwargs) + +_func_registry = {} + +def register_func(name): + def _decorator(func): + global _func_registry + assert not name in _func_registry + _func_registry[name] = func + return func + return _decorator + +def lookup_func(name): + global _func_registry + return _func_registry.get(name) diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py new file mode 100644 index 00000000000..4ac9251867c --- /dev/null +++ b/numba/mlir/numpy/funcs.py @@ -0,0 +1,36 @@ +from ..linalg_builder import register_func + +@register_func('numpy.add') +def add_impl(builder, arg1, arg2): + a1, a2 = builder.broadcast(arg1, arg2) + shape = a1.shape + + num_dims = len(shape) + iterators = ['parallel' for _ in range(num_dims)] + dims = ','.join(['d%s' % i for i in range(num_dims)]) + expr = f'({dims}) -> ({dims})' + maps = [expr,expr,expr] + init = builder.init_tensor(shape, a1.dtype) + + def body(a, b, c): + return a + b + + return builder.generic((a1,a2), init, iterators, maps, body) + +@register_func('array.sum') +def sum_impl(builder, arg): + shape = arg.shape + + num_dims = len(shape) + iterators = ['reduction' for _ in range(num_dims)] + dims = ','.join(['d%s' % i for i in range(num_dims)]) + expr1 = f'({dims}) -> ({dims})' + expr2 = f'({dims}) -> (0)' + maps = [expr1,expr2] + init = builder.from_elements(0, arg.dtype) + + def body(a, b): + return a + b + + res = builder.generic(arg, init, iterators, maps, body) + return builder.extract(res, 0) From 7481a84a38d3b19ace923c94cb60836c55eec3ca Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 31 Jan 2021 14:01:33 +0300 Subject: [PATCH 222/259] inplace binop (#165) --- mlir-compiler/src/lowering.cpp | 12 ++++++++++++ numba/mlir/tests/test_basic.py | 9 +++++++++ 2 files changed, 21 insertions(+) diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/src/lowering.cpp index 5299f03fcf5..b3bd3bbcdc4 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/src/lowering.cpp @@ -295,6 +295,7 @@ struct plier_lowerer final using func_t = mlir::Value (plier_lowerer::*)(const py::handle&); const std::pair handlers[] = { {"binop", &plier_lowerer::lower_binop}, + {"inplace_binop", &plier_lowerer::lower_inplce_binop}, {"unary", &plier_lowerer::lower_unary}, {"cast", &plier_lowerer::lower_cast}, {"call", &plier_lowerer::lower_call}, @@ -428,6 +429,17 @@ struct plier_lowerer final return builder.create(get_current_loc(), lhs, rhs, op_name); } + mlir::Value lower_inplce_binop(const py::handle& expr) + { + auto op = expr.attr("immutable_fn"); + auto lhs_name = expr.attr("lhs"); + auto rhs_name = expr.attr("rhs"); + auto lhs = loadvar(lhs_name); + auto rhs = loadvar(rhs_name); + auto op_name = resolve_op(op); + return builder.create(get_current_loc(), lhs, rhs, op_name); + } + mlir::Value lower_unary(const py::handle& expr) { auto op = expr.attr("fn"); diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 1f0f6bb509e..2f496269c94 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -36,6 +36,15 @@ def test_ops(self): except ZeroDivisionError: pass + def test_inplace_op(self): + def py_func(a,b): + a += b + return a + + jit_func = njit(py_func) + for a, b in itertools.product(_test_values, _test_values): + assert_equal(py_func(a, b), jit_func(a, b)) + def test_unary_ops(self): py_funcs = [ lambda a: +a, From 62b2aa2a441df08b1cb364affdde2a612192078d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 31 Jan 2021 14:51:05 +0300 Subject: [PATCH 223/259] [MLIR] Extend linalg builder (#166) * refac numpy func registration * extend linalg builder to scalars --- mlir-compiler/src/py_linalg_resolver.cpp | 113 ++++++++++++++++------- numba/mlir/__init__.py | 1 - numba/mlir/linalg_builder.py | 6 +- numba/mlir/numpy/funcs.py | 4 +- numba/mlir/numpy_funcs.py | 5 - numba/mlir/tests/test_numpy.py | 9 ++ 6 files changed, 94 insertions(+), 44 deletions(-) delete mode 100644 numba/mlir/numpy_funcs.py diff --git a/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/src/py_linalg_resolver.cpp index 3071149f56a..bda3568a4a0 100644 --- a/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/src/py_linalg_resolver.cpp @@ -28,7 +28,10 @@ namespace { bool is_compatible_types(mlir::TypeRange types) { - return !types.empty() && llvm::all_of(types, [](mlir::Type t){ return t.isa(); }); + return !types.empty() && llvm::all_of(types, [](mlir::Type t) + { + return t.isIntOrFloat() || t.isa(); + }); } py::handle get_dim(int64_t val) @@ -140,7 +143,7 @@ struct PyLinalgResolver::Context } return var(wrap_mlir(value), py_shape, wrap_mlir(elem_type)); } - return var(wrap_mlir(value), py::none(), wrap_mlir(value.getType())); + return var(wrap_mlir(value), py::list(), wrap_mlir(value.getType())); } mlir::FuncOp compile_body(py::handle body, py::list arg_types) @@ -238,7 +241,14 @@ auto get_generic_op_body_types(mlir::ValueRange inputs, mlir::ValueRange outputs { for (auto type : r.getTypes()) { - auto elem_type = type.cast().getElementType(); + auto elem_type = [&]() + { + if (auto tensor = type.dyn_cast()) + { + return tensor.getElementType(); + } + return type; + }(); ret.emplace_back(elem_type); } } @@ -257,6 +267,19 @@ auto generic_op_body_result_types(mlir::ValueRange outputs) return ret; } +mlir::Attribute zero_attr(mlir::Type type) +{ + if (type.isa()) + { + return mlir::IntegerAttr::get(type, 0); + } + if (type.isa()) + { + return mlir::FloatAttr::get(type, 0.0); + } + report_error("zero_attr: unhandled type"); +} + py::object broadcast_impl(py::capsule /*context*/, py::tuple args) { if (1 == args.size()) @@ -273,61 +296,81 @@ py::object init_tensor_impl(py::capsule context, py::list shape, py::capsule dty { auto& ctx = get_py_context(context); auto elem_type = unwrap_mlir(dtype); - auto init = ctx.builder.create(ctx.loc, unwrap_shape(shape), elem_type); + mlir::Value init; + if (shape.empty()) + { + // TODO: undef + init = ctx.builder.create(ctx.loc, zero_attr(elem_type)); + } + else + { + init = ctx.builder.create(ctx.loc, unwrap_shape(shape), elem_type); + } return ctx.context.create_var(ctx.loc, ctx.builder, init); } py::object generic_impl(py::capsule context, py::handle inputs, py::handle outputs, py::list iterators, py::list maps, py::handle body) { auto& ctx = get_py_context(context); + auto loc = ctx.loc; auto& builder = ctx.builder; auto& mlir_context = *builder.getContext(); auto inputs_args = get_agrs_from_tuple(inputs); auto output_args = get_agrs_from_tuple(outputs); auto ret_types = get_types(output_args); - auto affine_maps = get_affine_maps(maps, mlir_context); auto mlir_iterators = get_iterators(iterators, mlir_context); auto func_types = map_types_to_numba(ctx.context.types_mod, get_generic_op_body_types(inputs_args, output_args)); auto body_func = ctx.context.compile_body(body, func_types); - auto body_builder = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + + auto cast_values = [&](mlir::ValueRange vals, mlir::TypeRange types) { - auto cast_values = [&](mlir::ValueRange vals, mlir::TypeRange types) + assert(vals.size() == types.size()); + llvm::SmallVector ret(vals.size()); + auto do_cast = [&](mlir::Value val, mlir::Type type) { - assert(vals.size() == types.size()); - llvm::SmallVector ret(vals.size()); - auto do_cast = [&](mlir::Value val, mlir::Type type) + if (val.getType() == type) { - if (val.getType() == type) - { - return val; - } - return builder.create(loc, type, val).getResult(); - }; - for (auto it : llvm::enumerate(vals)) - { - auto index = static_cast(it.index()); - ret[index] = do_cast(it.value(), types[index]); + return val; } - return ret; + return builder.create(loc, type, val).getResult(); }; - auto func_type = body_func.getType(); - auto new_args = cast_values(args, func_type.getInputs()); - auto call = builder.create(loc, body_func, new_args); - auto new_results = cast_values(call.getResults(), generic_op_body_result_types(output_args)); - builder.create(loc, new_results); + for (auto it : llvm::enumerate(vals)) + { + auto index = static_cast(it.index()); + ret[index] = do_cast(it.value(), types[index]); + } + return ret; }; + if (mlir_iterators.empty()) + { + inputs_args.append(output_args.begin(), output_args.end()); + auto res = builder.create(loc, body_func, inputs_args); + return ctx.context.wrap_result(loc, builder, cast_values(res.getResults(), ret_types)); + } + else + { + auto affine_maps = get_affine_maps(maps, mlir_context); + auto body_builder = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange args) + { + auto func_type = body_func.getType(); + auto new_args = cast_values(args, func_type.getInputs()); + auto call = builder.create(loc, body_func, new_args); + auto new_results = cast_values(call.getResults(), generic_op_body_result_types(output_args)); + builder.create(loc, new_results); + }; - auto generic_op = builder.create( - ctx.loc, - ret_types, - inputs_args, - output_args, - affine_maps, - mlir_iterators, - body_builder); - return ctx.context.wrap_result(ctx.loc, builder, generic_op.getResults()); + auto generic_op = builder.create( + loc, + ret_types, + inputs_args, + output_args, + affine_maps, + mlir_iterators, + body_builder); + return ctx.context.wrap_result(loc, builder, generic_op.getResults()); + } } py::object from_elements_impl(py::capsule context, py::handle values, py::capsule dtype) diff --git a/numba/mlir/__init__.py b/numba/mlir/__init__.py index fc4d383024a..d4ef05b47ad 100644 --- a/numba/mlir/__init__.py +++ b/numba/mlir/__init__.py @@ -1,7 +1,6 @@ from numba import runtests from . import builtin_funcs -from . import numpy_funcs from . import math_funcs from .numpy import funcs diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 8d20a4e8ef7..0d3c0bb8bfa 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -1,4 +1,4 @@ - +from .func_registry import add_func class Var: def __init__(self, ssa_val, shape, dtype): @@ -49,11 +49,13 @@ def compile_func(*args, **kwargs): _func_registry = {} -def register_func(name): +def register_func(name, orig_func = None): def _decorator(func): global _func_registry assert not name in _func_registry _func_registry[name] = func + if not orig_func is None: + add_func(orig_func, name) return func return _decorator diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index 4ac9251867c..28655d7fde3 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -1,6 +1,8 @@ from ..linalg_builder import register_func -@register_func('numpy.add') +import numpy + +@register_func('numpy.add', numpy.add) def add_impl(builder, arg1, arg2): a1, a2 = builder.broadcast(arg1, arg2) shape = a1.shape diff --git a/numba/mlir/numpy_funcs.py b/numba/mlir/numpy_funcs.py deleted file mode 100644 index ed965917a1a..00000000000 --- a/numba/mlir/numpy_funcs.py +++ /dev/null @@ -1,5 +0,0 @@ -from numba.mlir.func_registry import add_func - -import numpy - -add_func(numpy.add, 'numpy.add') diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index d6fcef955a5..b7e5dfc3190 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -40,6 +40,15 @@ def py_func(a): arr = np.asarray([1,2,3]) assert_equal(py_func(arr), jit_func(arr)) + def test_add_scalar(self): + def py_func(a, b): + return np.add(a, b) + + jit_func = njit(py_func) + arr1 = 1 + arr2 = 2 + assert_equal(py_func(arr1, arr2), jit_func(arr1, arr2)) + def test_sum_add(self): def py_func(a, b): return np.add(a, b).sum() From 6dd2143babf3004fb2e80f9b6c36e6d92ce88ddd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 4 Feb 2021 01:22:27 +0300 Subject: [PATCH 224/259] Mlir tuples (#167) --- mlir-compiler/include/plier/PlierOps.td | 2 + mlir-compiler/src/dialect.cpp | 45 ++++++++++++++----- .../src/pipelines/plier_to_linalg.cpp | 43 +++++++++++++++++- mlir-compiler/src/pipelines/plier_to_std.cpp | 20 +++++++++ numba/mlir/tests/test_basic.py | 10 ++++- numba/mlir/tests/test_numpy.py | 9 ++++ 6 files changed, 116 insertions(+), 13 deletions(-) diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/include/plier/PlierOps.td index a4ef3339d04..9324098fb71 100644 --- a/mlir-compiler/include/plier/PlierOps.td +++ b/mlir-compiler/include/plier/PlierOps.td @@ -121,6 +121,7 @@ def GetItemOp : Plier_Op<"getitem", []> { AnyType:$index); let results = (outs AnyType); + let hasFolder = 1; let builders = [ OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::Value":$index)> @@ -134,6 +135,7 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { UI32Attr:$index); let results = (outs AnyType); + let hasFolder = 1; let builders = [ OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::Value":$index_var, "unsigned":$index)> diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 549c25a6b87..282c6abfbef 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -201,6 +201,23 @@ void BuildTupleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, // return mlir::failure(); //} +mlir::Value fold_build_tuple_getitem(mlir::Value val, llvm::ArrayRef operands) +{ + auto build_tuple = val.getDefiningOp(); + if (build_tuple) + { + if (auto val = operands[1].dyn_cast_or_null()) + { + auto index = val.getInt(); + if (index >= 0 && index < build_tuple.getNumOperands()) + { + return build_tuple.getOperand(static_cast(index)); + } + } + } + return {}; +} + void GetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value, ::mlir::Value index) { @@ -208,6 +225,15 @@ void GetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, PyType::getUndefined(state.getContext()), value, index); } +mlir::OpFoldResult GetItemOp::fold(llvm::ArrayRef operands) +{ + if (auto val = fold_build_tuple_getitem(value(), operands)) + { + return val; + } + return nullptr; +} + void StaticGetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value, ::mlir::Value index_var, unsigned int index) @@ -217,17 +243,14 @@ void StaticGetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &stat value, index_var, index); } -//mlir::OpFoldResult StaticGetItemOp::fold(llvm::ArrayRef /*operands*/) -//{ -// auto index = this->index(); -// auto args = getOperands(); -// if ((index + 1) < args.size() && // skip last arg -// args[index].getType() == getResult().getType()) -// { -// return args[index]; -// } -// return nullptr; -//} +mlir::OpFoldResult StaticGetItemOp::fold(llvm::ArrayRef operands) +{ + if (auto val = fold_build_tuple_getitem(value(), operands)) + { + return val; + } + return nullptr; +} void GetiterOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, ::mlir::Value value) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 7fa3fcf87f6..8d243be78c0 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -492,6 +492,46 @@ struct SetitemOpLowering : public mlir::OpRewritePattern } }; +struct ArrayShape : public mlir::OpRewritePattern +{ + ArrayShape(mlir::TypeConverter& type_converter, + mlir::MLIRContext* context): + OpRewritePattern(context), + converter(type_converter) {} + + mlir::LogicalResult matchAndRewrite( + plier::GetattrOp op, mlir::PatternRewriter &rewriter) const override + { + auto type = op.value().getType().dyn_cast(); + if (!type || op.name() != "shape" || !type.hasRank()) + { + return mlir::failure(); + } + + auto rank = static_cast(type.getRank()); + auto elem_type = converter.convertType(op.getType()).dyn_cast_or_null(); + if (!elem_type || elem_type.size() != rank) + { + return mlir::failure(); + } + + llvm::SmallVector dims(rank); + for (size_t i = 0; i < rank; ++i) + { + auto dim = rewriter.create(op.getLoc(), op.value(), i); + dims[i] = rewriter.create(op.getLoc(), elem_type.getType(i), dim); + } + auto res = rewriter.create(op.getLoc(), op.getType(), dims); + rerun_std_pipeline(op); + rewriter.replaceOp(op, res.getResult()); + return mlir::success(); + } + +private: + mlir::TypeConverter& converter; +}; + + void PlierToLinalgPass::runOnOperation() { auto context = &getContext(); @@ -513,7 +553,8 @@ void PlierToLinalgPass::runOnOperation() mlir::OwningRewritePatternList patterns; patterns.insert< FuncOpSignatureConversion, - CastOpLowering + CastOpLowering, + ArrayShape >(type_converter, context); CallLowerer callLowerer; diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 3b0c4773f8c..07d0e002068 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1115,6 +1115,25 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef return mlir::success(); } +mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +{ + if (operands.size() != 1) + { + return mlir::failure(); + } + + auto build_tuple = operands[0].getDefiningOp(); + if (!build_tuple) + { + return mlir::failure(); + } + + auto size = rewriter.create(op.getLoc(), build_tuple.getNumOperands()); + auto cast = rewriter.create(op.getLoc(), op.getType(), size); + rewriter.replaceOp(op, cast.getResult()); + return mlir::success(); +} + mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) { if (operands.size() != 1) @@ -1199,6 +1218,7 @@ struct CallLowerer std::pair handlers[] = { {"bool", lower_bool_cast}, {"range", lower_range}, + {"len", lower_len}, }; for (auto& handler : handlers) { diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 2f496269c94..01a0fae1fb0 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -129,7 +129,6 @@ def py_func(a, b): for a, b in itertools.product(_test_values, _test_values): assert_equal(py_func(a, b), jit_func(a, b)) - @unittest.skip def test_tuple(self): def py_func(a, b, c): t = (a,b,c) @@ -139,6 +138,15 @@ def py_func(a, b, c): for a, b, c in itertools.product(_test_values, _test_values, _test_values): assert_equal(py_func(a, b, c), jit_func(a, b, c)) + def test_tuple_len(self): + def py_func(a, b, c): + t = (a,b,c) + return len(t) + + jit_func = njit(py_func) + for a, b, c in itertools.product(_test_values, _test_values, _test_values): + assert_equal(py_func(a, b, c), jit_func(a, b, c)) + def test_range1(self): def py_func(a): res = 0 diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index b7e5dfc3190..91e6d381dea 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -102,5 +102,14 @@ def py_func(a): arr = np.asarray([3,2,1]) assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + def test_array_shape(self): + def py_func(a): + shape = a.shape + return shape[0] + shape[1] + + jit_func = njit(py_func) + arr = np.array([[1,2,3],[4,5,6]]) + assert_equal(py_func(arr), jit_func(arr)) + if __name__ == '__main__': unittest.main() From 3637c3347470cc5e0a55217cacef0ca0f7d2f450 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 4 Feb 2021 22:40:41 +0300 Subject: [PATCH 225/259] [MLIR] Update to master (#168) --- mlir-compiler/CMakeLists.txt | 2 ++ mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/src/dialect.cpp | 12 ++++--- mlir-compiler/src/pipelines/plier_to_std.cpp | 35 +++++++++++++++++++- mlir-compiler/src/transforms/const_utils.cpp | 23 +++++++++++++ mlir-compiler/src/transforms/const_utils.hpp | 24 ++++++++++++++ 6 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 mlir-compiler/src/transforms/const_utils.cpp create mode 100644 mlir-compiler/src/transforms/const_utils.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index a01eea32174..63ec51ddfbe 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -38,6 +38,7 @@ set(SOURCES_LIST src/rewrites/force_inline.cpp src/rewrites/promote_to_parallel.cpp src/rewrites/type_conversion.cpp + src/transforms/const_utils.cpp src/transforms/func_utils.cpp src/transforms/loop_utils.cpp src/transforms/pipeline_utils.cpp @@ -67,6 +68,7 @@ set(HEADERS_LIST src/rewrites/force_inline.hpp src/rewrites/promote_to_parallel.hpp src/rewrites/type_conversion.hpp + src/transforms/const_utils.hpp src/transforms/func_utils.hpp src/transforms/loop_utils.hpp src/transforms/pipeline_utils.hpp diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index c8a3e9f598a..6ededfdfa6c 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -de4ba7073bd7e200aca704e6a26403e07bc246a5 +ba87f99168c93461b28a4aa2d05e238ff774d57a diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/src/dialect.cpp index 282c6abfbef..32ae0ff85fb 100644 --- a/mlir-compiler/src/dialect.cpp +++ b/mlir-compiler/src/dialect.cpp @@ -201,7 +201,7 @@ void BuildTupleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, // return mlir::failure(); //} -mlir::Value fold_build_tuple_getitem(mlir::Value val, llvm::ArrayRef operands) +mlir::Value fold_build_tuple_getitem(mlir::Value val, mlir::Type type, llvm::ArrayRef operands) { auto build_tuple = val.getDefiningOp(); if (build_tuple) @@ -211,7 +211,11 @@ mlir::Value fold_build_tuple_getitem(mlir::Value val, llvm::ArrayRef= 0 && index < build_tuple.getNumOperands()) { - return build_tuple.getOperand(static_cast(index)); + auto op = build_tuple.getOperand(static_cast(index)); + if (op.getType() == type) + { + return op; + } } } } @@ -227,7 +231,7 @@ void GetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::OpFoldResult GetItemOp::fold(llvm::ArrayRef operands) { - if (auto val = fold_build_tuple_getitem(value(), operands)) + if (auto val = fold_build_tuple_getitem(value(), getType(), operands)) { return val; } @@ -245,7 +249,7 @@ void StaticGetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &stat mlir::OpFoldResult StaticGetItemOp::fold(llvm::ArrayRef operands) { - if (auto val = fold_build_tuple_getitem(value(), operands)) + if (auto val = fold_build_tuple_getitem(value(), getType(), operands)) { return val; } diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index 07d0e002068..b4e760988f4 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -18,6 +18,7 @@ #include "rewrites/call_lowering.hpp" #include "rewrites/cast_lowering.hpp" #include "rewrites/type_conversion.hpp" +#include "transforms/const_utils.hpp" #include "transforms/func_utils.hpp" #include "transforms/loop_utils.hpp" @@ -1080,6 +1081,36 @@ struct FixupWhileTypes : public mlir::OpRewritePattern } }; +template +struct FoldTupleGetitem : public mlir::OpRewritePattern +{ + FoldTupleGetitem(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter &rewriter) const override + { + auto build_tuple = op.value().template getDefiningOp(); + if (!build_tuple) + { + return mlir::failure(); + } + + if (auto val = getConstVal(op.getOperand(1))) + { + auto index = val.getInt(); + if (index >= 0 && index < build_tuple.getNumOperands()) + { + auto val = build_tuple.getOperand(static_cast(index)); + rewriter.replaceOp(op, val); + return mlir::success(); + } + } + return mlir::failure(); + } +}; + mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) { if ((operands.size() < 1 || operands.size() > 3) || @@ -1296,7 +1327,9 @@ void PlierToStdPass::runOnOperation() UnaryOpLowering, ScfIfRewrite, ScfWhileRewrite, - FixupWhileTypes + FixupWhileTypes, + FoldTupleGetitem, + FoldTupleGetitem >(type_converter, context); patterns.insert< diff --git a/mlir-compiler/src/transforms/const_utils.cpp b/mlir-compiler/src/transforms/const_utils.cpp new file mode 100644 index 00000000000..f7377077f51 --- /dev/null +++ b/mlir-compiler/src/transforms/const_utils.cpp @@ -0,0 +1,23 @@ +#include "transforms/const_utils.hpp" + +#include +#include + +mlir::Attribute getConstVal(mlir::Operation* op) +{ + if (!op->hasTrait()) + { + return {}; + } + + return op->getAttr("value"); +} + +mlir::Attribute getConstVal(mlir::Value op) +{ + if (auto parent_op = op.getDefiningOp()) + { + return getConstVal(parent_op); + } + return {}; +} diff --git a/mlir-compiler/src/transforms/const_utils.hpp b/mlir-compiler/src/transforms/const_utils.hpp new file mode 100644 index 00000000000..f2c62e507f5 --- /dev/null +++ b/mlir-compiler/src/transforms/const_utils.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace mlir +{ +class Operation; +class Value; +} + +mlir::Attribute getConstVal(mlir::Operation* op); +mlir::Attribute getConstVal(mlir::Value op); + +template +T getConstVal(mlir::Operation* op) +{ + return getConstVal(op).dyn_cast_or_null(); +} + +template +T getConstVal(mlir::Value op) +{ + return getConstVal(op).dyn_cast_or_null(); +} From c0869389fff2c03b8b014b5f594e22b07cc2233b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 6 Feb 2021 14:20:59 +0300 Subject: [PATCH 226/259] [MLIR] Some opts for array index access (#170) --- .../src/pipelines/plier_to_linalg.cpp | 270 +++++++++++++++++- mlir-compiler/src/rewrites/cse.hpp | 2 +- numba/mlir/tests/test_numpy.py | 46 ++- 3 files changed, 312 insertions(+), 6 deletions(-) diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 8d243be78c0..751f9549473 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -417,6 +417,256 @@ struct PlierToLinalgPass : void runOnOperation() override; }; +bool is_index_compatible(mlir::Type lhs_type, mlir::Type rhs_type) +{ + if (!lhs_type.isa() || lhs_type != rhs_type) + { + return false; + } + + if (lhs_type.cast().getWidth() < 64) + { + return false; + } + return true; +} + +template +struct ArithIndexCastSimplify : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter &rewriter) const override + { + auto lhs_type = op.lhs().getType(); + auto rhs_type = op.rhs().getType(); + if (!is_index_compatible(lhs_type, rhs_type)) + { + return mlir::failure(); + } + + auto get_cast = [](mlir::Value val)->mlir::Value + { + if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) + { + return op.getOperand(); + } + return {}; + }; + + auto get_const = [](mlir::Value val)->mlir::IntegerAttr + { + if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) + { + return op.getValue().cast(); + } + return {}; + }; + + auto lhs = get_cast(op.lhs()); + auto rhs = get_cast(op.rhs()); + auto lhs_const = get_const(op.lhs()); + auto rhs_const = get_const(op.rhs()); + if (lhs && rhs) + { + auto new_op = rewriter.create(op.getLoc(), lhs, rhs); + auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); + rewriter.replaceOp(op, result.getResult()); + return mlir::success(); + } + if (lhs && rhs_const) + { + auto new_const = rewriter.create(op.getLoc(), rhs_const.getInt()); + auto new_op = rewriter.create(op.getLoc(), lhs, new_const); + auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); + rewriter.replaceOp(op, result.getResult()); + return mlir::success(); + } + if (lhs_const && rhs) + { + auto new_const = rewriter.create(op.getLoc(), lhs_const.getInt()); + auto new_op = rewriter.create(op.getLoc(), new_const, rhs); + auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); + rewriter.replaceOp(op, result.getResult()); + return mlir::success(); + } + + return mlir::failure(); + } +}; + +struct CmpIndexCastSimplify : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::CmpIOp op, mlir::PatternRewriter &rewriter) const override + { + auto lhs_type = op.lhs().getType(); + auto rhs_type = op.rhs().getType(); + if (!is_index_compatible(lhs_type, rhs_type)) + { + return mlir::failure(); + } + + auto get_cast = [](mlir::Value val)->mlir::Value + { + if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) + { + return op.getOperand(); + } + return {}; + }; + + auto get_const = [](mlir::Value val)->mlir::IntegerAttr + { + if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) + { + return op.getValue().cast(); + } + return {}; + }; + + auto lhs = get_cast(op.lhs()); + auto rhs = get_cast(op.rhs()); + auto lhs_const = get_const(op.lhs()); + auto rhs_const = get_const(op.rhs()); + if (lhs && rhs) + { + auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), lhs, rhs); + rewriter.replaceOp(op, new_cmp.getResult()); + return mlir::success(); + } + if (lhs && rhs_const) + { + auto new_const = rewriter.create(op.getLoc(), rhs_const.getInt()); + auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), lhs, new_const); + rewriter.replaceOp(op, new_cmp.getResult()); + return mlir::success(); + } + if (lhs_const && rhs) + { + auto new_const = rewriter.create(op.getLoc(), lhs_const.getInt()); + auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), new_const, rhs); + rewriter.replaceOp(op, new_cmp.getResult()); + return mlir::success(); + } + + return mlir::failure(); + } +}; + +struct CmpLoopBoundsSimplify : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override + { + auto index_var = op.getLoopBody().front().getArgument(0); + if (auto step_var = mlir::dyn_cast_or_null(op.step().getDefiningOp())) + { + assert(step_var.value().cast().getInt() > 0); + } + bool matched = false; + for (auto user : llvm::make_early_inc_range(index_var.getUsers())) + { + auto cmp = mlir::dyn_cast(user); + if (cmp) + { + auto pred = cmp.predicate(); + auto lhs = cmp.lhs(); + auto rhs = cmp.rhs(); + // Normalize index and predicate (index always on the left) + using norm_fptr_t = bool(*)(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs); + using Predicate = mlir::CmpIPredicate; + const norm_fptr_t norm_handlers[] = { + &norm_impl, + &norm_impl, + &norm_impl, + &norm_impl, + &norm_impl, + &norm_impl, + }; + + for (auto h : norm_handlers) + { + if (h(pred, index_var, lhs, rhs)) + { + break; + } + } + + using fptr_t = llvm::Optional(*)(Predicate pred, mlir::Value lhs, mlir::Value rhs, mlir::Value index, mlir::Value lowerBound, mlir::Value upperBound); + const fptr_t handlers[] = { + &handler_impl, + &handler_impl, + &handler_impl, + &handler_impl, + }; + + for (auto h : handlers) + { + if (auto c = h(pred, lhs, rhs, index_var, op.lowerBound(), op.upperBound())) + { + auto type = rewriter.getI1Type(); + auto val = rewriter.getIntegerAttr(type, *c); + auto const_val = rewriter.create(cmp.getLoc(), val); + rewriter.replaceOp(cmp, const_val.getResult()); + matched = true; + break; + } + } + } + } + return mlir::success(matched); + } + +private: + template + static bool norm_impl2(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs) + { + if (pred != SrcPred) + { + return false; + } + if (index != lhs) + { + std::swap(lhs, rhs); + pred = DstPred; + } + return true; + } + + template + static bool norm_impl(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs) + { + return norm_impl2(pred, index, lhs, rhs) || + norm_impl2(pred, index, lhs, rhs); + } + + enum EBound + { + LowerBound, + UpperBound, + }; + template + static llvm::Optional handler_impl(mlir::CmpIPredicate pred, mlir::Value lhs, mlir::Value rhs, mlir::Value index, mlir::Value lowerBound, mlir::Value upperBound) + { + if (pred != Pred) + { + return {}; + } + auto bound = (Bound == LowerBound ? lowerBound : upperBound); + if(rhs == bound && lhs == index) + { + return Value; + } + return {}; + } +}; + template struct SetitemOpLowering : public mlir::OpRewritePattern { @@ -649,15 +899,29 @@ void PostLinalgOptPass::runOnOperation() { mlir::OwningRewritePatternList patterns; + auto& context = getContext(); + for (auto *op : context.getRegisteredOperations()) + { + op->getCanonicalizationPatterns(patterns, &context); + } + patterns.insert< CanonicalizeReduction, // LoopInvariantCodeMotion, TODO PromoteToParallel, + CmpIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + CmpLoopBoundsSimplify, CSERewrite - >(&getContext()); - + >(&context); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) diff --git a/mlir-compiler/src/rewrites/cse.hpp b/mlir-compiler/src/rewrites/cse.hpp index 1536e8c0c9f..9d2675160d5 100644 --- a/mlir-compiler/src/rewrites/cse.hpp +++ b/mlir-compiler/src/rewrites/cse.hpp @@ -15,7 +15,7 @@ template struct CSERewrite : public mlir::OpRewritePattern { CSERewrite(mlir::MLIRContext *context): - OpRewritePattern(context, /*benefit*/0) {} + OpRewritePattern(context, /*benefit*/1) {} // TODO: benefit=0 mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter &rewriter) const override diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 91e6d381dea..170e5b2e703 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -88,11 +88,11 @@ def py_func(a): arr = np.asarray([3,2,1]) assert_equal(py_func(arr.copy()), jit_func(arr.copy())) - def test_array_bounds(self): + def test_array_bounds1(self): def py_func(a): res = 0 for i in range(len(a)): - if i >= len(a): + if i >= len(a) or i < 0: res = res + 1 else: res = res + a[i] @@ -102,6 +102,48 @@ def py_func(a): arr = np.asarray([3,2,1]) assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + def test_array_bounds2(self): + def py_func(a): + res = 0 + for i in range(len(a)): + if i < len(a) and i >= 0: + res = res + a[i] + else: + res = res + 1 + return res + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + + def test_array_bounds3(self): + def py_func(a): + res = 0 + for i in range(len(a)): + if 0 <= i < len(a): + res = res + a[i] + else: + res = res + 1 + return res + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + + def test_array_bounds4(self): + def py_func(a): + res = 0 + for i in range(len(a) - 1): + if 0 <= i < (len(a) - 1): + res = res + a[i] + else: + res = res + 1 + return res + + jit_func = njit(py_func) + arr = np.asarray([3,2,1]) + assert_equal(py_func(arr.copy()), jit_func(arr.copy())) + def test_array_shape(self): def py_func(a): shape = a.shape From 1817405371ffef46d730c154b873c73be3e556eb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 6 Feb 2021 16:05:55 +0300 Subject: [PATCH 227/259] [MLIR] Move rewrites tos separate files (#171) --- mlir-compiler/CMakeLists.txt | 6 + .../src/pipelines/plier_to_linalg.cpp | 290 +----------------- mlir-compiler/src/pipelines/plier_to_std.cpp | 6 +- .../src/rewrites/index_type_propagation.cpp | 161 ++++++++++ .../src/rewrites/index_type_propagation.hpp | 9 + mlir-compiler/src/rewrites/loop_rewrites.cpp | 110 +++++++ mlir-compiler/src/rewrites/loop_rewrites.hpp | 19 ++ mlir-compiler/src/transforms/cast_utils.cpp | 19 ++ mlir-compiler/src/transforms/cast_utils.hpp | 12 + mlir-compiler/src/transforms/loop_utils.cpp | 8 +- 10 files changed, 347 insertions(+), 293 deletions(-) create mode 100644 mlir-compiler/src/rewrites/index_type_propagation.cpp create mode 100644 mlir-compiler/src/rewrites/index_type_propagation.hpp create mode 100644 mlir-compiler/src/rewrites/loop_rewrites.cpp create mode 100644 mlir-compiler/src/rewrites/loop_rewrites.hpp create mode 100644 mlir-compiler/src/transforms/cast_utils.cpp create mode 100644 mlir-compiler/src/transforms/cast_utils.hpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 63ec51ddfbe..be672b24219 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -36,8 +36,11 @@ set(SOURCES_LIST src/rewrites/cast_lowering.cpp src/rewrites/cse.cpp src/rewrites/force_inline.cpp + src/rewrites/index_type_propagation.cpp + src/rewrites/loop_rewrites.cpp src/rewrites/promote_to_parallel.cpp src/rewrites/type_conversion.cpp + src/transforms/cast_utils.cpp src/transforms/const_utils.cpp src/transforms/func_utils.cpp src/transforms/loop_utils.cpp @@ -66,8 +69,11 @@ set(HEADERS_LIST src/rewrites/cast_lowering.hpp src/rewrites/cse.hpp src/rewrites/force_inline.hpp + src/rewrites/index_type_propagation.hpp + src/rewrites/loop_rewrites.hpp src/rewrites/promote_to_parallel.hpp src/rewrites/type_conversion.hpp + src/transforms/cast_utils.hpp src/transforms/const_utils.hpp src/transforms/func_utils.hpp src/transforms/loop_utils.hpp diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 751f9549473..115a078e0b6 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -29,6 +29,9 @@ #include "rewrites/promote_to_parallel.hpp" #include "rewrites/type_conversion.hpp" #include "rewrites/force_inline.hpp" +#include "rewrites/index_type_propagation.hpp" +#include "rewrites/loop_rewrites.hpp" + #include "transforms/loop_utils.hpp" #include "base_pipeline.hpp" @@ -114,33 +117,6 @@ bool check_numpy_args(llvm::ArrayRef args, unsigned expected_count) return true; } -mlir::Attribute get_zero(mlir::Type type) -{ - assert(type); - if (auto int_type = type.dyn_cast()) - { - return mlir::IntegerAttr::get(type, 0); - } - if (auto float_type = type.dyn_cast()) - { - return mlir::FloatAttr::get(type, 0.0); - } - llvm_unreachable("get_zero: usupported type"); -} - -mlir::Type get_elem_type(mlir::Type type) -{ - if (auto memref = type.dyn_cast()) - { - return memref.getElementType(); - } - if (auto tensor = type.dyn_cast()) - { - return tensor.getElementType(); - } - llvm_unreachable("get_elem_type: unknown type"); -} - void rerun_std_pipeline(mlir::Operation* op) { assert(nullptr != op); @@ -417,256 +393,6 @@ struct PlierToLinalgPass : void runOnOperation() override; }; -bool is_index_compatible(mlir::Type lhs_type, mlir::Type rhs_type) -{ - if (!lhs_type.isa() || lhs_type != rhs_type) - { - return false; - } - - if (lhs_type.cast().getWidth() < 64) - { - return false; - } - return true; -} - -template -struct ArithIndexCastSimplify : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - Op op, mlir::PatternRewriter &rewriter) const override - { - auto lhs_type = op.lhs().getType(); - auto rhs_type = op.rhs().getType(); - if (!is_index_compatible(lhs_type, rhs_type)) - { - return mlir::failure(); - } - - auto get_cast = [](mlir::Value val)->mlir::Value - { - if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) - { - return op.getOperand(); - } - return {}; - }; - - auto get_const = [](mlir::Value val)->mlir::IntegerAttr - { - if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) - { - return op.getValue().cast(); - } - return {}; - }; - - auto lhs = get_cast(op.lhs()); - auto rhs = get_cast(op.rhs()); - auto lhs_const = get_const(op.lhs()); - auto rhs_const = get_const(op.rhs()); - if (lhs && rhs) - { - auto new_op = rewriter.create(op.getLoc(), lhs, rhs); - auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); - rewriter.replaceOp(op, result.getResult()); - return mlir::success(); - } - if (lhs && rhs_const) - { - auto new_const = rewriter.create(op.getLoc(), rhs_const.getInt()); - auto new_op = rewriter.create(op.getLoc(), lhs, new_const); - auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); - rewriter.replaceOp(op, result.getResult()); - return mlir::success(); - } - if (lhs_const && rhs) - { - auto new_const = rewriter.create(op.getLoc(), lhs_const.getInt()); - auto new_op = rewriter.create(op.getLoc(), new_const, rhs); - auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); - rewriter.replaceOp(op, result.getResult()); - return mlir::success(); - } - - return mlir::failure(); - } -}; - -struct CmpIndexCastSimplify : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::CmpIOp op, mlir::PatternRewriter &rewriter) const override - { - auto lhs_type = op.lhs().getType(); - auto rhs_type = op.rhs().getType(); - if (!is_index_compatible(lhs_type, rhs_type)) - { - return mlir::failure(); - } - - auto get_cast = [](mlir::Value val)->mlir::Value - { - if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) - { - return op.getOperand(); - } - return {}; - }; - - auto get_const = [](mlir::Value val)->mlir::IntegerAttr - { - if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) - { - return op.getValue().cast(); - } - return {}; - }; - - auto lhs = get_cast(op.lhs()); - auto rhs = get_cast(op.rhs()); - auto lhs_const = get_const(op.lhs()); - auto rhs_const = get_const(op.rhs()); - if (lhs && rhs) - { - auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), lhs, rhs); - rewriter.replaceOp(op, new_cmp.getResult()); - return mlir::success(); - } - if (lhs && rhs_const) - { - auto new_const = rewriter.create(op.getLoc(), rhs_const.getInt()); - auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), lhs, new_const); - rewriter.replaceOp(op, new_cmp.getResult()); - return mlir::success(); - } - if (lhs_const && rhs) - { - auto new_const = rewriter.create(op.getLoc(), lhs_const.getInt()); - auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), new_const, rhs); - rewriter.replaceOp(op, new_cmp.getResult()); - return mlir::success(); - } - - return mlir::failure(); - } -}; - -struct CmpLoopBoundsSimplify : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override - { - auto index_var = op.getLoopBody().front().getArgument(0); - if (auto step_var = mlir::dyn_cast_or_null(op.step().getDefiningOp())) - { - assert(step_var.value().cast().getInt() > 0); - } - bool matched = false; - for (auto user : llvm::make_early_inc_range(index_var.getUsers())) - { - auto cmp = mlir::dyn_cast(user); - if (cmp) - { - auto pred = cmp.predicate(); - auto lhs = cmp.lhs(); - auto rhs = cmp.rhs(); - // Normalize index and predicate (index always on the left) - using norm_fptr_t = bool(*)(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs); - using Predicate = mlir::CmpIPredicate; - const norm_fptr_t norm_handlers[] = { - &norm_impl, - &norm_impl, - &norm_impl, - &norm_impl, - &norm_impl, - &norm_impl, - }; - - for (auto h : norm_handlers) - { - if (h(pred, index_var, lhs, rhs)) - { - break; - } - } - - using fptr_t = llvm::Optional(*)(Predicate pred, mlir::Value lhs, mlir::Value rhs, mlir::Value index, mlir::Value lowerBound, mlir::Value upperBound); - const fptr_t handlers[] = { - &handler_impl, - &handler_impl, - &handler_impl, - &handler_impl, - }; - - for (auto h : handlers) - { - if (auto c = h(pred, lhs, rhs, index_var, op.lowerBound(), op.upperBound())) - { - auto type = rewriter.getI1Type(); - auto val = rewriter.getIntegerAttr(type, *c); - auto const_val = rewriter.create(cmp.getLoc(), val); - rewriter.replaceOp(cmp, const_val.getResult()); - matched = true; - break; - } - } - } - } - return mlir::success(matched); - } - -private: - template - static bool norm_impl2(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs) - { - if (pred != SrcPred) - { - return false; - } - if (index != lhs) - { - std::swap(lhs, rhs); - pred = DstPred; - } - return true; - } - - template - static bool norm_impl(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs) - { - return norm_impl2(pred, index, lhs, rhs) || - norm_impl2(pred, index, lhs, rhs); - } - - enum EBound - { - LowerBound, - UpperBound, - }; - template - static llvm::Optional handler_impl(mlir::CmpIPredicate pred, mlir::Value lhs, mlir::Value rhs, mlir::Value index, mlir::Value lowerBound, mlir::Value upperBound) - { - if (pred != Pred) - { - return {}; - } - auto bound = (Bound == LowerBound ? lowerBound : upperBound); - if(rhs == bound && lhs == index) - { - return Value; - } - return {}; - } -}; - template struct SetitemOpLowering : public mlir::OpRewritePattern { @@ -909,18 +635,12 @@ void PostLinalgOptPass::runOnOperation() CanonicalizeReduction, // LoopInvariantCodeMotion, TODO PromoteToParallel, - CmpIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, CmpLoopBoundsSimplify, CSERewrite >(&context); + populate_index_propagate_patterns(context, patterns); + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/src/pipelines/plier_to_std.cpp index b4e760988f4..27faec1d3a2 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -454,7 +454,7 @@ mlir::Value float_int_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRe return rewriter.create(val.getLoc(), val, dst_type); } -mlir::Value index_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) +mlir::Value index_cast_impl(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& rewriter) { return rewriter.create(val.getLoc(), val, dst_type); } @@ -481,8 +481,8 @@ mlir::Value do_cast(mlir::Type dst_type, mlir::Value val, mlir::PatternRewriter& {&is_int, &is_int, &int_cast}, {&is_int, &is_float, &int_float_cast}, {&is_float, &is_int, &float_int_cast}, - {&is_index, &is_int, &index_cast}, - {&is_int, &is_index, &index_cast}, + {&is_index, &is_int, &index_cast_impl}, + {&is_int, &is_index, &index_cast_impl}, }; for (auto& h : handlers) diff --git a/mlir-compiler/src/rewrites/index_type_propagation.cpp b/mlir-compiler/src/rewrites/index_type_propagation.cpp new file mode 100644 index 00000000000..3dbde2f600c --- /dev/null +++ b/mlir-compiler/src/rewrites/index_type_propagation.cpp @@ -0,0 +1,161 @@ +#include "rewrites/index_type_propagation.hpp" + +#include +#include + +namespace +{ +bool is_index_compatible(mlir::Type lhs_type, mlir::Type rhs_type) +{ + if (!lhs_type.isa() || lhs_type != rhs_type) + { + return false; + } + + if (lhs_type.cast().getWidth() < 64) + { + return false; + } + return true; +} + +template +struct ArithIndexCastSimplify : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + Op op, mlir::PatternRewriter &rewriter) const override + { + auto lhs_type = op.lhs().getType(); + auto rhs_type = op.rhs().getType(); + if (!is_index_compatible(lhs_type, rhs_type)) + { + return mlir::failure(); + } + + auto get_cast = [](mlir::Value val)->mlir::Value + { + if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) + { + return op.getOperand(); + } + return {}; + }; + + auto get_const = [](mlir::Value val)->mlir::IntegerAttr + { + if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) + { + return op.getValue().cast(); + } + return {}; + }; + + auto lhs = get_cast(op.lhs()); + auto rhs = get_cast(op.rhs()); + auto lhs_const = get_const(op.lhs()); + auto rhs_const = get_const(op.rhs()); + if (lhs && rhs) + { + auto new_op = rewriter.create(op.getLoc(), lhs, rhs); + auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); + rewriter.replaceOp(op, result.getResult()); + return mlir::success(); + } + if (lhs && rhs_const) + { + auto new_const = rewriter.create(op.getLoc(), rhs_const.getInt()); + auto new_op = rewriter.create(op.getLoc(), lhs, new_const); + auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); + rewriter.replaceOp(op, result.getResult()); + return mlir::success(); + } + if (lhs_const && rhs) + { + auto new_const = rewriter.create(op.getLoc(), lhs_const.getInt()); + auto new_op = rewriter.create(op.getLoc(), new_const, rhs); + auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); + rewriter.replaceOp(op, result.getResult()); + return mlir::success(); + } + + return mlir::failure(); + } +}; + +struct CmpIndexCastSimplify : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::CmpIOp op, mlir::PatternRewriter &rewriter) const override + { + auto lhs_type = op.lhs().getType(); + auto rhs_type = op.rhs().getType(); + if (!is_index_compatible(lhs_type, rhs_type)) + { + return mlir::failure(); + } + + auto get_cast = [](mlir::Value val)->mlir::Value + { + if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) + { + return op.getOperand(); + } + return {}; + }; + + auto get_const = [](mlir::Value val)->mlir::IntegerAttr + { + if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) + { + return op.getValue().cast(); + } + return {}; + }; + + auto lhs = get_cast(op.lhs()); + auto rhs = get_cast(op.rhs()); + auto lhs_const = get_const(op.lhs()); + auto rhs_const = get_const(op.rhs()); + if (lhs && rhs) + { + auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), lhs, rhs); + rewriter.replaceOp(op, new_cmp.getResult()); + return mlir::success(); + } + if (lhs && rhs_const) + { + auto new_const = rewriter.create(op.getLoc(), rhs_const.getInt()); + auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), lhs, new_const); + rewriter.replaceOp(op, new_cmp.getResult()); + return mlir::success(); + } + if (lhs_const && rhs) + { + auto new_const = rewriter.create(op.getLoc(), lhs_const.getInt()); + auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), new_const, rhs); + rewriter.replaceOp(op, new_cmp.getResult()); + return mlir::success(); + } + + return mlir::failure(); + } +}; +} + +void populate_index_propagate_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns) +{ + patterns.insert< + CmpIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify, + ArithIndexCastSimplify + >(&context); +} diff --git a/mlir-compiler/src/rewrites/index_type_propagation.hpp b/mlir-compiler/src/rewrites/index_type_propagation.hpp new file mode 100644 index 00000000000..142c3b6a35c --- /dev/null +++ b/mlir-compiler/src/rewrites/index_type_propagation.hpp @@ -0,0 +1,9 @@ +#pragma once + +namespace mlir +{ +class OwningRewritePatternList; +class MLIRContext; +} + +void populate_index_propagate_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns); diff --git a/mlir-compiler/src/rewrites/loop_rewrites.cpp b/mlir-compiler/src/rewrites/loop_rewrites.cpp new file mode 100644 index 00000000000..eefa40b91cb --- /dev/null +++ b/mlir-compiler/src/rewrites/loop_rewrites.cpp @@ -0,0 +1,110 @@ +#include "rewrites/loop_rewrites.hpp" + +#include +#include + +namespace +{ +template +bool norm_impl2(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs) +{ + if (pred != SrcPred) + { + return false; + } + if (index != lhs) + { + std::swap(lhs, rhs); + pred = DstPred; + } + return true; +} + +template +bool norm_impl(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs) +{ + return norm_impl2(pred, index, lhs, rhs) || + norm_impl2(pred, index, lhs, rhs); +} + +enum EBound +{ + LowerBound, + UpperBound, +}; +template +llvm::Optional handler_impl(mlir::CmpIPredicate pred, mlir::Value lhs, mlir::Value rhs, mlir::Value index, mlir::Value lowerBound, mlir::Value upperBound) +{ + if (pred != Pred) + { + return {}; + } + auto bound = (Bound == LowerBound ? lowerBound : upperBound); + if(rhs == bound && lhs == index) + { + return Value; + } + return {}; +} +} + +mlir::LogicalResult CmpLoopBoundsSimplify::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const +{ + auto index_var = op.getLoopBody().front().getArgument(0); + if (auto step_var = mlir::dyn_cast_or_null(op.step().getDefiningOp())) + { + assert(step_var.value().cast().getInt() > 0); + } + bool matched = false; + for (auto user : llvm::make_early_inc_range(index_var.getUsers())) + { + auto cmp = mlir::dyn_cast(user); + if (cmp) + { + auto pred = cmp.predicate(); + auto lhs = cmp.lhs(); + auto rhs = cmp.rhs(); + // Normalize index and predicate (index always on the left) + using norm_fptr_t = bool(*)(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs); + using Predicate = mlir::CmpIPredicate; + const norm_fptr_t norm_handlers[] = { + &norm_impl, + &norm_impl, + &norm_impl, + &norm_impl, + &norm_impl, + &norm_impl, + }; + + for (auto h : norm_handlers) + { + if (h(pred, index_var, lhs, rhs)) + { + break; + } + } + + using fptr_t = llvm::Optional(*)(Predicate pred, mlir::Value lhs, mlir::Value rhs, mlir::Value index, mlir::Value lowerBound, mlir::Value upperBound); + const fptr_t handlers[] = { + &handler_impl, + &handler_impl, + &handler_impl, + &handler_impl, + }; + + for (auto h : handlers) + { + if (auto c = h(pred, lhs, rhs, index_var, op.lowerBound(), op.upperBound())) + { + auto type = rewriter.getI1Type(); + auto val = rewriter.getIntegerAttr(type, *c); + auto const_val = rewriter.create(cmp.getLoc(), val); + rewriter.replaceOp(cmp, const_val.getResult()); + matched = true; + break; + } + } + } + } + return mlir::success(matched); +} diff --git a/mlir-compiler/src/rewrites/loop_rewrites.hpp b/mlir-compiler/src/rewrites/loop_rewrites.hpp new file mode 100644 index 00000000000..59b723cd562 --- /dev/null +++ b/mlir-compiler/src/rewrites/loop_rewrites.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +namespace mlir +{ +namespace scf +{ +class ForOp; +} +} + +struct CmpLoopBoundsSimplify : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; +}; diff --git a/mlir-compiler/src/transforms/cast_utils.cpp b/mlir-compiler/src/transforms/cast_utils.cpp new file mode 100644 index 00000000000..cb35b1d59b3 --- /dev/null +++ b/mlir-compiler/src/transforms/cast_utils.cpp @@ -0,0 +1,19 @@ +#include "transforms/cast_utils.hpp" + +#include + +mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, mlir::Type dst_type) +{ + auto src_type = src.getType(); + assert(src_type.isa() || dst_type.isa()); + if (src_type != dst_type) + { + return builder.create(loc, src, dst_type); + } + return src; +} + +mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src) +{ + return index_cast(builder, loc, src, mlir::IndexType::get(builder.getContext())); +} diff --git a/mlir-compiler/src/transforms/cast_utils.hpp b/mlir-compiler/src/transforms/cast_utils.hpp new file mode 100644 index 00000000000..bd5126d191a --- /dev/null +++ b/mlir-compiler/src/transforms/cast_utils.hpp @@ -0,0 +1,12 @@ +#pragma once + +namespace mlir +{ +class Value; +class Location; +class OpBuilder; +class Type; +} + +mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, mlir::Type dst_type); +mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src); diff --git a/mlir-compiler/src/transforms/loop_utils.cpp b/mlir-compiler/src/transforms/loop_utils.cpp index 2952826b89e..e09cf8f5b3c 100644 --- a/mlir-compiler/src/transforms/loop_utils.cpp +++ b/mlir-compiler/src/transforms/loop_utils.cpp @@ -10,6 +10,8 @@ #include "plier/dialect.hpp" +#include "transforms/cast_utils.hpp" + namespace { template @@ -102,11 +104,7 @@ mlir::LogicalResult lower_while_to_for( auto index_cast = [&](mlir::Value val)->mlir::Value { - if (!val.getType().isa()) - { - return builder.create(loc, val, mlir::IndexType::get(val.getContext())); - } - return val; + return ::index_cast(builder, loc, val); }; auto bounds = get_bounds(builder, loc); From 7c26311ac79993edcc6bc7b53db2caede1305938 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Feb 2021 04:37:28 +0300 Subject: [PATCH 228/259] [MLIR] Separate python-agnostic parts into lib (#172) * reorganize project structure * move dialect definitions to separate lib * move file * hide includes * move transforms and rewrites to plier project * remove llvm libs from plier * add plier namespace * move pipeline to plier * add plier namespace * move compiler infra to plier lib * plier namespace --- mlir-compiler/CMakeLists.txt | 116 +----------------- mlir-compiler/mlir-compiler/CMakeLists.txt | 71 +++++++++++ mlir-compiler/{ => mlir-compiler}/readme.md | 0 .../{ => mlir-compiler}/src/lowering.cpp | 36 +++--- .../{ => mlir-compiler}/src/lowering.hpp | 0 .../{ => mlir-compiler}/src/mangle.cpp | 0 .../{ => mlir-compiler}/src/mangle.hpp | 0 .../src/pipelines/base_pipeline.cpp | 4 +- .../src/pipelines/base_pipeline.hpp | 5 +- .../src/pipelines/lower_to_llvm.cpp | 16 +-- .../src/pipelines/lower_to_llvm.hpp | 5 +- .../src/pipelines/parallel_to_tbb.cpp | 4 +- .../src/pipelines/parallel_to_tbb.hpp | 5 +- .../src/pipelines/plier_to_linalg.cpp | 48 ++++---- .../src/pipelines/plier_to_linalg.hpp | 5 +- .../src/pipelines/plier_to_std.cpp | 26 ++-- .../src/pipelines/plier_to_std.hpp | 5 +- .../src/py_func_resolver.cpp | 0 .../src/py_func_resolver.hpp | 0 .../src/py_linalg_resolver.cpp | 10 +- .../src/py_linalg_resolver.hpp | 0 .../{ => mlir-compiler}/src/py_map_types.cpp | 0 .../{ => mlir-compiler}/src/py_map_types.hpp | 0 .../{ => mlir-compiler}/src/py_module.cpp | 0 .../{ => mlir-compiler}/src/py_module.hpp | 0 mlir-compiler/{ => mlir-compiler}/test.py | 0 mlir-compiler/plier/CMakeLists.txt | 91 ++++++++++++++ .../{ => plier}/include/plier/CMakeLists.txt | 0 .../{ => plier}/include/plier/PlierOps.td | 0 .../include/plier/compiler}/compiler.hpp | 6 +- .../plier/compiler}/pipeline_registry.hpp | 4 +- .../{ => plier}/include/plier/dialect.hpp | 0 .../include/plier}/rewrites/call_lowering.hpp | 3 + .../rewrites/canonicalize_reductions.hpp | 3 + .../include/plier}/rewrites/cast_lowering.hpp | 3 + .../include/plier}/rewrites/cse.hpp | 6 +- .../include/plier}/rewrites/force_inline.hpp | 3 + .../rewrites/index_type_propagation.hpp | 3 + .../include/plier}/rewrites/loop_rewrites.hpp | 3 + .../plier}/rewrites/promote_to_parallel.hpp | 3 + .../plier}/rewrites/type_conversion.hpp | 3 + .../include/plier}/transforms/cast_utils.hpp | 3 + .../include/plier}/transforms/const_utils.hpp | 3 + .../include/plier}/transforms/func_utils.hpp | 3 + .../include/plier}/transforms/loop_utils.hpp | 3 + .../plier}/transforms/pipeline_utils.hpp | 3 + .../{src => plier/include/plier}/utils.hpp | 3 + .../{src => plier/src/compiler}/compiler.cpp | 36 +++--- .../src/compiler}/pipeline_registry.cpp | 8 +- mlir-compiler/{ => plier}/src/dialect.cpp | 0 .../src/rewrites/call_lowering.cpp | 6 +- .../src/rewrites/canonicalize_reductions.cpp | 4 +- .../src/rewrites/cast_lowering.cpp | 6 +- .../{ => plier}/src/rewrites/cse.cpp | 6 +- .../{ => plier}/src/rewrites/force_inline.cpp | 4 +- .../src/rewrites/index_type_propagation.cpp | 4 +- .../src/rewrites/loop_rewrites.cpp | 4 +- .../src/rewrites/promote_to_parallel.cpp | 4 +- .../src/rewrites/type_conversion.cpp | 6 +- .../{ => plier}/src/transforms/cast_utils.cpp | 6 +- .../src/transforms/const_utils.cpp | 6 +- .../{ => plier}/src/transforms/func_utils.cpp | 7 +- .../{ => plier}/src/transforms/loop_utils.cpp | 8 +- .../src/transforms/pipeline_utils.cpp | 8 +- mlir-compiler/{ => plier}/src/utils.cpp | 4 +- 65 files changed, 371 insertions(+), 261 deletions(-) create mode 100644 mlir-compiler/mlir-compiler/CMakeLists.txt rename mlir-compiler/{ => mlir-compiler}/readme.md (100%) rename mlir-compiler/{ => mlir-compiler}/src/lowering.cpp (93%) rename mlir-compiler/{ => mlir-compiler}/src/lowering.hpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/mangle.cpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/mangle.hpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/base_pipeline.cpp (87%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/base_pipeline.hpp (76%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/lower_to_llvm.cpp (98%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/lower_to_llvm.hpp (57%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/parallel_to_tbb.cpp (98%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/parallel_to_tbb.hpp (57%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/plier_to_linalg.cpp (95%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/plier_to_linalg.hpp (65%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/plier_to_std.cpp (98%) rename mlir-compiler/{ => mlir-compiler}/src/pipelines/plier_to_std.hpp (75%) rename mlir-compiler/{ => mlir-compiler}/src/py_func_resolver.cpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/py_func_resolver.hpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/py_linalg_resolver.cpp (98%) rename mlir-compiler/{ => mlir-compiler}/src/py_linalg_resolver.hpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/py_map_types.cpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/py_map_types.hpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/py_module.cpp (100%) rename mlir-compiler/{ => mlir-compiler}/src/py_module.hpp (100%) rename mlir-compiler/{ => mlir-compiler}/test.py (100%) create mode 100644 mlir-compiler/plier/CMakeLists.txt rename mlir-compiler/{ => plier}/include/plier/CMakeLists.txt (100%) rename mlir-compiler/{ => plier}/include/plier/PlierOps.td (100%) rename mlir-compiler/{src => plier/include/plier/compiler}/compiler.hpp (90%) rename mlir-compiler/{src => plier/include/plier/compiler}/pipeline_registry.hpp (98%) rename mlir-compiler/{ => plier}/include/plier/dialect.hpp (100%) rename mlir-compiler/{src => plier/include/plier}/rewrites/call_lowering.hpp (97%) rename mlir-compiler/{src => plier/include/plier}/rewrites/canonicalize_reductions.hpp (94%) rename mlir-compiler/{src => plier/include/plier}/rewrites/cast_lowering.hpp (97%) rename mlir-compiler/{src => plier/include/plier}/rewrites/cse.hpp (85%) rename mlir-compiler/{src => plier/include/plier}/rewrites/force_inline.hpp (94%) rename mlir-compiler/{src => plier/include/plier}/rewrites/index_type_propagation.hpp (90%) rename mlir-compiler/{src => plier/include/plier}/rewrites/loop_rewrites.hpp (94%) rename mlir-compiler/{src => plier/include/plier}/rewrites/promote_to_parallel.hpp (94%) rename mlir-compiler/{src => plier/include/plier}/rewrites/type_conversion.hpp (96%) rename mlir-compiler/{src => plier/include/plier}/transforms/cast_utils.hpp (93%) rename mlir-compiler/{src => plier/include/plier}/transforms/const_utils.hpp (95%) rename mlir-compiler/{src => plier/include/plier}/transforms/func_utils.hpp (93%) rename mlir-compiler/{src => plier/include/plier}/transforms/loop_utils.hpp (96%) rename mlir-compiler/{src => plier/include/plier}/transforms/pipeline_utils.hpp (93%) rename mlir-compiler/{src => plier/include/plier}/utils.hpp (96%) rename mlir-compiler/{src => plier/src/compiler}/compiler.cpp (83%) rename mlir-compiler/{src => plier/src/compiler}/pipeline_registry.cpp (96%) rename mlir-compiler/{ => plier}/src/dialect.cpp (100%) rename mlir-compiler/{ => plier}/src/rewrites/call_lowering.cpp (85%) rename mlir-compiler/{ => plier}/src/rewrites/canonicalize_reductions.cpp (97%) rename mlir-compiler/{ => plier}/src/rewrites/cast_lowering.cpp (85%) rename mlir-compiler/{ => plier}/src/rewrites/cse.cpp (95%) rename mlir-compiler/{ => plier}/src/rewrites/force_inline.cpp (87%) rename mlir-compiler/{ => plier}/src/rewrites/index_type_propagation.cpp (96%) rename mlir-compiler/{ => plier}/src/rewrites/loop_rewrites.cpp (95%) rename mlir-compiler/{ => plier}/src/rewrites/promote_to_parallel.cpp (96%) rename mlir-compiler/{ => plier}/src/rewrites/type_conversion.cpp (96%) rename mlir-compiler/{ => plier}/src/transforms/cast_utils.cpp (60%) rename mlir-compiler/{ => plier}/src/transforms/const_utils.cpp (66%) rename mlir-compiler/{ => plier}/src/transforms/func_utils.cpp (69%) rename mlir-compiler/{ => plier}/src/transforms/loop_utils.cpp (97%) rename mlir-compiler/{ => plier}/src/transforms/pipeline_utils.cpp (84%) rename mlir-compiler/{ => plier}/src/utils.cpp (56%) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index be672b24219..eb63fbf04bd 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -6,117 +6,5 @@ set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -if(UNIX) - add_link_options("-Wl,--exclude-libs,ALL") -endif() - - -find_package(pybind11 REQUIRED) - -find_package(LLVM REQUIRED CONFIG) -find_package(MLIR REQUIRED CONFIG) - -list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") -list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") -include(TableGen) -include(AddLLVM) -include(AddMLIR) -include(HandleLLVMOptions) - -add_subdirectory(include/plier) - -set(SOURCES_LIST - src/pipelines/base_pipeline.cpp - src/pipelines/lower_to_llvm.cpp - src/pipelines/parallel_to_tbb.cpp - src/pipelines/plier_to_linalg.cpp - src/pipelines/plier_to_std.cpp - src/rewrites/call_lowering.cpp - src/rewrites/canonicalize_reductions.cpp - src/rewrites/cast_lowering.cpp - src/rewrites/cse.cpp - src/rewrites/force_inline.cpp - src/rewrites/index_type_propagation.cpp - src/rewrites/loop_rewrites.cpp - src/rewrites/promote_to_parallel.cpp - src/rewrites/type_conversion.cpp - src/transforms/cast_utils.cpp - src/transforms/const_utils.cpp - src/transforms/func_utils.cpp - src/transforms/loop_utils.cpp - src/transforms/pipeline_utils.cpp - src/compiler.cpp - src/dialect.cpp - src/lowering.cpp - src/mangle.cpp - src/pipeline_registry.cpp - src/py_func_resolver.cpp - src/py_linalg_resolver.cpp - src/py_map_types.cpp - src/py_module.cpp - src/utils.cpp - ) -set(HEADERS_LIST - include/plier/dialect.hpp - include/plier/PlierOps.td - src/pipelines/base_pipeline.hpp - src/pipelines/lower_to_llvm.hpp - src/pipelines/parallel_to_tbb.hpp - src/pipelines/plier_to_linalg.hpp - src/pipelines/plier_to_std.hpp - src/rewrites/call_lowering.hpp - src/rewrites/canonicalize_reductions.hpp - src/rewrites/cast_lowering.hpp - src/rewrites/cse.hpp - src/rewrites/force_inline.hpp - src/rewrites/index_type_propagation.hpp - src/rewrites/loop_rewrites.hpp - src/rewrites/promote_to_parallel.hpp - src/rewrites/type_conversion.hpp - src/transforms/cast_utils.hpp - src/transforms/const_utils.hpp - src/transforms/func_utils.hpp - src/transforms/loop_utils.hpp - src/transforms/pipeline_utils.hpp - src/compiler.hpp - src/lowering.hpp - src/mangle.hpp - src/pipeline_registry.hpp - src/py_func_resolver.hpp - src/py_linalg_resolver.hpp - src/py_map_types.hpp - src/py_module.hpp - src/utils.hpp - ) - -pybind11_add_module(${PROJECT_NAME} ${SOURCES_LIST} ${HEADERS_LIST}) - -if (MSVC) - target_compile_options(${PROJECT_NAME} PRIVATE /EHsc) -endif () - -target_compile_definitions(${PROJECT_NAME} PRIVATE ${LLVM_DEFINITIONS}) - -target_link_libraries(${PROJECT_NAME} PRIVATE - LLVM${LLVM_NATIVE_ARCH}CodeGen - LLVM${LLVM_NATIVE_ARCH}Desc - LLVMTarget - MLIRIR - MLIRLLVMIR - MLIRTargetLLVMIR - MLIRTransforms - MLIRStandardOpsTransforms - MLIRLinalgTransforms - MLIRSCFToStandard - MLIRTensorTransforms - ) - -target_include_directories(${PROJECT_NAME} PRIVATE - ./src - ./include - ${LLVM_INCLUDE_DIRS} - ${MLIR_INCLUDE_DIRS} - ${PROJECT_BINARY_DIR}/include - ) - -add_dependencies(${PROJECT_NAME} MLIRPlierOpsIncGen) +add_subdirectory(plier) +add_subdirectory(mlir-compiler) diff --git a/mlir-compiler/mlir-compiler/CMakeLists.txt b/mlir-compiler/mlir-compiler/CMakeLists.txt new file mode 100644 index 00000000000..411d0257d7b --- /dev/null +++ b/mlir-compiler/mlir-compiler/CMakeLists.txt @@ -0,0 +1,71 @@ +if(UNIX) + add_link_options("-Wl,--exclude-libs,ALL") +endif() + +find_package(pybind11 REQUIRED) + +find_package(LLVM REQUIRED CONFIG) +find_package(MLIR REQUIRED CONFIG) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) + +set(SOURCES_LIST + src/pipelines/base_pipeline.cpp + src/pipelines/lower_to_llvm.cpp + src/pipelines/parallel_to_tbb.cpp + src/pipelines/plier_to_linalg.cpp + src/pipelines/plier_to_std.cpp + src/lowering.cpp + src/mangle.cpp + src/py_func_resolver.cpp + src/py_linalg_resolver.cpp + src/py_map_types.cpp + src/py_module.cpp + ) +set(HEADERS_LIST + src/pipelines/base_pipeline.hpp + src/pipelines/lower_to_llvm.hpp + src/pipelines/parallel_to_tbb.hpp + src/pipelines/plier_to_linalg.hpp + src/pipelines/plier_to_std.hpp + src/lowering.hpp + src/mangle.hpp + src/py_func_resolver.hpp + src/py_linalg_resolver.hpp + src/py_map_types.hpp + src/py_module.hpp + ) + +pybind11_add_module(${PROJECT_NAME} ${SOURCES_LIST} ${HEADERS_LIST}) + +if (MSVC) + target_compile_options(${PROJECT_NAME} PRIVATE /EHsc) +endif () + +target_compile_definitions(${PROJECT_NAME} PRIVATE ${LLVM_DEFINITIONS}) + +target_link_libraries(${PROJECT_NAME} PRIVATE + plier + LLVM${LLVM_NATIVE_ARCH}CodeGen + LLVM${LLVM_NATIVE_ARCH}Desc + LLVMTarget + MLIRIR + MLIRLLVMIR + MLIRTargetLLVMIR + MLIRTransforms + MLIRStandardOpsTransforms + MLIRLinalgTransforms + MLIRSCFToStandard + MLIRTensorTransforms + ) + +target_include_directories(${PROJECT_NAME} PRIVATE + ./src + ${LLVM_INCLUDE_DIRS} + ${MLIR_INCLUDE_DIRS} + ) diff --git a/mlir-compiler/readme.md b/mlir-compiler/mlir-compiler/readme.md similarity index 100% rename from mlir-compiler/readme.md rename to mlir-compiler/mlir-compiler/readme.md diff --git a/mlir-compiler/src/lowering.cpp b/mlir-compiler/mlir-compiler/src/lowering.cpp similarity index 93% rename from mlir-compiler/src/lowering.cpp rename to mlir-compiler/mlir-compiler/src/lowering.cpp index b3bd3bbcdc4..e30d3c3c85a 100644 --- a/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/mlir-compiler/src/lowering.cpp @@ -18,9 +18,9 @@ #include "plier/dialect.hpp" -#include "compiler.hpp" -#include "pipeline_registry.hpp" -#include "utils.hpp" +#include "plier/compiler/compiler.hpp" +#include "plier/compiler/pipeline_registry.hpp" +#include "plier/utils.hpp" #include "pipelines/base_pipeline.hpp" #include "pipelines/parallel_to_tbb.hpp" @@ -253,7 +253,7 @@ struct plier_lowerer final } else { - report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); + plier::report_error(llvm::Twine("lower_inst not handled: \"") + py::str(inst.get_type()).cast() + "\""); } } @@ -286,7 +286,7 @@ struct plier_lowerer final name); } - report_error(llvm::Twine("lower_assign not handled: \"") + py::str(value.get_type()).cast() + "\""); + plier::report_error(llvm::Twine("lower_assign not handled: \"") + py::str(value.get_type()).cast() + "\""); } mlir::Value lower_expr(const py::handle& expr) @@ -316,7 +316,7 @@ struct plier_lowerer final return (this->*h.second)(expr); } } - report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); + plier::report_error(llvm::Twine("lower_expr not handled: \"") + op + "\""); } template @@ -409,7 +409,7 @@ struct plier_lowerer final auto py_func_name = func_name_resolver(typemap(py_func)); if (py_func_name.is_none()) { - report_error(llvm::Twine("Can't resolve function: ") + py::str(typemap(py_func)).cast()); + plier::report_error(llvm::Twine("Can't resolve function: ") + py::str(typemap(py_func)).cast()); } auto func_name = py_func_name.cast(); @@ -459,7 +459,7 @@ struct plier_lowerer final } } - report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); + plier::report_error(llvm::Twine("resolve_op not handled: \"") + py::str(op).cast() + "\""); } mlir::Value lower_getattr(const py::handle& inst) @@ -540,7 +540,7 @@ struct plier_lowerer final { return get_val(builder.getUnitAttr()); } - report_error(llvm::Twine("get_const unhandled type \"") + py::str(val.get_type()).cast() + "\""); + plier::report_error(llvm::Twine("get_const unhandled type \"") + py::str(val.get_type()).cast() + "\""); } mlir::FunctionType get_func_type(const py::handle& fnargs, const py::handle& restype) @@ -589,7 +589,7 @@ struct plier_lowerer final auto term = bb.getTerminator(); if (nullptr == term) { - report_error("broken ir: block without terminator"); + plier::report_error("broken ir: block without terminator"); } builder.setInsertionPointToEnd(&bb); @@ -615,7 +615,7 @@ struct plier_lowerer final } else { - report_error(llvm::Twine("Unhandled terminator: ") + term->getName().getStringRef()); + plier::report_error(llvm::Twine("Unhandled terminator: ") + term->getName().getStringRef()); } } } @@ -623,9 +623,9 @@ struct plier_lowerer final }; -CompilerContext::Settings get_settings(const py::handle& settings) +plier::CompilerContext::Settings get_settings(const py::handle& settings) { - CompilerContext::Settings ret; + plier::CompilerContext::Settings ret; ret.verify = settings["verify"].cast(); ret.pass_statistics = settings["pass_statistics"].cast(); ret.pass_timings = settings["pass_timings"].cast(); @@ -646,13 +646,13 @@ py::bytes gen_ll_module(mlir::ModuleOp mod) }; llvm::LLVMContext ll_ctx; std::unique_ptr ll_mod; - scoped_diag_handler(*mod.getContext(), diag_handler, [&]() + plier::scoped_diag_handler(*mod.getContext(), diag_handler, [&]() { ll_mod = mlir::translateModuleToLLVMIR(mod, ll_ctx); if (nullptr == ll_mod) { err_stream.flush(); - report_error(llvm::Twine("Cannot generate LLVM module\n") + err); + plier::report_error(llvm::Twine("Cannot generate LLVM module\n") + err); } }); assert(nullptr != ll_mod); @@ -660,7 +660,7 @@ py::bytes gen_ll_module(mlir::ModuleOp mod) return serialize_mod(*ll_mod); } -void create_pipeline(PipelineRegistry& registry) +void create_pipeline(plier::PipelineRegistry& registry) { register_base_pipeline(registry); register_lower_to_llvm_pipeline(registry); @@ -672,7 +672,7 @@ void create_pipeline(PipelineRegistry& registry) struct Module { mlir::MLIRContext context; - PipelineRegistry registry; + plier::PipelineRegistry registry; mlir::ModuleOp module; Module() @@ -688,7 +688,7 @@ void run_compiler(Module& mod, const py::object& compilation_context) auto& registry = mod.registry; auto settings = get_settings(compilation_context["compiler_settings"]); - CompilerContext compiler(context, settings, registry); + plier::CompilerContext compiler(context, settings, registry); compiler.run(module); } } diff --git a/mlir-compiler/src/lowering.hpp b/mlir-compiler/mlir-compiler/src/lowering.hpp similarity index 100% rename from mlir-compiler/src/lowering.hpp rename to mlir-compiler/mlir-compiler/src/lowering.hpp diff --git a/mlir-compiler/src/mangle.cpp b/mlir-compiler/mlir-compiler/src/mangle.cpp similarity index 100% rename from mlir-compiler/src/mangle.cpp rename to mlir-compiler/mlir-compiler/src/mangle.cpp diff --git a/mlir-compiler/src/mangle.hpp b/mlir-compiler/mlir-compiler/src/mangle.hpp similarity index 100% rename from mlir-compiler/src/mangle.hpp rename to mlir-compiler/mlir-compiler/src/mangle.hpp diff --git a/mlir-compiler/src/pipelines/base_pipeline.cpp b/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.cpp similarity index 87% rename from mlir-compiler/src/pipelines/base_pipeline.cpp rename to mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.cpp index 7e8b35b4ba4..1685305d74f 100644 --- a/mlir-compiler/src/pipelines/base_pipeline.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.cpp @@ -1,6 +1,6 @@ #include "pipelines/base_pipeline.hpp" -#include "pipeline_registry.hpp" +#include "plier/compiler/pipeline_registry.hpp" namespace { @@ -13,7 +13,7 @@ const constexpr llvm::StringRef passes[] ={ void dummy_pass_func(mlir::OpPassManager&) {} } -void register_base_pipeline(PipelineRegistry& registry) +void register_base_pipeline(plier::PipelineRegistry& registry) { for (std::size_t i = 0; i < llvm::array_lengthof(passes); ++i) { diff --git a/mlir-compiler/src/pipelines/base_pipeline.hpp b/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.hpp similarity index 76% rename from mlir-compiler/src/pipelines/base_pipeline.hpp rename to mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.hpp index 65f9747802c..a7cff3d87cc 100644 --- a/mlir-compiler/src/pipelines/base_pipeline.hpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/base_pipeline.hpp @@ -2,9 +2,12 @@ #include +namespace plier +{ class PipelineRegistry; +} -void register_base_pipeline(PipelineRegistry& registry); +void register_base_pipeline(plier::PipelineRegistry& registry); struct PipelineStage { diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp similarity index 98% rename from mlir-compiler/src/pipelines/lower_to_llvm.cpp rename to mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 12f7415e378..5d85f5ed9ba 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -22,12 +22,12 @@ #include "plier/dialect.hpp" -#include "transforms/func_utils.hpp" +#include "plier/transforms/func_utils.hpp" #include "base_pipeline.hpp" -#include "pipeline_registry.hpp" +#include "plier/compiler/pipeline_registry.hpp" -#include "utils.hpp" +#include "plier/utils.hpp" namespace { @@ -41,7 +41,7 @@ const mlir::LowerToLLVMOptions &getLLVMOptions() auto target = llvm::TargetRegistry::lookupTarget(triple, err_str); if (nullptr == target) { - report_error(llvm::Twine("Unable to get target: ") + err_str); + plier::report_error(llvm::Twine("Unable to get target: ") + err_str); } llvm::TargetOptions target_opts; std::unique_ptr machine(target->createTargetMachine(triple, llvm::sys::getHostCPUName(), "", target_opts, llvm::None)); @@ -207,7 +207,7 @@ struct MemRefConversionCache auto func_name = gen_conversion_func_name(memref_type); auto func_type = mlir::FunctionType::get(builder.getContext(),src_type, dst_type); auto loc = builder.getUnknownLoc(); - auto new_func = add_function(builder, module, func_name, func_type); + auto new_func = plier::add_function(builder, module, func_name, func_type); auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); new_func->setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); cache.insert({memref_type, new_func}); @@ -673,7 +673,7 @@ struct LowerParallel : public mlir::OpRewritePattern } }(); - auto func = add_function(rewriter, mod, func_name, func_type); + auto func = plier::add_function(rewriter, mod, func_name, func_type); copyAttrs(parent_func, func); return func; }(); @@ -736,7 +736,7 @@ struct LowerParallel : public mlir::OpRewritePattern void_ptr_type }; auto func_type = mlir::FunctionType::get(op.getContext(), args, {}); - return add_function(rewriter, mod, func_name, func_type); + return plier::add_function(rewriter, mod, func_name, func_type); }(); auto func_addr = rewriter.create(loc, func_type, rewriter.getSymbolRefAttr(outlined_func)); mlir::Value pf_args[] = { @@ -902,7 +902,7 @@ void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) } -void register_lower_to_llvm_pipeline(PipelineRegistry& registry) +void register_lower_to_llvm_pipeline(plier::PipelineRegistry& registry) { registry.register_pipeline([](auto sink) { diff --git a/mlir-compiler/src/pipelines/lower_to_llvm.hpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.hpp similarity index 57% rename from mlir-compiler/src/pipelines/lower_to_llvm.hpp rename to mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.hpp index 43ec2466314..861cff36af7 100644 --- a/mlir-compiler/src/pipelines/lower_to_llvm.hpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.hpp @@ -1,12 +1,15 @@ #pragma once +namespace plier +{ class PipelineRegistry; +} namespace llvm { class StringRef; } -void register_lower_to_llvm_pipeline(PipelineRegistry& registry); +void register_lower_to_llvm_pipeline(plier::PipelineRegistry& registry); llvm::StringRef lower_to_llvm_pipeline_name(); diff --git a/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp similarity index 98% rename from mlir-compiler/src/pipelines/parallel_to_tbb.cpp rename to mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp index 8634ec09779..dff93c14c24 100644 --- a/mlir-compiler/src/pipelines/parallel_to_tbb.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -9,7 +9,7 @@ #include "plier/dialect.hpp" -#include "pipeline_registry.hpp" +#include "plier/compiler/pipeline_registry.hpp" #include "pipelines/base_pipeline.hpp" #include "pipelines/lower_to_llvm.hpp" @@ -191,7 +191,7 @@ void populate_parallel_to_tbb_pipeline(mlir::OpPassManager& pm) } } -void register_parallel_to_tbb_pipeline(PipelineRegistry& registry) +void register_parallel_to_tbb_pipeline(plier::PipelineRegistry& registry) { registry.register_pipeline([](auto sink) { diff --git a/mlir-compiler/src/pipelines/parallel_to_tbb.hpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.hpp similarity index 57% rename from mlir-compiler/src/pipelines/parallel_to_tbb.hpp rename to mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.hpp index e19f9fde47b..cca9709169e 100644 --- a/mlir-compiler/src/pipelines/parallel_to_tbb.hpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.hpp @@ -1,12 +1,15 @@ #pragma once +namespace plier +{ class PipelineRegistry; +} namespace llvm { class StringRef; } -void register_parallel_to_tbb_pipeline(PipelineRegistry& registry); +void register_parallel_to_tbb_pipeline(plier::PipelineRegistry& registry); llvm::StringRef parallel_to_tbb_pipeline_name(); diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp similarity index 95% rename from mlir-compiler/src/pipelines/plier_to_linalg.cpp rename to mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 115a078e0b6..e2675e8c089 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -21,21 +21,21 @@ #include "plier/dialect.hpp" #include "pipelines/plier_to_std.hpp" -#include "transforms/pipeline_utils.hpp" -#include "rewrites/call_lowering.hpp" -#include "rewrites/canonicalize_reductions.hpp" -#include "rewrites/cast_lowering.hpp" -#include "rewrites/cse.hpp" -#include "rewrites/promote_to_parallel.hpp" -#include "rewrites/type_conversion.hpp" -#include "rewrites/force_inline.hpp" -#include "rewrites/index_type_propagation.hpp" -#include "rewrites/loop_rewrites.hpp" - -#include "transforms/loop_utils.hpp" + +#include "plier/transforms/pipeline_utils.hpp" +#include "plier/rewrites/call_lowering.hpp" +#include "plier/rewrites/canonicalize_reductions.hpp" +#include "plier/rewrites/cast_lowering.hpp" +#include "plier/rewrites/cse.hpp" +#include "plier/rewrites/promote_to_parallel.hpp" +#include "plier/rewrites/type_conversion.hpp" +#include "plier/rewrites/force_inline.hpp" +#include "plier/rewrites/index_type_propagation.hpp" +#include "plier/rewrites/loop_rewrites.hpp" +#include "plier/transforms/loop_utils.hpp" #include "base_pipeline.hpp" -#include "pipeline_registry.hpp" +#include "plier/compiler/pipeline_registry.hpp" #include "py_linalg_resolver.hpp" #include @@ -123,7 +123,7 @@ void rerun_std_pipeline(mlir::Operation* op) auto marker = mlir::StringAttr::get(plier_to_std_pipeline_name(), op->getContext()); auto mod = op->getParentOfType(); assert(nullptr != mod); - add_pipeline_jump_marker(mod, marker); + plier::add_pipeline_jump_marker(mod, marker); } bool is_int(mlir::Type type) @@ -528,22 +528,22 @@ void PlierToLinalgPass::runOnOperation() mlir::OwningRewritePatternList patterns; patterns.insert< - FuncOpSignatureConversion, - CastOpLowering, + plier::FuncOpSignatureConversion, + plier::CastOpLowering, ArrayShape >(type_converter, context); CallLowerer callLowerer; patterns.insert< - CallOpLowering + plier::CallOpLowering >(type_converter, context, callLowerer); patterns.insert< GetitemOpLowering, GetitemOpLowering, SetitemOpLowering, - ForceInline + plier::ForceInline >(&getContext()); // range/prange lowering need dead branch pruning to properly @@ -632,14 +632,14 @@ void PostLinalgOptPass::runOnOperation() } patterns.insert< - CanonicalizeReduction, + plier::CanonicalizeReduction, // LoopInvariantCodeMotion, TODO - PromoteToParallel, - CmpLoopBoundsSimplify, - CSERewrite + plier::PromoteToParallel, + plier::CmpLoopBoundsSimplify, + plier::CSERewrite >(&context); - populate_index_propagate_patterns(context, patterns); + plier::populate_index_propagate_patterns(context, patterns); mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } @@ -669,7 +669,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) } } -void register_plier_to_linalg_pipeline(PipelineRegistry& registry) +void register_plier_to_linalg_pipeline(plier::PipelineRegistry& registry) { registry.register_pipeline([](auto sink) { diff --git a/mlir-compiler/src/pipelines/plier_to_linalg.hpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.hpp similarity index 65% rename from mlir-compiler/src/pipelines/plier_to_linalg.hpp rename to mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.hpp index f18fa229470..660b0e51703 100644 --- a/mlir-compiler/src/pipelines/plier_to_linalg.hpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.hpp @@ -1,13 +1,16 @@ #pragma once +namespace plier +{ class PipelineRegistry; +} namespace llvm { class StringRef; } -void register_plier_to_linalg_pipeline(PipelineRegistry& registry); +void register_plier_to_linalg_pipeline(plier::PipelineRegistry& registry); llvm::StringRef plier_to_linalg_gen_pipeline_name(); llvm::StringRef plier_to_linalg_opt_pipeline_name(); diff --git a/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp similarity index 98% rename from mlir-compiler/src/pipelines/plier_to_std.cpp rename to mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index 27faec1d3a2..bde03c75f3f 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -15,15 +15,15 @@ #include "plier/dialect.hpp" -#include "rewrites/call_lowering.hpp" -#include "rewrites/cast_lowering.hpp" -#include "rewrites/type_conversion.hpp" -#include "transforms/const_utils.hpp" -#include "transforms/func_utils.hpp" -#include "transforms/loop_utils.hpp" +#include "plier/rewrites/call_lowering.hpp" +#include "plier/rewrites/cast_lowering.hpp" +#include "plier/rewrites/type_conversion.hpp" +#include "plier/transforms/const_utils.hpp" +#include "plier/transforms/func_utils.hpp" +#include "plier/transforms/loop_utils.hpp" #include "base_pipeline.hpp" -#include "pipeline_registry.hpp" +#include "plier/compiler/pipeline_registry.hpp" #include "py_func_resolver.hpp" #include "mangle.hpp" @@ -1097,7 +1097,7 @@ struct FoldTupleGetitem : public mlir::OpRewritePattern return mlir::failure(); } - if (auto val = getConstVal(op.getOperand(1))) + if (auto val = plier::getConstVal(op.getOperand(1))) { auto index = val.getInt(); if (index >= 0 && index < build_tuple.getNumOperands()) @@ -1200,7 +1200,7 @@ mlir::FuncOp get_lib_symbol( return op; } - return add_function(rewriter, mod, name, type); + return plier::add_function(rewriter, mod, name, type); } mlir::LogicalResult lower_math_func( @@ -1317,7 +1317,7 @@ void PlierToStdPass::runOnOperation() mlir::OwningRewritePatternList patterns; patterns.insert< - FuncOpSignatureConversion, + plier::FuncOpSignatureConversion, ArgOpLowering, ReturnOpLowering, ConstOpLowering, @@ -1333,13 +1333,13 @@ void PlierToStdPass::runOnOperation() >(type_converter, context); patterns.insert< - CastOpLowering + plier::CastOpLowering >(type_converter, context, &do_cast); CallLowerer callLowerer; patterns.insert< - CallOpLowering + plier::CallOpLowering >(type_converter, context, callLowerer); mlir::populateStdExpandOpsPatterns(context, patterns); @@ -1382,7 +1382,7 @@ void populate_std_type_converter(mlir::MLIRContext& context, mlir::TypeConverter }); } -void register_plier_to_std_pipeline(PipelineRegistry& registry) +void register_plier_to_std_pipeline(plier::PipelineRegistry& registry) { registry.register_pipeline([](auto sink) { diff --git a/mlir-compiler/src/pipelines/plier_to_std.hpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.hpp similarity index 75% rename from mlir-compiler/src/pipelines/plier_to_std.hpp rename to mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.hpp index 80afadac4c9..c71cb6aa07c 100644 --- a/mlir-compiler/src/pipelines/plier_to_std.hpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.hpp @@ -1,6 +1,9 @@ #pragma once +namespace plier +{ class PipelineRegistry; +} namespace llvm { @@ -15,6 +18,6 @@ class TypeConverter; void populate_std_type_converter(mlir::MLIRContext& context, mlir::TypeConverter& converter); -void register_plier_to_std_pipeline(PipelineRegistry& registry); +void register_plier_to_std_pipeline(plier::PipelineRegistry& registry); llvm::StringRef plier_to_std_pipeline_name(); diff --git a/mlir-compiler/src/py_func_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_func_resolver.cpp similarity index 100% rename from mlir-compiler/src/py_func_resolver.cpp rename to mlir-compiler/mlir-compiler/src/py_func_resolver.cpp diff --git a/mlir-compiler/src/py_func_resolver.hpp b/mlir-compiler/mlir-compiler/src/py_func_resolver.hpp similarity index 100% rename from mlir-compiler/src/py_func_resolver.hpp rename to mlir-compiler/mlir-compiler/src/py_func_resolver.hpp diff --git a/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp similarity index 98% rename from mlir-compiler/src/py_linalg_resolver.cpp rename to mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index bda3568a4a0..295abbda78e 100644 --- a/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -13,7 +13,7 @@ #include "plier/dialect.hpp" #include "py_map_types.hpp" -#include "utils.hpp" +#include "plier/utils.hpp" namespace py = pybind11; @@ -277,7 +277,7 @@ mlir::Attribute zero_attr(mlir::Type type) { return mlir::FloatAttr::get(type, 0.0); } - report_error("zero_attr: unhandled type"); + plier::report_error("zero_attr: unhandled type"); } py::object broadcast_impl(py::capsule /*context*/, py::tuple args) @@ -400,13 +400,13 @@ py::object from_elements_impl(py::capsule context, py::handle values, py::capsul { return mlir::FloatAttr::get(type, obj.cast()); } - report_error("Invalid dtype"); + plier::report_error("Invalid dtype"); }(); vals[index] = builder.create(loc, attr); } else { - report_error("Invalid element type"); + plier::report_error("Invalid element type"); } }); auto res = builder.create(loc, vals); @@ -432,7 +432,7 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice } else { - report_error("Invalid element type"); + plier::report_error("Invalid element type"); } }); auto res = builder.create(loc, get_var_value(value), ind); diff --git a/mlir-compiler/src/py_linalg_resolver.hpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp similarity index 100% rename from mlir-compiler/src/py_linalg_resolver.hpp rename to mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp diff --git a/mlir-compiler/src/py_map_types.cpp b/mlir-compiler/mlir-compiler/src/py_map_types.cpp similarity index 100% rename from mlir-compiler/src/py_map_types.cpp rename to mlir-compiler/mlir-compiler/src/py_map_types.cpp diff --git a/mlir-compiler/src/py_map_types.hpp b/mlir-compiler/mlir-compiler/src/py_map_types.hpp similarity index 100% rename from mlir-compiler/src/py_map_types.hpp rename to mlir-compiler/mlir-compiler/src/py_map_types.hpp diff --git a/mlir-compiler/src/py_module.cpp b/mlir-compiler/mlir-compiler/src/py_module.cpp similarity index 100% rename from mlir-compiler/src/py_module.cpp rename to mlir-compiler/mlir-compiler/src/py_module.cpp diff --git a/mlir-compiler/src/py_module.hpp b/mlir-compiler/mlir-compiler/src/py_module.hpp similarity index 100% rename from mlir-compiler/src/py_module.hpp rename to mlir-compiler/mlir-compiler/src/py_module.hpp diff --git a/mlir-compiler/test.py b/mlir-compiler/mlir-compiler/test.py similarity index 100% rename from mlir-compiler/test.py rename to mlir-compiler/mlir-compiler/test.py diff --git a/mlir-compiler/plier/CMakeLists.txt b/mlir-compiler/plier/CMakeLists.txt new file mode 100644 index 00000000000..52c1d6ba4de --- /dev/null +++ b/mlir-compiler/plier/CMakeLists.txt @@ -0,0 +1,91 @@ +if(UNIX) + add_link_options("-Wl,--exclude-libs,ALL") +endif() + +find_package(LLVM REQUIRED CONFIG) +find_package(MLIR REQUIRED CONFIG) + +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(TableGen) +include(AddLLVM) +include(AddMLIR) +include(HandleLLVMOptions) + +add_subdirectory(include/plier) + +set(SOURCES_LIST + src/compiler/compiler.cpp + src/compiler/pipeline_registry.cpp + src/dialect.cpp + src/rewrites/call_lowering.cpp + src/rewrites/canonicalize_reductions.cpp + src/rewrites/cast_lowering.cpp + src/rewrites/cse.cpp + src/rewrites/force_inline.cpp + src/rewrites/index_type_propagation.cpp + src/rewrites/loop_rewrites.cpp + src/rewrites/promote_to_parallel.cpp + src/rewrites/type_conversion.cpp + src/transforms/cast_utils.cpp + src/transforms/const_utils.cpp + src/transforms/func_utils.cpp + src/transforms/loop_utils.cpp + src/transforms/pipeline_utils.cpp + src/utils.cpp + ) +set(HEADERS_LIST + include/plier/compiler/compiler.hpp + include/plier/compiler/pipeline_registry.hpp + include/plier/dialect.hpp + include/plier/PlierOps.td + include/plier/rewrites/call_lowering.hpp + include/plier/rewrites/canonicalize_reductions.hpp + include/plier/rewrites/cast_lowering.hpp + include/plier/rewrites/cse.hpp + include/plier/rewrites/force_inline.hpp + include/plier/rewrites/index_type_propagation.hpp + include/plier/rewrites/loop_rewrites.hpp + include/plier/rewrites/promote_to_parallel.hpp + include/plier/rewrites/type_conversion.hpp + include/plier/transforms/cast_utils.hpp + include/plier/transforms/const_utils.hpp + include/plier/transforms/func_utils.hpp + include/plier/transforms/loop_utils.hpp + include/plier/transforms/pipeline_utils.hpp + include/plier/utils.hpp + ) + +set(PLIER_LIB "plier") + +add_library(${PLIER_LIB} STATIC ${SOURCES_LIST} ${HEADERS_LIST}) + +if (MSVC) + target_compile_options(${PLIER_LIB} PRIVATE /EHsc) +endif () + +target_compile_definitions(${PLIER_LIB} PRIVATE ${LLVM_DEFINITIONS}) + +target_link_libraries(${PLIER_LIB} PRIVATE + MLIRIR + MLIRLLVMIR + MLIRTargetLLVMIR + MLIRTransforms + MLIRStandardOpsTransforms + MLIRLinalgTransforms + MLIRSCFToStandard + MLIRTensorTransforms + ) + +target_include_directories(${PLIER_LIB} PRIVATE + ./src + ${LLVM_INCLUDE_DIRS} + ${MLIR_INCLUDE_DIRS} + ) + +target_include_directories(${PLIER_LIB} PUBLIC + ./include + ${PROJECT_BINARY_DIR}/include + ) + +add_dependencies(${PLIER_LIB} MLIRPlierOpsIncGen) diff --git a/mlir-compiler/include/plier/CMakeLists.txt b/mlir-compiler/plier/include/plier/CMakeLists.txt similarity index 100% rename from mlir-compiler/include/plier/CMakeLists.txt rename to mlir-compiler/plier/include/plier/CMakeLists.txt diff --git a/mlir-compiler/include/plier/PlierOps.td b/mlir-compiler/plier/include/plier/PlierOps.td similarity index 100% rename from mlir-compiler/include/plier/PlierOps.td rename to mlir-compiler/plier/include/plier/PlierOps.td diff --git a/mlir-compiler/src/compiler.hpp b/mlir-compiler/plier/include/plier/compiler/compiler.hpp similarity index 90% rename from mlir-compiler/src/compiler.hpp rename to mlir-compiler/plier/include/plier/compiler/compiler.hpp index f08538c77be..7d18abbfcbf 100644 --- a/mlir-compiler/src/compiler.hpp +++ b/mlir-compiler/plier/include/plier/compiler/compiler.hpp @@ -8,7 +8,10 @@ class MLIRContext; class ModuleOp; } +namespace plier +{ class PipelineRegistry; + class CompilerContext { public: @@ -33,5 +36,4 @@ class CompilerContext private: std::unique_ptr impl; }; - -void run_compiler(CompilerContext& context, mlir::ModuleOp module); +} diff --git a/mlir-compiler/src/pipeline_registry.hpp b/mlir-compiler/plier/include/plier/compiler/pipeline_registry.hpp similarity index 98% rename from mlir-compiler/src/pipeline_registry.hpp rename to mlir-compiler/plier/include/plier/compiler/pipeline_registry.hpp index e742e3b7360..1c7e20272f5 100644 --- a/mlir-compiler/src/pipeline_registry.hpp +++ b/mlir-compiler/plier/include/plier/compiler/pipeline_registry.hpp @@ -12,7 +12,8 @@ namespace mlir class OpPassManager; } - +namespace plier +{ class PipelineRegistry { public: @@ -38,3 +39,4 @@ class PipelineRegistry private: std::vector pipelines; }; +} diff --git a/mlir-compiler/include/plier/dialect.hpp b/mlir-compiler/plier/include/plier/dialect.hpp similarity index 100% rename from mlir-compiler/include/plier/dialect.hpp rename to mlir-compiler/plier/include/plier/dialect.hpp diff --git a/mlir-compiler/src/rewrites/call_lowering.hpp b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp similarity index 97% rename from mlir-compiler/src/rewrites/call_lowering.hpp rename to mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp index 06200938f16..c686004d5e1 100644 --- a/mlir-compiler/src/rewrites/call_lowering.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp @@ -11,6 +11,8 @@ namespace mlir class TypeConverter; } +namespace plier +{ struct CallOpLowering : public mlir::OpRewritePattern { using resolver_t = llvm::function_ref, mlir::PatternRewriter&)>; @@ -25,3 +27,4 @@ struct CallOpLowering : public mlir::OpRewritePattern private: resolver_t resolver; }; +} diff --git a/mlir-compiler/src/rewrites/canonicalize_reductions.hpp b/mlir-compiler/plier/include/plier/rewrites/canonicalize_reductions.hpp similarity index 94% rename from mlir-compiler/src/rewrites/canonicalize_reductions.hpp rename to mlir-compiler/plier/include/plier/rewrites/canonicalize_reductions.hpp index 1936b067fb3..45f66b7590d 100644 --- a/mlir-compiler/src/rewrites/canonicalize_reductions.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/canonicalize_reductions.hpp @@ -10,6 +10,8 @@ class ForOp; } } +namespace plier +{ struct CanonicalizeReduction : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -17,3 +19,4 @@ struct CanonicalizeReduction : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite( mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; }; +} diff --git a/mlir-compiler/src/rewrites/cast_lowering.hpp b/mlir-compiler/plier/include/plier/rewrites/cast_lowering.hpp similarity index 97% rename from mlir-compiler/src/rewrites/cast_lowering.hpp rename to mlir-compiler/plier/include/plier/rewrites/cast_lowering.hpp index bb3fcc30dc0..a003eca8568 100644 --- a/mlir-compiler/src/rewrites/cast_lowering.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/cast_lowering.hpp @@ -11,6 +11,8 @@ namespace mlir class TypeConverter; } +namespace plier +{ struct CastOpLowering : public mlir::OpRewritePattern { using cast_t = std::function; @@ -26,3 +28,4 @@ struct CastOpLowering : public mlir::OpRewritePattern mlir::TypeConverter& converter; cast_t cast_func; }; +} diff --git a/mlir-compiler/src/rewrites/cse.hpp b/mlir-compiler/plier/include/plier/rewrites/cse.hpp similarity index 85% rename from mlir-compiler/src/rewrites/cse.hpp rename to mlir-compiler/plier/include/plier/rewrites/cse.hpp index 9d2675160d5..fa29039d3ad 100644 --- a/mlir-compiler/src/rewrites/cse.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/cse.hpp @@ -3,13 +3,12 @@ #include #include -namespace CSE +namespace plier { namespace detail { mlir::LogicalResult applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter); } -} template struct CSERewrite : public mlir::OpRewritePattern @@ -20,6 +19,7 @@ struct CSERewrite : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter &rewriter) const override { - return ::CSE::detail::applyCSE(op.getRegion(), rewriter); + return ::plier::detail::applyCSE(op.getRegion(), rewriter); } }; +} diff --git a/mlir-compiler/src/rewrites/force_inline.hpp b/mlir-compiler/plier/include/plier/rewrites/force_inline.hpp similarity index 94% rename from mlir-compiler/src/rewrites/force_inline.hpp rename to mlir-compiler/plier/include/plier/rewrites/force_inline.hpp index 881ae3f493f..5c518c71211 100644 --- a/mlir-compiler/src/rewrites/force_inline.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/force_inline.hpp @@ -7,6 +7,8 @@ namespace mlir class CallOp; } +namespace plier +{ struct ForceInline : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -14,3 +16,4 @@ struct ForceInline : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite( mlir::CallOp op, mlir::PatternRewriter &rewriter) const override; }; +} diff --git a/mlir-compiler/src/rewrites/index_type_propagation.hpp b/mlir-compiler/plier/include/plier/rewrites/index_type_propagation.hpp similarity index 90% rename from mlir-compiler/src/rewrites/index_type_propagation.hpp rename to mlir-compiler/plier/include/plier/rewrites/index_type_propagation.hpp index 142c3b6a35c..d5e1f5c14c3 100644 --- a/mlir-compiler/src/rewrites/index_type_propagation.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/index_type_propagation.hpp @@ -6,4 +6,7 @@ class OwningRewritePatternList; class MLIRContext; } +namespace plier +{ void populate_index_propagate_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns); +} diff --git a/mlir-compiler/src/rewrites/loop_rewrites.hpp b/mlir-compiler/plier/include/plier/rewrites/loop_rewrites.hpp similarity index 94% rename from mlir-compiler/src/rewrites/loop_rewrites.hpp rename to mlir-compiler/plier/include/plier/rewrites/loop_rewrites.hpp index 59b723cd562..9e2ecf684af 100644 --- a/mlir-compiler/src/rewrites/loop_rewrites.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/loop_rewrites.hpp @@ -10,6 +10,8 @@ class ForOp; } } +namespace plier +{ struct CmpLoopBoundsSimplify : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -17,3 +19,4 @@ struct CmpLoopBoundsSimplify : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite( mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; }; +} diff --git a/mlir-compiler/src/rewrites/promote_to_parallel.hpp b/mlir-compiler/plier/include/plier/rewrites/promote_to_parallel.hpp similarity index 94% rename from mlir-compiler/src/rewrites/promote_to_parallel.hpp rename to mlir-compiler/plier/include/plier/rewrites/promote_to_parallel.hpp index f246ce0696d..eabbd31d359 100644 --- a/mlir-compiler/src/rewrites/promote_to_parallel.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/promote_to_parallel.hpp @@ -10,6 +10,8 @@ class ForOp; } } +namespace plier +{ struct PromoteToParallel : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -17,3 +19,4 @@ struct PromoteToParallel : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite( mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; }; +} diff --git a/mlir-compiler/src/rewrites/type_conversion.hpp b/mlir-compiler/plier/include/plier/rewrites/type_conversion.hpp similarity index 96% rename from mlir-compiler/src/rewrites/type_conversion.hpp rename to mlir-compiler/plier/include/plier/rewrites/type_conversion.hpp index 638cb48b5dc..80fa189d476 100644 --- a/mlir-compiler/src/rewrites/type_conversion.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/type_conversion.hpp @@ -8,6 +8,8 @@ namespace mlir class TypeConverter; } +namespace plier +{ struct FuncOpSignatureConversion : public mlir::OpRewritePattern { FuncOpSignatureConversion(mlir::TypeConverter& conv, @@ -20,3 +22,4 @@ struct FuncOpSignatureConversion : public mlir::OpRewritePattern private: mlir::TypeConverter& converter; }; +} diff --git a/mlir-compiler/src/transforms/cast_utils.hpp b/mlir-compiler/plier/include/plier/transforms/cast_utils.hpp similarity index 93% rename from mlir-compiler/src/transforms/cast_utils.hpp rename to mlir-compiler/plier/include/plier/transforms/cast_utils.hpp index bd5126d191a..d36a4e95719 100644 --- a/mlir-compiler/src/transforms/cast_utils.hpp +++ b/mlir-compiler/plier/include/plier/transforms/cast_utils.hpp @@ -8,5 +8,8 @@ class OpBuilder; class Type; } +namespace plier +{ mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, mlir::Type dst_type); mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src); +} diff --git a/mlir-compiler/src/transforms/const_utils.hpp b/mlir-compiler/plier/include/plier/transforms/const_utils.hpp similarity index 95% rename from mlir-compiler/src/transforms/const_utils.hpp rename to mlir-compiler/plier/include/plier/transforms/const_utils.hpp index f2c62e507f5..7c8c5169c9a 100644 --- a/mlir-compiler/src/transforms/const_utils.hpp +++ b/mlir-compiler/plier/include/plier/transforms/const_utils.hpp @@ -8,6 +8,8 @@ class Operation; class Value; } +namespace plier +{ mlir::Attribute getConstVal(mlir::Operation* op); mlir::Attribute getConstVal(mlir::Value op); @@ -22,3 +24,4 @@ T getConstVal(mlir::Value op) { return getConstVal(op).dyn_cast_or_null(); } +} diff --git a/mlir-compiler/src/transforms/func_utils.hpp b/mlir-compiler/plier/include/plier/transforms/func_utils.hpp similarity index 93% rename from mlir-compiler/src/transforms/func_utils.hpp rename to mlir-compiler/plier/include/plier/transforms/func_utils.hpp index 242696e1115..8065ddc8ede 100644 --- a/mlir-compiler/src/transforms/func_utils.hpp +++ b/mlir-compiler/plier/include/plier/transforms/func_utils.hpp @@ -13,5 +13,8 @@ namespace llvm class StringRef; } +namespace plier +{ mlir::FuncOp add_function(mlir::OpBuilder& builder, mlir::ModuleOp module, llvm::StringRef name, mlir::FunctionType type); +} diff --git a/mlir-compiler/src/transforms/loop_utils.hpp b/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp similarity index 96% rename from mlir-compiler/src/transforms/loop_utils.hpp rename to mlir-compiler/plier/include/plier/transforms/loop_utils.hpp index 389ab881db4..b5dec36ca62 100644 --- a/mlir-compiler/src/transforms/loop_utils.hpp +++ b/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp @@ -21,7 +21,10 @@ namespace plier class GetiterOp; } +namespace plier +{ mlir::LogicalResult lower_while_to_for(plier::GetiterOp getiter, mlir::PatternRewriter& builder, llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, llvm::function_ref get_iter_val, llvm::function_ref results = nullptr); +} diff --git a/mlir-compiler/src/transforms/pipeline_utils.hpp b/mlir-compiler/plier/include/plier/transforms/pipeline_utils.hpp similarity index 93% rename from mlir-compiler/src/transforms/pipeline_utils.hpp rename to mlir-compiler/plier/include/plier/transforms/pipeline_utils.hpp index 7b53b6acd8a..0d80ebfb1eb 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.hpp +++ b/mlir-compiler/plier/include/plier/transforms/pipeline_utils.hpp @@ -7,6 +7,9 @@ class ModuleOp; class StringAttr; } +namespace plier +{ mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module); void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name); void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name); +} diff --git a/mlir-compiler/src/utils.hpp b/mlir-compiler/plier/include/plier/utils.hpp similarity index 96% rename from mlir-compiler/src/utils.hpp rename to mlir-compiler/plier/include/plier/utils.hpp index 1e026c45137..83610967d66 100644 --- a/mlir-compiler/src/utils.hpp +++ b/mlir-compiler/plier/include/plier/utils.hpp @@ -9,6 +9,8 @@ namespace llvm class Twine; } +namespace plier +{ [[noreturn]] void report_error(const llvm::Twine& msg); template @@ -22,3 +24,4 @@ void scoped_diag_handler(T& ctx, H&& diag_handler, F&& func) }); func(); } +} diff --git a/mlir-compiler/src/compiler.cpp b/mlir-compiler/plier/src/compiler/compiler.cpp similarity index 83% rename from mlir-compiler/src/compiler.cpp rename to mlir-compiler/plier/src/compiler/compiler.cpp index 9aa5dea96c1..0a0a1a3cb19 100644 --- a/mlir-compiler/src/compiler.cpp +++ b/mlir-compiler/plier/src/compiler/compiler.cpp @@ -1,4 +1,4 @@ -#include "compiler.hpp" +#include "plier/compiler/compiler.hpp" #include #include @@ -10,11 +10,11 @@ #include -#include "utils.hpp" +#include "plier/utils.hpp" -#include "pipeline_registry.hpp" +#include "plier/compiler/pipeline_registry.hpp" -#include "transforms/pipeline_utils.hpp" +#include "plier/transforms/pipeline_utils.hpp" namespace { @@ -22,7 +22,7 @@ struct PassManagerStage { template PassManagerStage(mlir::MLIRContext& ctx, - const CompilerContext::Settings& settings, + const plier::CompilerContext::Settings& settings, F&& init_func): pm(&ctx) { @@ -97,8 +97,8 @@ struct PassManagerStage struct PassManagerSchedule { PassManagerSchedule(mlir::MLIRContext& ctx, - const CompilerContext::Settings& settings, - const PipelineRegistry& registry) + const plier::CompilerContext::Settings& settings, + const plier::PipelineRegistry& registry) { auto func = [&](auto sink) { @@ -161,11 +161,11 @@ struct PassManagerSchedule { return mlir::failure(); } - auto markers = get_pipeline_jump_markers(module); + auto markers = plier::get_pipeline_jump_markers(module); auto jump_target = current->get_jump(markers); if (nullptr != jump_target.first) { - remove_pipeline_jump_marker(module, jump_target.second); + plier::remove_pipeline_jump_marker(module, jump_target.second); current = jump_target.first; } else @@ -182,12 +182,12 @@ struct PassManagerSchedule }; } -class CompilerContext::CompilerContextImpl +class plier::CompilerContext::CompilerContextImpl { public: CompilerContextImpl(mlir::MLIRContext& ctx, const CompilerContext::Settings& settings, - const PipelineRegistry& registry): + const plier::PipelineRegistry& registry): schedule(ctx, settings, registry) {} void run(mlir::ModuleOp module) @@ -202,14 +202,14 @@ class CompilerContext::CompilerContextImpl } }; - scoped_diag_handler(*module.getContext(), diag_handler, [&]() + plier::scoped_diag_handler(*module.getContext(), diag_handler, [&]() { if (mlir::failed(schedule.run(module))) { err_stream << "\n"; module.print(err_stream); err_stream.flush(); - report_error(llvm::Twine("MLIR pipeline failed\n") + err); + plier::report_error(llvm::Twine("MLIR pipeline failed\n") + err); } }); } @@ -217,20 +217,20 @@ class CompilerContext::CompilerContextImpl PassManagerSchedule schedule; }; -CompilerContext::CompilerContext(mlir::MLIRContext& ctx, - const Settings& settings, - const PipelineRegistry& registry): +plier::CompilerContext::CompilerContext(mlir::MLIRContext& ctx, + const Settings& settings, + const PipelineRegistry& registry): impl(std::make_unique(ctx, settings, registry)) { } -CompilerContext::~CompilerContext() +plier::CompilerContext::~CompilerContext() { } -void CompilerContext::run(mlir::ModuleOp module) +void plier::CompilerContext::run(mlir::ModuleOp module) { impl->run(module); } diff --git a/mlir-compiler/src/pipeline_registry.cpp b/mlir-compiler/plier/src/compiler/pipeline_registry.cpp similarity index 96% rename from mlir-compiler/src/pipeline_registry.cpp rename to mlir-compiler/plier/src/compiler/pipeline_registry.cpp index e62ae0c1a77..8a16446bc5f 100644 --- a/mlir-compiler/src/pipeline_registry.cpp +++ b/mlir-compiler/plier/src/compiler/pipeline_registry.cpp @@ -1,16 +1,16 @@ -#include "pipeline_registry.hpp" +#include "plier/compiler/pipeline_registry.hpp" #include #include #include -#include "utils.hpp" +#include "plier/utils.hpp" #include #include #include -void PipelineRegistry::register_pipeline(PipelineRegistry::registry_entry_t func) +void plier::PipelineRegistry::register_pipeline(PipelineRegistry::registry_entry_t func) { assert(nullptr != func); pipelines.push_back(std::move(func)); @@ -34,7 +34,7 @@ void topo_visit(T& elem, IterF&& iter_func, VisitF&& func) } } -void PipelineRegistry::populate_pass_manager(populate_pass_manager_t result_sink) const +void plier::PipelineRegistry::populate_pass_manager(populate_pass_manager_t result_sink) const { llvm::BumpPtrAllocator allocator; llvm::UniqueStringSaver string_set(allocator); diff --git a/mlir-compiler/src/dialect.cpp b/mlir-compiler/plier/src/dialect.cpp similarity index 100% rename from mlir-compiler/src/dialect.cpp rename to mlir-compiler/plier/src/dialect.cpp diff --git a/mlir-compiler/src/rewrites/call_lowering.cpp b/mlir-compiler/plier/src/rewrites/call_lowering.cpp similarity index 85% rename from mlir-compiler/src/rewrites/call_lowering.cpp rename to mlir-compiler/plier/src/rewrites/call_lowering.cpp index 36295f03ee9..53e93c30d60 100644 --- a/mlir-compiler/src/rewrites/call_lowering.cpp +++ b/mlir-compiler/plier/src/rewrites/call_lowering.cpp @@ -1,11 +1,11 @@ -#include "call_lowering.hpp" +#include "plier/rewrites/call_lowering.hpp" -CallOpLowering::CallOpLowering( +plier::CallOpLowering::CallOpLowering( mlir::TypeConverter&, mlir::MLIRContext* context, CallOpLowering::resolver_t resolver): OpRewritePattern(context), resolver(resolver) {} -mlir::LogicalResult CallOpLowering::matchAndRewrite(plier::PyCallOp op, mlir::PatternRewriter& rewriter) const +mlir::LogicalResult plier::CallOpLowering::matchAndRewrite(plier::PyCallOp op, mlir::PatternRewriter& rewriter) const { auto operands = op.getOperands(); if (operands.empty()) diff --git a/mlir-compiler/src/rewrites/canonicalize_reductions.cpp b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp similarity index 97% rename from mlir-compiler/src/rewrites/canonicalize_reductions.cpp rename to mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp index ffaed84b509..db48de3bc9d 100644 --- a/mlir-compiler/src/rewrites/canonicalize_reductions.cpp +++ b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp @@ -1,4 +1,4 @@ -#include "rewrites/canonicalize_reductions.hpp" +#include "plier/rewrites/canonicalize_reductions.hpp" #include #include @@ -105,7 +105,7 @@ void createScalarStore( } } -mlir::LogicalResult CanonicalizeReduction::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const +mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const { llvm::SmallVector to_process; op.walk([&](mlir::LoadOp load) diff --git a/mlir-compiler/src/rewrites/cast_lowering.cpp b/mlir-compiler/plier/src/rewrites/cast_lowering.cpp similarity index 85% rename from mlir-compiler/src/rewrites/cast_lowering.cpp rename to mlir-compiler/plier/src/rewrites/cast_lowering.cpp index c2a4e2e0a57..b30ad1da4a1 100644 --- a/mlir-compiler/src/rewrites/cast_lowering.cpp +++ b/mlir-compiler/plier/src/rewrites/cast_lowering.cpp @@ -1,14 +1,14 @@ -#include "rewrites/cast_lowering.hpp" +#include "plier/rewrites/cast_lowering.hpp" #include -CastOpLowering::CastOpLowering( +plier::CastOpLowering::CastOpLowering( mlir::TypeConverter& typeConverter, mlir::MLIRContext* context, CastOpLowering::cast_t cast_func): OpRewritePattern(context), converter(typeConverter), cast_func(std::move(cast_func)) {} -mlir::LogicalResult CastOpLowering::matchAndRewrite( +mlir::LogicalResult plier::CastOpLowering::matchAndRewrite( plier::CastOp op, mlir::PatternRewriter& rewriter) const { auto src = op.getOperand(); diff --git a/mlir-compiler/src/rewrites/cse.cpp b/mlir-compiler/plier/src/rewrites/cse.cpp similarity index 95% rename from mlir-compiler/src/rewrites/cse.cpp rename to mlir-compiler/plier/src/rewrites/cse.cpp index ee4dc3eb203..9a39e6ba923 100644 --- a/mlir-compiler/src/rewrites/cse.cpp +++ b/mlir-compiler/plier/src/rewrites/cse.cpp @@ -1,4 +1,4 @@ -#include "rewrites/cse.hpp" +#include "plier/rewrites/cse.hpp" #include #include @@ -78,9 +78,7 @@ mlir::LogicalResult simplifyRegion(ScopedMapTy& map, mlir::Region& region, mlir: } } - - -mlir::LogicalResult CSE::detail::applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter) +mlir::LogicalResult plier::detail::applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter) { ScopedMapTy map; ScopedMapTy::ScopeTy scope(map); diff --git a/mlir-compiler/src/rewrites/force_inline.cpp b/mlir-compiler/plier/src/rewrites/force_inline.cpp similarity index 87% rename from mlir-compiler/src/rewrites/force_inline.cpp rename to mlir-compiler/plier/src/rewrites/force_inline.cpp index 8eb16c3f48f..a6b0b5440f2 100644 --- a/mlir-compiler/src/rewrites/force_inline.cpp +++ b/mlir-compiler/plier/src/rewrites/force_inline.cpp @@ -1,11 +1,11 @@ -#include "rewrites/force_inline.hpp" +#include "plier/rewrites/force_inline.hpp" #include #include #include "plier/dialect.hpp" -mlir::LogicalResult ForceInline::matchAndRewrite(mlir::CallOp op, mlir::PatternRewriter& rewriter) const +mlir::LogicalResult plier::ForceInline::matchAndRewrite(mlir::CallOp op, mlir::PatternRewriter& rewriter) const { auto attr_name = plier::attributes::getForceInlineName(); auto mod = op->getParentOfType(); diff --git a/mlir-compiler/src/rewrites/index_type_propagation.cpp b/mlir-compiler/plier/src/rewrites/index_type_propagation.cpp similarity index 96% rename from mlir-compiler/src/rewrites/index_type_propagation.cpp rename to mlir-compiler/plier/src/rewrites/index_type_propagation.cpp index 3dbde2f600c..9531bd57ebf 100644 --- a/mlir-compiler/src/rewrites/index_type_propagation.cpp +++ b/mlir-compiler/plier/src/rewrites/index_type_propagation.cpp @@ -1,4 +1,4 @@ -#include "rewrites/index_type_propagation.hpp" +#include "plier/rewrites/index_type_propagation.hpp" #include #include @@ -146,7 +146,7 @@ struct CmpIndexCastSimplify : public mlir::OpRewritePattern }; } -void populate_index_propagate_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns) +void plier::populate_index_propagate_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns) { patterns.insert< CmpIndexCastSimplify, diff --git a/mlir-compiler/src/rewrites/loop_rewrites.cpp b/mlir-compiler/plier/src/rewrites/loop_rewrites.cpp similarity index 95% rename from mlir-compiler/src/rewrites/loop_rewrites.cpp rename to mlir-compiler/plier/src/rewrites/loop_rewrites.cpp index eefa40b91cb..1637587ecf9 100644 --- a/mlir-compiler/src/rewrites/loop_rewrites.cpp +++ b/mlir-compiler/plier/src/rewrites/loop_rewrites.cpp @@ -1,4 +1,4 @@ -#include "rewrites/loop_rewrites.hpp" +#include "plier/rewrites/loop_rewrites.hpp" #include #include @@ -48,7 +48,7 @@ llvm::Optional handler_impl(mlir::CmpIPredicate pred, mlir::Value lhs, } } -mlir::LogicalResult CmpLoopBoundsSimplify::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const +mlir::LogicalResult plier::CmpLoopBoundsSimplify::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const { auto index_var = op.getLoopBody().front().getArgument(0); if (auto step_var = mlir::dyn_cast_or_null(op.step().getDefiningOp())) diff --git a/mlir-compiler/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp similarity index 96% rename from mlir-compiler/src/rewrites/promote_to_parallel.cpp rename to mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp index 09cef94f295..0acdae74116 100644 --- a/mlir-compiler/src/rewrites/promote_to_parallel.cpp +++ b/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp @@ -1,4 +1,4 @@ -#include "rewrites/promote_to_parallel.hpp" +#include "plier/rewrites/promote_to_parallel.hpp" #include #include @@ -32,7 +32,7 @@ bool hasSideEffects(mlir::Operation *op) } } -mlir::LogicalResult PromoteToParallel::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const +mlir::LogicalResult plier::PromoteToParallel::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const { auto has_parallel_attr = op->hasAttr(plier::attributes::getParallelName()); if (!has_parallel_attr && hasSideEffects(op)) diff --git a/mlir-compiler/src/rewrites/type_conversion.cpp b/mlir-compiler/plier/src/rewrites/type_conversion.cpp similarity index 96% rename from mlir-compiler/src/rewrites/type_conversion.cpp rename to mlir-compiler/plier/src/rewrites/type_conversion.cpp index 4cd262275fe..76652d9e4ab 100644 --- a/mlir-compiler/src/rewrites/type_conversion.cpp +++ b/mlir-compiler/plier/src/rewrites/type_conversion.cpp @@ -1,4 +1,4 @@ -#include "rewrites/type_conversion.hpp" +#include "plier/rewrites/type_conversion.hpp" #include #include @@ -106,11 +106,11 @@ mlir::LogicalResult convertRegionTypes( } } -FuncOpSignatureConversion::FuncOpSignatureConversion(mlir::TypeConverter& conv, +plier::FuncOpSignatureConversion::FuncOpSignatureConversion(mlir::TypeConverter& conv, mlir::MLIRContext* ctx) : OpRewritePattern(ctx), converter(conv) {} -mlir::LogicalResult FuncOpSignatureConversion::matchAndRewrite( +mlir::LogicalResult plier::FuncOpSignatureConversion::matchAndRewrite( mlir::FuncOp funcOp, mlir::PatternRewriter& rewriter) const { auto type = funcOp.getType(); diff --git a/mlir-compiler/src/transforms/cast_utils.cpp b/mlir-compiler/plier/src/transforms/cast_utils.cpp similarity index 60% rename from mlir-compiler/src/transforms/cast_utils.cpp rename to mlir-compiler/plier/src/transforms/cast_utils.cpp index cb35b1d59b3..600a0af86ac 100644 --- a/mlir-compiler/src/transforms/cast_utils.cpp +++ b/mlir-compiler/plier/src/transforms/cast_utils.cpp @@ -1,8 +1,8 @@ -#include "transforms/cast_utils.hpp" +#include "plier/transforms/cast_utils.hpp" #include -mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, mlir::Type dst_type) +mlir::Value plier::index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, mlir::Type dst_type) { auto src_type = src.getType(); assert(src_type.isa() || dst_type.isa()); @@ -13,7 +13,7 @@ mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value return src; } -mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src) +mlir::Value plier::index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src) { return index_cast(builder, loc, src, mlir::IndexType::get(builder.getContext())); } diff --git a/mlir-compiler/src/transforms/const_utils.cpp b/mlir-compiler/plier/src/transforms/const_utils.cpp similarity index 66% rename from mlir-compiler/src/transforms/const_utils.cpp rename to mlir-compiler/plier/src/transforms/const_utils.cpp index f7377077f51..fb1b87e007e 100644 --- a/mlir-compiler/src/transforms/const_utils.cpp +++ b/mlir-compiler/plier/src/transforms/const_utils.cpp @@ -1,9 +1,9 @@ -#include "transforms/const_utils.hpp" +#include "plier/transforms/const_utils.hpp" #include #include -mlir::Attribute getConstVal(mlir::Operation* op) +mlir::Attribute plier::getConstVal(mlir::Operation* op) { if (!op->hasTrait()) { @@ -13,7 +13,7 @@ mlir::Attribute getConstVal(mlir::Operation* op) return op->getAttr("value"); } -mlir::Attribute getConstVal(mlir::Value op) +mlir::Attribute plier::getConstVal(mlir::Value op) { if (auto parent_op = op.getDefiningOp()) { diff --git a/mlir-compiler/src/transforms/func_utils.cpp b/mlir-compiler/plier/src/transforms/func_utils.cpp similarity index 69% rename from mlir-compiler/src/transforms/func_utils.cpp rename to mlir-compiler/plier/src/transforms/func_utils.cpp index cd7205dde7b..73a0edd89fb 100644 --- a/mlir-compiler/src/transforms/func_utils.cpp +++ b/mlir-compiler/plier/src/transforms/func_utils.cpp @@ -1,12 +1,13 @@ -#include "transforms/func_utils.hpp" +#include "plier/transforms/func_utils.hpp" #include #include #include -mlir::FuncOp add_function(mlir::OpBuilder& builder, mlir::ModuleOp module, - llvm::StringRef name, mlir::FunctionType type) +mlir::FuncOp plier::add_function( + mlir::OpBuilder& builder, mlir::ModuleOp module, llvm::StringRef name, + mlir::FunctionType type) { mlir::OpBuilder::InsertionGuard guard(builder); // Insert before module terminator. diff --git a/mlir-compiler/src/transforms/loop_utils.cpp b/mlir-compiler/plier/src/transforms/loop_utils.cpp similarity index 97% rename from mlir-compiler/src/transforms/loop_utils.cpp rename to mlir-compiler/plier/src/transforms/loop_utils.cpp index e09cf8f5b3c..a5898cf7d70 100644 --- a/mlir-compiler/src/transforms/loop_utils.cpp +++ b/mlir-compiler/plier/src/transforms/loop_utils.cpp @@ -1,4 +1,4 @@ -#include "transforms/loop_utils.hpp" +#include "plier/transforms/loop_utils.hpp" #include @@ -10,7 +10,7 @@ #include "plier/dialect.hpp" -#include "transforms/cast_utils.hpp" +#include "plier/transforms/cast_utils.hpp" namespace { @@ -42,7 +42,7 @@ mlir::Value get_last_iter_value( } -mlir::LogicalResult lower_while_to_for( +mlir::LogicalResult plier::lower_while_to_for( plier::GetiterOp getiter, mlir::PatternRewriter& builder, llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, llvm::function_ref get_iter_val, @@ -104,7 +104,7 @@ mlir::LogicalResult lower_while_to_for( auto index_cast = [&](mlir::Value val)->mlir::Value { - return ::index_cast(builder, loc, val); + return ::plier::index_cast(builder, loc, val); }; auto bounds = get_bounds(builder, loc); diff --git a/mlir-compiler/src/transforms/pipeline_utils.cpp b/mlir-compiler/plier/src/transforms/pipeline_utils.cpp similarity index 84% rename from mlir-compiler/src/transforms/pipeline_utils.cpp rename to mlir-compiler/plier/src/transforms/pipeline_utils.cpp index 91f59faf0b1..c390201a79c 100644 --- a/mlir-compiler/src/transforms/pipeline_utils.cpp +++ b/mlir-compiler/plier/src/transforms/pipeline_utils.cpp @@ -1,16 +1,16 @@ -#include "transforms/pipeline_utils.hpp" +#include "plier/transforms/pipeline_utils.hpp" #include #include #include "plier/dialect.hpp" -mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module) +mlir::ArrayAttr plier::get_pipeline_jump_markers(mlir::ModuleOp module) { return module->getAttrOfType(plier::attributes::getJumpMarkersName()); } -void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) +void plier::add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) { assert(name); assert(!name.getValue().empty()); @@ -38,7 +38,7 @@ void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) } -void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) +void plier::remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) { assert(name); assert(!name.getValue().empty()); diff --git a/mlir-compiler/src/utils.cpp b/mlir-compiler/plier/src/utils.cpp similarity index 56% rename from mlir-compiler/src/utils.cpp rename to mlir-compiler/plier/src/utils.cpp index 8d3b7f19d0b..36a9a414a92 100644 --- a/mlir-compiler/src/utils.cpp +++ b/mlir-compiler/plier/src/utils.cpp @@ -1,10 +1,10 @@ -#include "utils.hpp" +#include "plier/utils.hpp" #include #include "llvm/ADT/Twine.h" -void report_error(const llvm::Twine& msg) +void plier::report_error(const llvm::Twine& msg) { throw std::runtime_error(msg.str()); } From 11e0eaa5d75285b67b287670406bc061c0b676ec Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 8 Feb 2021 14:25:38 +0300 Subject: [PATCH 229/259] getZeroVal (#173) --- .../src/pipelines/parallel_to_tbb.cpp | 10 ++++------ .../mlir-compiler/src/py_linalg_resolver.cpp | 18 ++++-------------- .../include/plier/transforms/const_utils.hpp | 4 +++- .../plier/src/transforms/const_utils.cpp | 18 +++++++++++++++++- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp index dff93c14c24..b6ae835edb4 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -10,6 +10,7 @@ #include "plier/dialect.hpp" #include "plier/compiler/pipeline_registry.hpp" +#include "plier/transforms/const_utils.hpp" #include "pipelines/base_pipeline.hpp" #include "pipelines/lower_to_llvm.hpp" @@ -26,13 +27,10 @@ mlir::MemRefType getReduceType(mlir::Type type, int64_t count) mlir::Value getZeroVal(mlir::OpBuilder& builder, mlir::Location loc, mlir::Type type) { - if (type.isa()) + auto const_val = plier::getZeroVal(type); + if (const_val) { - return builder.create(loc, 0, type.cast()); - } - if (type.isa()) - { - return builder.create(loc, llvm::APFloat(0.0), type.cast()); + return builder.create(loc, const_val); } llvm_unreachable("Unhandled type"); } diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 295abbda78e..49c7aee5843 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -14,6 +14,7 @@ #include "plier/dialect.hpp" #include "py_map_types.hpp" #include "plier/utils.hpp" +#include "plier/transforms/const_utils.hpp" namespace py = pybind11; @@ -267,19 +268,6 @@ auto generic_op_body_result_types(mlir::ValueRange outputs) return ret; } -mlir::Attribute zero_attr(mlir::Type type) -{ - if (type.isa()) - { - return mlir::IntegerAttr::get(type, 0); - } - if (type.isa()) - { - return mlir::FloatAttr::get(type, 0.0); - } - plier::report_error("zero_attr: unhandled type"); -} - py::object broadcast_impl(py::capsule /*context*/, py::tuple args) { if (1 == args.size()) @@ -300,7 +288,9 @@ py::object init_tensor_impl(py::capsule context, py::list shape, py::capsule dty if (shape.empty()) { // TODO: undef - init = ctx.builder.create(ctx.loc, zero_attr(elem_type)); + auto zero_val = plier::getZeroVal(elem_type); + assert(zero_val); + init = ctx.builder.create(ctx.loc, zero_val); } else { diff --git a/mlir-compiler/plier/include/plier/transforms/const_utils.hpp b/mlir-compiler/plier/include/plier/transforms/const_utils.hpp index 7c8c5169c9a..ea4f22b0eec 100644 --- a/mlir-compiler/plier/include/plier/transforms/const_utils.hpp +++ b/mlir-compiler/plier/include/plier/transforms/const_utils.hpp @@ -1,11 +1,11 @@ #pragma once #include +#include namespace mlir { class Operation; -class Value; } namespace plier @@ -24,4 +24,6 @@ T getConstVal(mlir::Value op) { return getConstVal(op).dyn_cast_or_null(); } + +mlir::Attribute getZeroVal(mlir::Type type); } diff --git a/mlir-compiler/plier/src/transforms/const_utils.cpp b/mlir-compiler/plier/src/transforms/const_utils.cpp index fb1b87e007e..c0bd0646799 100644 --- a/mlir-compiler/plier/src/transforms/const_utils.cpp +++ b/mlir-compiler/plier/src/transforms/const_utils.cpp @@ -1,10 +1,11 @@ #include "plier/transforms/const_utils.hpp" #include -#include +#include mlir::Attribute plier::getConstVal(mlir::Operation* op) { + assert(op); if (!op->hasTrait()) { return {}; @@ -15,9 +16,24 @@ mlir::Attribute plier::getConstVal(mlir::Operation* op) mlir::Attribute plier::getConstVal(mlir::Value op) { + assert(op); if (auto parent_op = op.getDefiningOp()) { return getConstVal(parent_op); } return {}; } + +mlir::Attribute plier::getZeroVal(mlir::Type type) +{ + assert(type); + if (type.isa()) + { + return mlir::FloatAttr::get(type, 0.0); + } + if (type.isa()) + { + return mlir::IntegerAttr::get(type, 0); + } + return {}; +} From 954852cc4bc1e849de4ffb12db151d3559a533d5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 9 Feb 2021 19:31:29 +0300 Subject: [PATCH 230/259] modulo op (#174) --- mlir-compiler/mlir-compiler/src/lowering.cpp | 1 + .../src/pipelines/plier_to_std.cpp | 24 +++++++++++++++++++ numba/mlir/tests/test_basic.py | 7 ++++-- 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/lowering.cpp b/mlir-compiler/mlir-compiler/src/lowering.cpp index e30d3c3c85a..294653f7342 100644 --- a/mlir-compiler/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/mlir-compiler/src/lowering.cpp @@ -72,6 +72,7 @@ static const constexpr OpId inst_ops_names[] = { {"*", "mul"}, {"/", "truediv"}, {"//", "floordiv"}, + {"%", "mod"}, {">", "gt"}, {">=", "ge"}, diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index bde03c75f3f..3dbddc68638 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -511,11 +511,34 @@ void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type void replace_itruediv_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) { assert(nullptr != op); + assert(new_type.isa()); auto lhs = do_cast(new_type, operands[0], rewriter); auto rhs = do_cast(new_type, operands[1], rewriter); rewriter.replaceOpWithNewOp(op, lhs, rhs); } +void replace_imod_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) +{ + auto loc = op->getLoc(); + auto a = operands[0]; + auto b = operands[1]; + auto v1 = rewriter.create(loc, a, b).getResult(); + auto v2 = rewriter.create(loc, v1, b).getResult(); + auto res = rewriter.create(loc, v2, b).getResult(); + rewriter.replaceOp(op, res); +} + +void replace_fmod_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) +{ + auto loc = op->getLoc(); + auto a = operands[0]; + auto b = operands[1]; + auto v1 = rewriter.create(loc, a, b).getResult(); + auto v2 = rewriter.create(loc, v1, b).getResult(); + auto res = rewriter.create(loc, v2, b).getResult(); + rewriter.replaceOp(op, res); +} + template void replace_cmp_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type /*new_type*/, mlir::ValueRange operands) { @@ -577,6 +600,7 @@ struct BinOpLowering : public mlir::OpRewritePattern {"-", &replace_op, &replace_op}, {"*", &replace_op, &replace_op}, {"/", &replace_itruediv_op, &replace_op}, + {"%", &replace_imod_op, &replace_fmod_op}, {">", &replace_cmp_op(mlir::CmpIPredicate::sgt)>, &replace_cmp_op(mlir::CmpFPredicate::OGT)>}, diff --git a/numba/mlir/tests/test_basic.py b/numba/mlir/tests/test_basic.py index 01a0fae1fb0..23bdf352734 100644 --- a/numba/mlir/tests/test_basic.py +++ b/numba/mlir/tests/test_basic.py @@ -1,6 +1,6 @@ import numba from numba import njit -from math import nan, inf +from math import nan, inf, isnan from numpy.testing import assert_equal # for nans comparison from numba.tests.support import TestCase @@ -8,7 +8,9 @@ import itertools -_test_values = [-3,-2,-1,0,1,2,3,-2.5,-1.0,-0.5 -0.0, 0.0, 0.5, 1.0, 2.5, -inf, inf] # TODO: nans +# TODO: nans and infs not tested yet, we are not sure if want exactly follow +# interpreted python rules +_test_values = [-3,-2,-1,0,1,2,3,-2.5,-1.0,-0.5 -0.0, 0.0, 0.5, 1.0, 2.5] class TestMlirBasic(TestCase): def test_ret(self): @@ -25,6 +27,7 @@ def test_ops(self): lambda a, b: a - b, lambda a, b: a * b, lambda a, b: a / b, + lambda a, b: a % b, # TODO: floordiv ] From 45da142373266a934f93111f2a562bde6435302b Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Feb 2021 02:34:32 +0300 Subject: [PATCH 231/259] [MLIR] Return arrays from func (#183) --- .../src/pipelines/lower_to_llvm.cpp | 454 +++++++++++++++--- .../src/pipelines/plier_to_linalg.cpp | 10 +- numba/mlir/tests/test_numpy.py | 17 + 3 files changed, 408 insertions(+), 73 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 5d85f5ed9ba..c7278aa6b2f 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -176,74 +176,186 @@ mlir::Value unflatten(mlir::Type type, mlir::Location loc, mlir::OpBuilder& buil } } -std::string gen_conversion_func_name(mlir::MemRefType memref_type) +std::string gen_to_memref_conversion_func_name(mlir::MemRefType memref_type) { assert(memref_type); std::string ret; llvm::raw_string_ostream ss(ret); - ss << "__convert_memref_"; + ss << "__convert_to_memref_"; memref_type.getElementType().print(ss); ss.flush(); return ret; } -struct MemRefConversionCache +std::string gen_from_memref_conversion_func_name(mlir::MemRefType memref_type) { - mlir::FuncOp get_conversion_func( - mlir::ModuleOp module, mlir::OpBuilder& builder, mlir::MemRefType memref_type, - mlir::LLVM::LLVMStructType src_type, mlir::LLVM::LLVMStructType dst_type) - { - assert(memref_type); - assert(src_type); - assert(dst_type); - auto it = cache.find(memref_type); - if (it != cache.end()) - { - auto func = it->second; - assert(func.getType().getNumResults() == 1); - assert(func.getType().getResult(0) == dst_type); - return func; - } - auto func_name = gen_conversion_func_name(memref_type); - auto func_type = mlir::FunctionType::get(builder.getContext(),src_type, dst_type); - auto loc = builder.getUnknownLoc(); - auto new_func = plier::add_function(builder, module, func_name, func_type); - auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); - new_func->setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); - cache.insert({memref_type, new_func}); - mlir::OpBuilder::InsertionGuard guard(builder); - auto block = new_func.addEntryBlock(); - builder.setInsertionPointToStart(block); - namespace mllvm = mlir::LLVM; - mlir::Value arg = block->getArgument(0); - auto extract = [&](unsigned index) - { - auto res_type = src_type.getBody()[index]; - auto i = builder.getI64ArrayAttr(index); - return builder.create(loc, res_type, arg, i); - }; - auto ptr = extract(4); - auto shape = extract(5); - auto strides = extract(6); - auto i64 = mlir::IntegerType::get(builder.getContext(), 64); - auto offset = builder.create(loc, i64, builder.getI64IntegerAttr(0)); - mlir::Value res = builder.create(loc, dst_type); - auto insert = [&](unsigned index, mlir::Value val) - { - auto i = builder.getI64ArrayAttr(index); - res = builder.create(loc, res, val, i); - }; - insert(0, ptr); - insert(1, ptr); - insert(2, offset); - insert(3, shape); - insert(4, strides); - builder.create(loc, res); - return new_func; + assert(memref_type); + std::string ret; + llvm::raw_string_ostream ss(ret); + ss << "__convert_from_memref_"; + memref_type.getElementType().print(ss); + ss.flush(); + return ret; +} + +mlir::Value div_strides(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value strides, mlir::Value m) +{ + auto array_type = strides.getType().cast(); + mlir::Value array = builder.create(loc, array_type); + for (unsigned i = 0 ; i < array_type.getNumElements(); ++i) + { + auto index = builder.getI64ArrayAttr(i); + auto prev = builder.create(loc, array_type.getElementType(), strides, index); + auto val = builder.create(loc, prev, m); + array = builder.create(loc, array, val, index); } -private: - llvm::DenseMap cache; -}; + return array; +} + +mlir::Value mul_strides(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value strides, mlir::Value m) +{ + auto array_type = strides.getType().cast(); + mlir::Value array = builder.create(loc, array_type); + for (unsigned i = 0 ; i < array_type.getNumElements(); ++i) + { + auto index = builder.getI64ArrayAttr(i); + auto prev = builder.create(loc, array_type.getElementType(), strides, index); + auto val = builder.create(loc, prev, m); + array = builder.create(loc, array, val, index); + } + return array; +} + +unsigned item_size(mlir::Type type) +{ + if (auto inttype = type.dyn_cast()) + { + assert((inttype.getWidth() % 8) == 0); + return inttype.getWidth() / 8; + } + if (auto floattype = type.dyn_cast()) + { + assert((floattype.getWidth() % 8) == 0); + return floattype.getWidth() / 8; + } + llvm_unreachable("item_size: invalid type"); +} + +mlir::FuncOp get_to_memref_conversion_func( + mlir::ModuleOp module, mlir::OpBuilder& builder, mlir::MemRefType memref_type, + mlir::LLVM::LLVMStructType src_type, mlir::LLVM::LLVMStructType dst_type) +{ + assert(memref_type); + assert(src_type); + assert(dst_type); + auto func_name = gen_to_memref_conversion_func_name(memref_type); + if (auto func = module.lookupSymbol(func_name)) + { + assert(func.getType().getNumResults() == 1); + assert(func.getType().getResult(0) == dst_type); + return func; + } + auto func_type = mlir::FunctionType::get(builder.getContext(), src_type, dst_type); + auto loc = builder.getUnknownLoc(); + auto new_func = plier::add_function(builder, module, func_name, func_type); + auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); + new_func->setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); + mlir::OpBuilder::InsertionGuard guard(builder); + auto block = new_func.addEntryBlock(); + builder.setInsertionPointToStart(block); + namespace mllvm = mlir::LLVM; + mlir::Value arg = block->getArgument(0); + auto extract = [&](unsigned index) + { + auto res_type = src_type.getBody()[index]; + auto i = builder.getI64ArrayAttr(index); + return builder.create(loc, res_type, arg, i); + }; + auto meminfo = extract(0); + auto ptr = extract(4); + auto shape = extract(5); + auto strides = extract(6); + auto i64 = mlir::IntegerType::get(builder.getContext(), 64); + auto offset = builder.create(loc, i64, builder.getI64IntegerAttr(0)); + mlir::Value res = builder.create(loc, dst_type); + auto meminfo_casted = builder.create(loc, ptr.getType(), meminfo); + auto itemsize = builder.create(loc, i64, builder.getI64IntegerAttr(item_size(memref_type.getElementType()))); + auto insert = [&](unsigned index, mlir::Value val) + { + auto i = builder.getI64ArrayAttr(index); + res = builder.create(loc, res, val, i); + }; + insert(0, meminfo_casted); + insert(1, ptr); + insert(2, offset); + insert(3, shape); + insert(4, div_strides(loc, builder, strides, itemsize)); + builder.create(loc, res); + return new_func; +} + +mlir::FuncOp get_from_memref_conversion_func( + mlir::ModuleOp module, mlir::OpBuilder& builder, mlir::MemRefType memref_type, + mlir::LLVM::LLVMStructType src_type, mlir::LLVM::LLVMStructType dst_type) +{ + assert(memref_type); + assert(src_type); + assert(dst_type); + auto func_name = gen_from_memref_conversion_func_name(memref_type); + if (auto func = module.lookupSymbol(func_name)) + { + assert(func.getType().getNumResults() == 1); + assert(func.getType().getResult(0) == dst_type); + return func; + } + auto func_type = mlir::FunctionType::get(builder.getContext(), src_type, dst_type); + auto loc = builder.getUnknownLoc(); + auto new_func = plier::add_function(builder, module, func_name, func_type); + auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); + new_func->setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); + mlir::OpBuilder::InsertionGuard guard(builder); + auto block = new_func.addEntryBlock(); + builder.setInsertionPointToStart(block); + namespace mllvm = mlir::LLVM; + mlir::Value arg = block->getArgument(0); + auto i8ptr_type = mllvm::LLVMPointerType::get(builder.getIntegerType(8)); + auto i64_type = builder.getIntegerType(64); + auto extract = [&](unsigned index) + { + auto res_type = src_type.getBody()[index]; + auto i = builder.getI64ArrayAttr(index); + return builder.create(loc, res_type, arg, i); + }; + auto meminfo = builder.create(loc, i8ptr_type, extract(0)); + auto orig_ptr = extract(1); + auto offset = extract(2); + auto shape = extract(3); + auto strides = extract(4); + auto ptr = builder.create(loc, orig_ptr.getType(), orig_ptr, offset.getResult()); + mlir::Value res = builder.create(loc, dst_type); + auto null = builder.create(loc, i8ptr_type); + mlir::Value nitems = builder.create(loc, i64_type, builder.getI64IntegerAttr(1)); + for (int64_t i = 0; i < memref_type.getRank(); ++i) + { + auto dim = builder.create(loc, nitems.getType(), shape, builder.getI64ArrayAttr(i)); + nitems = builder.create(loc, nitems, dim); + } + auto itemsize = builder.create(loc, i64_type, builder.getI64IntegerAttr(item_size(memref_type.getElementType()))); + auto insert = [&](unsigned index, mlir::Value val) + { + auto i = builder.getI64ArrayAttr(index); + res = builder.create(loc, res, val, i); + }; + insert(0, meminfo); + insert(1, null); // parent + insert(2, nitems); + insert(3, itemsize); + insert(4, ptr); + insert(5, shape); + insert(6, mul_strides(loc, builder, strides, itemsize)); + builder.create(loc, res); + return new_func; +} mlir::Attribute get_fastmath_attrs(mlir::MLIRContext& ctx) { @@ -296,8 +408,6 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) return ret; }; - MemRefConversionCache conversion_cache; - mlir::OpBuilder builder(&ctx); builder.setInsertionPointToStart(&func.getBody().front()); @@ -324,7 +434,7 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) auto mod = mlir::cast(func->getParentOp()); auto dst_type = type_helper.get_type_converter().convertType(memref_type); assert(dst_type); - auto conv_func = conversion_cache.get_conversion_func(mod, builder, memref_type, arr_type, dst_type.cast()); + auto conv_func = get_to_memref_conversion_func(mod, builder, memref_type, arr_type, dst_type.cast()); auto converted = builder.create(loc, conv_func, desc).getResult(0); auto casted = builder.create(loc, memref_type, converted); func.getBody().getArgument(index).replaceAllUsesWith(casted); @@ -337,7 +447,16 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) } }; - auto orig_ret_type = (old_type.getNumResults() != 0 ? old_type.getResult(0) : type_helper.ptr(type_helper.i(8))); + auto get_res_type = [&](mlir::Type type)->mlir::Type + { + if (auto memreftype = type.dyn_cast()) + { + return get_array_type(type_helper.get_type_converter(), memreftype); + } + return type; + }; + + auto orig_ret_type = (old_type.getNumResults() != 0 ? get_res_type(old_type.getResult(0)) : type_helper.ptr(type_helper.i(8))); add_arg(ptr(orig_ret_type)); add_arg(ptr(ptr(getExceptInfoType(type_helper)))); @@ -373,6 +492,7 @@ struct ReturnOpLowering : public mlir::OpRewritePattern rewriter.replaceOpWithNewOp(op, ret); }; + auto loc = op.getLoc(); rewriter.setInsertionPoint(op); auto addr = op->getParentRegion()->front().getArgument(0); if (op.getNumOperands() == 0) @@ -380,17 +500,25 @@ struct ReturnOpLowering : public mlir::OpRewritePattern assert(addr.getType().isa()); auto null_type = addr.getType().cast().getElementType(); auto ll_val = rewriter.create(op.getLoc(), null_type); - rewriter.create(op.getLoc(), ll_val, addr); + rewriter.create(loc, ll_val, addr); insert_ret(); return mlir::success(); } else if (op.getNumOperands() == 1) { - auto val = op.getOperand(0); - auto ll_ret_type = type_converter.convertType(val.getType()); - assert(static_cast(ll_ret_type)); - auto ll_val = rewriter.create(op.getLoc(), ll_ret_type, val); // TODO: hack to make verifier happy - rewriter.create(op.getLoc(), ll_val, addr); + mlir::Value val = op.getOperand(0); + auto orig_type = val.getType(); + auto ll_ret_type = type_converter.convertType(orig_type); + assert(ll_ret_type); + val = rewriter.create(loc, ll_ret_type, val); + if (auto memref_type = orig_type.dyn_cast()) + { + auto dst_type = get_array_type(type_converter, memref_type).cast(); + auto mod = op->getParentOfType(); + auto func = get_from_memref_conversion_func(mod, rewriter, memref_type, ll_ret_type.cast(), dst_type); + val = rewriter.create(loc, func, val).getResult(0); + } + rewriter.create(loc, val, addr); insert_ret(); return mlir::success(); } @@ -466,6 +594,196 @@ struct ApplyFastmathFlags : public mlir::OpRewritePattern } }; +// Copypaste from StandardToLLVM +mlir::Value createIndexAttrConstant(mlir::OpBuilder &builder, mlir::Location loc, + mlir::Type resultType, int64_t value) { + return builder.create( + loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); +} + +struct AllocLikeOpLowering : public mlir::ConvertToLLVMPattern { + using ConvertToLLVMPattern::createIndexConstant; + using ConvertToLLVMPattern::getIndexType; + using ConvertToLLVMPattern::getVoidPtrType; + + explicit AllocLikeOpLowering(mlir::StringRef opName, mlir::LLVMTypeConverter &converter) + : ConvertToLLVMPattern(opName, &converter.getContext(), converter, /*benefit*/99) {} + +protected: + // Returns 'input' aligned up to 'alignment'. Computes + // bumped = input + alignement - 1 + // aligned = bumped - bumped % alignment +// static mlir::Value createAligned(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, +// mlir::Value input, mlir::Value alignment) { +// using namespace mlir; +// Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); +// Value bump = rewriter.create(loc, alignment, one); +// Value bumped = rewriter.create(loc, input, bump); +// Value mod = rewriter.create(loc, bumped, alignment); +// return rewriter.create(loc, bumped, mod); +// } + + // Creates a call to an allocation function with params and casts the + // resulting void pointer to ptrType. + mlir::Value createAllocCall(mlir::Location loc, mlir::StringRef name, mlir::Type ptrType, + mlir::ArrayRef params, mlir::ModuleOp module, + mlir::ConversionPatternRewriter &rewriter) const { + using namespace mlir; + SmallVector paramTypes; + auto allocFuncOp = module.lookupSymbol(name); + if (!allocFuncOp) { + for (Value param : params) + paramTypes.push_back(param.getType()); + auto allocFuncType = + LLVM::LLVMFunctionType::get(getVoidPtrType(), paramTypes); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(module.getBody()); + allocFuncOp = rewriter.create(rewriter.getUnknownLoc(), + name, allocFuncType); + } + auto allocFuncSymbol = rewriter.getSymbolRefAttr(allocFuncOp); + auto allocatedPtr = rewriter + .create(loc, getVoidPtrType(), + allocFuncSymbol, params) + .getResult(0); + return rewriter.create(loc, ptrType, allocatedPtr); + } + + /// Allocates the underlying buffer. Returns the allocated pointer and the + /// aligned pointer. + virtual std::tuple + allocateBuffer(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, + mlir::Value sizeBytes, mlir::Operation *op) const = 0; + +private: + static mlir::MemRefType getMemRefResultType(mlir::Operation *op) { + return op->getResult(0).getType().cast(); + } + + mlir::LogicalResult match(mlir::Operation *op) const override { + mlir::MemRefType memRefType = getMemRefResultType(op); + return mlir::success(isConvertibleAndHasIdentityMaps(memRefType)); + } + + // An `alloc` is converted into a definition of a memref descriptor value and + // a call to `malloc` to allocate the underlying data buffer. The memref + // descriptor is of the LLVM structure type where: + // 1. the first element is a pointer to the allocated (typed) data buffer, + // 2. the second element is a pointer to the (typed) payload, aligned to the + // specified alignment, + // 3. the remaining elements serve to store all the sizes and strides of the + // memref using LLVM-converted `index` type. + // + // Alignment is performed by allocating `alignment` more bytes than + // requested and shifting the aligned pointer relative to the allocated + // memory. Note: `alignment - ` would actually be + // sufficient. If alignment is unspecified, the two pointers are equal. + + // An `alloca` is converted into a definition of a memref descriptor value and + // an llvm.alloca to allocate the underlying data buffer. + void rewrite(mlir::Operation *op, mlir::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + mlir::MemRefType memRefType = getMemRefResultType(op); + auto loc = op->getLoc(); + + // Get actual sizes of the memref as values: static sizes are constant + // values and dynamic sizes are passed to 'alloc' as operands. In case of + // zero-dimensional memref, assume a scalar (size 1). + mlir::SmallVector sizes; + mlir::SmallVector strides; + mlir::Value sizeBytes; + this->getMemRefDescriptorSizes(loc, memRefType, operands, rewriter, sizes, + strides, sizeBytes); + + // Allocate the underlying buffer. + mlir::Value allocatedPtr; + mlir::Value alignedPtr; + std::tie(allocatedPtr, alignedPtr) = + this->allocateBuffer(rewriter, loc, sizeBytes, op); + + // Create the MemRef descriptor. + auto memRefDescriptor = this->createMemRefDescriptor( + loc, memRefType, allocatedPtr, alignedPtr, sizes, strides, rewriter); + + // Return the final value of the descriptor. + rewriter.replaceOp(op, {memRefDescriptor}); + } +}; + +struct AllocOpLowering : public AllocLikeOpLowering { + AllocOpLowering(mlir::LLVMTypeConverter &converter) + : AllocLikeOpLowering(mlir::AllocOp::getOperationName(), converter) {} + + std::tuple allocateBuffer(mlir::ConversionPatternRewriter &rewriter, + mlir::Location loc, mlir::Value sizeBytes, + mlir::Operation *op) const override { + auto allocOp = mlir::cast(op); + auto memRefType = allocOp.getType(); + mlir::Value alignment; + if (auto alignmentAttr = allocOp.alignment()) { + alignment = createIndexConstant(rewriter, loc, *alignmentAttr); + } else if (!memRefType.getElementType().isSignlessIntOrIndexOrFloat()) { + // In the case where no alignment is specified, we may want to override + // `malloc's` behavior. `malloc` typically aligns at the size of the + // biggest scalar on a target HW. For non-scalars, use the natural + // alignment of the LLVM type given by the LLVM DataLayout. + alignment = getSizeInBytes(loc, memRefType.getElementType(), rewriter); + } else { + alignment = createIndexConstant(rewriter, loc, 32/*item_size(memRefType.getElementType())*/); + } + alignment = rewriter.create(loc, rewriter.getIntegerType(32), alignment); + + auto mod = allocOp->getParentOfType(); + auto meminfo_ptr = + createAllocCall(loc, "NRT_MemInfo_alloc_safe_aligned", getVoidPtrType(), {sizeBytes, alignment}, + mod, rewriter); + auto data_ptr = createAllocCall(loc, "NRT_MemInfo_data_fast", getVoidPtrType(), {meminfo_ptr}, + mod, rewriter); + + auto elem_ptr_type = mlir::LLVM::LLVMPointerType::get(memRefType.getElementType()); + auto bitcast = [&](mlir::Value val) + { + return rewriter.create(loc, elem_ptr_type, val); + }; + + return std::make_tuple(bitcast(meminfo_ptr), bitcast(data_ptr)); + } +}; + +struct DeallocOpLowering : public mlir::ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + explicit DeallocOpLowering(mlir::LLVMTypeConverter &converter) + : ConvertOpToLLVMPattern(converter, /*benefit*/99) {} + + mlir::LogicalResult + matchAndRewrite(mlir::DeallocOp op, mlir::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1 && "dealloc takes one operand"); + mlir::DeallocOp::Adaptor transformed(operands); + + // Insert the `free` declaration if it is not already present. + auto freeFunc = + op->getParentOfType().lookupSymbol("NRT_decref"); + if (!freeFunc) { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart( + op->getParentOfType().getBody()); + freeFunc = rewriter.create( + rewriter.getUnknownLoc(), "NRT_decref", + mlir::LLVM::LLVMFunctionType::get(getVoidType(), getVoidPtrType())); + } + + mlir::MemRefDescriptor memref(transformed.memref()); + mlir::Value casted = rewriter.create( + op.getLoc(), getVoidPtrType(), + memref.allocatedPtr(rewriter, op.getLoc())); + rewriter.replaceOpWithNewOp( + op, mlir::TypeRange(), rewriter.getSymbolRefAttr(freeFunc), casted); + return mlir::success(); + } +}; + class CheckForPlierTypes : public mlir::PassWrapper> { @@ -877,6 +1195,7 @@ struct LLVMLoweringPass : public mlir::PassWrapper(typeConverter, &getContext()); + patterns.insert(typeConverter); LLVMConversionTarget target(getContext()); if (failed(applyPartialConversion(m, target, std::move(patterns)))) @@ -896,7 +1215,6 @@ void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) // pm.addPass(std::make_unique()); pm.addNestedPass(std::make_unique()); pm.addPass(std::make_unique(getLLVMOptions())); -// pm.addPass(mlir::createLowerToLLVMPass(getLLVMOptions())); pm.addNestedPass(std::make_unique()); } } diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index e2675e8c089..89f4f86303b 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -653,16 +654,15 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::createLinalgFusionOfTensorOpsPass()); - pm.addNestedPass(mlir::createTensorBufferizePass()); + pm.addPass(mlir::createTensorConstantBufferizePass()); + pm.addNestedPass(mlir::createSCFBufferizePass()); pm.addNestedPass(mlir::createLinalgBufferizePass()); pm.addNestedPass(mlir::createStdBufferizePass()); + pm.addNestedPass(mlir::createTensorBufferizePass()); pm.addPass(mlir::createFuncBufferizePass()); + pm.addNestedPass(mlir::createFinalizingBufferizePass()); - pm.addNestedPass(mlir::createPromoteBuffersToStackPass(1024)); - pm.addNestedPass(mlir::createBufferHoistingPass()); - pm.addNestedPass(mlir::createBufferLoopHoistingPass()); pm.addNestedPass(mlir::createBufferDeallocationPass()); - pm.addNestedPass(mlir::createCopyRemovalPass()); pm.addPass(std::make_unique()); pm.addPass(std::make_unique()); diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 170e5b2e703..8c32f522650 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -40,6 +40,15 @@ def py_func(a): arr = np.asarray([1,2,3]) assert_equal(py_func(arr), jit_func(arr)) + def test_add(self): + def py_func(a, b): + return np.add(a, b) + + jit_func = njit(py_func) + arr1 = np.array([1,2,3]) + arr2 = np.array([4,5,6]) + assert_equal(py_func(arr1,arr2), jit_func(arr1,arr2)) + def test_add_scalar(self): def py_func(a, b): return np.add(a, b) @@ -153,5 +162,13 @@ def py_func(a): arr = np.array([[1,2,3],[4,5,6]]) assert_equal(py_func(arr), jit_func(arr)) + def test_array_return(self): + def py_func(a): + return a + + jit_func = njit(py_func) + arr = np.array([1,2,3]) + assert_equal(py_func(arr), jit_func(arr)) + if __name__ == '__main__': unittest.main() From bcf4a35835328b37b505a1f7c489c85d7fda1f38 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 16 Feb 2021 15:19:10 +0300 Subject: [PATCH 232/259] [MLIR] array GetItem/SetItem multidim argument (#184) * test * check getitem/setitem index type * PropagateBuildTupleTypes * multidim setitem/getitem --- .../src/pipelines/plier_to_linalg.cpp | 80 ++++++++++++++----- .../src/pipelines/plier_to_std.cpp | 22 +++++ numba/mlir/tests/test_numpy.py | 11 ++- 3 files changed, 91 insertions(+), 22 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 89f4f86303b..6082dd3662f 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -223,6 +223,23 @@ struct CallLowerer PyLinalgResolver linalg_resolver; }; +mlir::Value index_cast(mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder) +{ + if (!value.getType().isa()) + { + auto index_type = mlir::IndexType::get(value.getContext()); + auto res = builder.create(loc, index_type, value); + rerun_std_pipeline(res); + return res; + } + return value; +} + +bool isValidGetitemIndex(mlir::Type type) +{ + return type.isa(); +} + template struct GetitemOpLowering : public mlir::OpRewritePattern { @@ -241,24 +258,36 @@ struct GetitemOpLowering : public mlir::OpRewritePattern { return mlir::failure(); } - if (!index.getType().template isa() && - !index.getType().template isa()) + if (!isValidGetitemIndex(index.getType())) { return mlir::failure(); } auto loc = op.getLoc(); - if (index.getType().template isa()) + + llvm::SmallVector indices; + if (auto tuple_type = index.getType().template dyn_cast()) + { + indices.resize(tuple_type.size()); + for (auto it : llvm::enumerate(tuple_type)) + { + auto getitem_ind = rewriter.create(loc, it.index()); + auto ind = rewriter.create(loc, index, getitem_ind); + indices[it.index()] = index_cast(ind, loc, rewriter); + } + } + else { - index = rewriter.create(loc, index, mlir::IndexType::get(op.getContext())); + indices.push_back(index_cast(index, loc, rewriter)); } + mlir::Value res; if (is_memref) { - res = rewriter.create(loc, val, index); + res = rewriter.create(loc, val, indices); } else if (is_tensor) { - res = rewriter.create(loc, val, index); + res = rewriter.create(loc, val, indices); } else { @@ -328,18 +357,6 @@ bool replace_ssa_value(mlir::Value value, mlir::Value new_value, mlir::PatternRe llvm_unreachable("Unhandled parent op"); } -mlir::Value index_cast(mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder) -{ - if (!value.getType().isa()) - { - auto index_type = mlir::IndexType::get(value.getContext()); - auto res = builder.create(loc, index_type, value); - rerun_std_pipeline(res); - return res; - } - return value; -} - template struct SetitemOpLoweringSSA : public mlir::OpRewritePattern { @@ -407,6 +424,12 @@ struct SetitemOpLowering : public mlir::OpRewritePattern return op.getOperand(0).getType(); }; + auto index = op.index(); + if (!isValidGetitemIndex(index.getType())) + { + return mlir::failure(); + } + if (auto target_type = get_target_type().template dyn_cast()) { auto target = op.getOperand(0); @@ -452,10 +475,8 @@ struct SetitemOpLowering : public mlir::OpRewritePattern return mlir::failure(); } auto target = op.getOperand(0); - auto index = op.getOperand(1); auto value = op.getOperand(2); auto loc = op.getLoc(); - auto ind = index_cast(index, loc, rewriter); auto elem_type = target.getType().template cast().getElementType(); if (value.getType() != elem_type) { @@ -463,7 +484,24 @@ struct SetitemOpLowering : public mlir::OpRewritePattern value = rewriter.create(loc, elem_type, value); rerun_std_pipeline(op); } - auto store = rewriter.create(loc, value, target, ind); + + llvm::SmallVector indices; + if (auto tuple_type = index.getType().template dyn_cast()) + { + indices.resize(tuple_type.size()); + for (auto it : llvm::enumerate(tuple_type)) + { + auto getitem_ind = rewriter.create(loc, it.index()); + auto ind = rewriter.create(loc, index, getitem_ind); + indices[it.index()] = index_cast(ind, loc, rewriter); + } + rerun_std_pipeline(op); + } + else + { + indices.push_back(index_cast(index, loc, rewriter)); + } + rewriter.create(loc, value, target, indices); rewriter.eraseOp(op); return mlir::success(); } diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index 3dbddc68638..c491df9de18 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1105,6 +1105,27 @@ struct FixupWhileTypes : public mlir::OpRewritePattern } }; +struct PropagateBuildTupleTypes : public mlir::OpRewritePattern +{ + PropagateBuildTupleTypes(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context): + OpRewritePattern(context) {} + + mlir::LogicalResult matchAndRewrite( + plier::BuildTupleOp op, mlir::PatternRewriter &rewriter) const override + { + if (op.getType().isa() || + llvm::any_of(op.getOperandTypes(), [](mlir::Type type){ return type.isa(); })) + { + return mlir::failure(); + } + + auto new_type = mlir::TupleType::get(op.getContext(), op.getOperandTypes()); + rewriter.replaceOpWithNewOp(op, new_type, op.getOperands()); + return mlir::success(); + } +}; + template struct FoldTupleGetitem : public mlir::OpRewritePattern { @@ -1352,6 +1373,7 @@ void PlierToStdPass::runOnOperation() ScfIfRewrite, ScfWhileRewrite, FixupWhileTypes, + PropagateBuildTupleTypes, FoldTupleGetitem, FoldTupleGetitem >(type_converter, context); diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 8c32f522650..6e9b2afb40e 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -78,7 +78,7 @@ def py_func(a, b, c): arr3 = np.asarray([7,8,9]) assert_equal(py_func(arr1, arr2, arr3), jit_func(arr1, arr2, arr3)) - def test_setitem(self): + def test_setitem1(self): def py_func(a, b): a[b] = 42 return a[b] @@ -87,6 +87,15 @@ def py_func(a, b): arr = np.asarray([1,2,3]) assert_equal(py_func(arr, 1), jit_func(arr, 1)) + def test_setitem2(self): + def py_func(a, b, c): + a[b, c] = 42 + return a[b, c] + + jit_func = njit(py_func) + arr = np.asarray([[1,2,3],[4,5,6]]) + assert_equal(py_func(arr, 1, 2), jit_func(arr, 1, 2)) + def test_setitem_loop(self): def py_func(a): for i in range(len(a)): From f4a22a67d783f02af9080c6e3d26aacba1d81fb2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 17 Feb 2021 03:11:43 +0300 Subject: [PATCH 233/259] [MLIR] Numpy sqrt and square support (#185) * test * fix CanonicalizeReduction for nested loops * math. fixes * numpy sqrt * refactor * numpy.square --- .../src/pipelines/plier_to_std.cpp | 8 ++-- .../mlir-compiler/src/py_linalg_resolver.cpp | 18 +++++++- .../src/rewrites/canonicalize_reductions.cpp | 13 +++--- numba/mlir/numpy/funcs.py | 41 +++++++++++++++---- numba/mlir/tests/test_numpy.py | 24 +++++++---- 5 files changed, 80 insertions(+), 24 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index c491df9de18..1718da5fc94 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1255,13 +1255,15 @@ mlir::LogicalResult lower_math_func( auto ret_type = map_plier_type(op.getType()); auto valid_type = [&](mlir::Type type) { - return ret_type == type && type.isa(); + return type.isa(); }; if (ret_type && name.consume_front("math.") && args.size() == 1 && valid_type(args[0].getType())) { + auto loc = op.getLoc(); + mlir::Value arg = rewriter.create(loc, ret_type, args[0]); auto is_float = ret_type.isa(); - auto func_type = mlir::FunctionType::get(op.getContext(), args[0].getType(), ret_type); + auto func_type = mlir::FunctionType::get(op.getContext(), ret_type, ret_type); auto module = op->getParentOfType(); mlir::FuncOp func; if (is_float) @@ -1272,7 +1274,7 @@ mlir::LogicalResult lower_math_func( { func = get_lib_symbol(module, name, func_type, rewriter); } - auto call = rewriter.create(op.getLoc(), func, args); + auto call = rewriter.create(loc, func, arg); rewriter.replaceOp(op, call.getResults()); return mlir::success(); } diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 49c7aee5843..181c19a9166 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -429,13 +429,27 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice return ctx.context.create_var(ctx.loc, ctx.builder, res); } -void setup_py_builder(py::handle builder) +void setup_py_builder(py::handle builder, mlir::OpBuilder& b) { py::setattr(builder, "_broadcast", py::cpp_function(&broadcast_impl)); py::setattr(builder, "_init_tensor", py::cpp_function(&init_tensor_impl)); py::setattr(builder, "_generic", py::cpp_function(&generic_impl)); py::setattr(builder, "_from_elements", py::cpp_function(&from_elements_impl)); py::setattr(builder, "_extract", py::cpp_function(&extract_impl)); + + auto add_type = [&](const char* name, mlir::Type type) + { + py::setattr(builder, name, wrap_mlir(type)); + }; + + add_type("int8", b.getIntegerType(8)); + add_type("int16", b.getIntegerType(16)); + add_type("int32", b.getIntegerType(32)); + add_type("int64", b.getIntegerType(64)); + + add_type("float16", b.getF16Type()); + add_type("float32", b.getF32Type()); + add_type("float64", b.getF64Type()); } PyLinalgResolver::Values unpack_results(py::handle object) @@ -494,7 +508,7 @@ llvm::Optional PyLinalgResolver::rewrite(llvm::StringR PyBuilderContext py_builder_context{loc, builder, *context}; auto py_builder = context->builder(py::capsule(&py_builder_context)); - setup_py_builder(py_builder); + setup_py_builder(py_builder, builder); assert(!args.empty()); auto module = args.front().getParentRegion()->getParentOfType(); diff --git a/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp index db48de3bc9d..49c4f11cf32 100644 --- a/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp +++ b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp @@ -108,14 +108,17 @@ void createScalarStore( mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const { llvm::SmallVector to_process; - op.walk([&](mlir::LoadOp load) + for (auto& current : op.getLoopBody().front()) { - auto memref = load.memref(); - if (checkMemref(memref, op)) + if (auto load = mlir::dyn_cast(current)) { - to_process.emplace_back(memref); + auto memref = load.memref(); + if (checkMemref(memref, op)) + { + to_process.emplace_back(memref); + } } - }); + } if (!to_process.empty()) { diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index 28655d7fde3..f6ef37d15e4 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -1,23 +1,34 @@ from ..linalg_builder import register_func import numpy +import math -@register_func('numpy.add', numpy.add) -def add_impl(builder, arg1, arg2): - a1, a2 = builder.broadcast(arg1, arg2) - shape = a1.shape +def eltwise(builder, args, body, res_type = None): + if isinstance(args, tuple): + args = builder.broadcast(*args) + else: + args = (args,) + + if res_type is None: + res_type = args[0].dtype + + shape = args[0].shape num_dims = len(shape) iterators = ['parallel' for _ in range(num_dims)] dims = ','.join(['d%s' % i for i in range(num_dims)]) expr = f'({dims}) -> ({dims})' - maps = [expr,expr,expr] - init = builder.init_tensor(shape, a1.dtype) + maps = [expr for _ in range(len(args) + 1)] + init = builder.init_tensor(shape, res_type) + return builder.generic(args, init, iterators, maps, body) + +@register_func('numpy.add', numpy.add) +def add_impl(builder, arg1, arg2): def body(a, b, c): return a + b - return builder.generic((a1,a2), init, iterators, maps, body) + return eltwise(builder, (arg1, arg2), body) @register_func('array.sum') def sum_impl(builder, arg): @@ -36,3 +47,19 @@ def body(a, b): res = builder.generic(arg, init, iterators, maps, body) return builder.extract(res, 0) + +@register_func('numpy.sqrt', numpy.sqrt) +def sqrt_impl(builder, arg): + + def body(a, b): + return math.sqrt(a) + + return eltwise(builder, arg, body, builder.float64) + +@register_func('numpy.square', numpy.square) +def quare_impl(builder, arg): + + def body(a, b): + return a * a + + return eltwise(builder, arg, body) diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 6e9b2afb40e..d1b8493d589 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -5,6 +5,11 @@ from numba.tests.support import TestCase import unittest +_arr_1d_int = [1,2,3,4,5,6,7,8] +_arr_1d_float = [1.0,2.1,3.2,4.3,5.4,6.5,7.6,8.7] +_arr_2d_int = [[1,2,3,4],[5,6,7,8]] +_arr_2d_float = [[1.0,2.1,3.2,4.3],[5.4,6.5,7.6,8.7]] +_test_arrays = [_arr_1d_int, _arr_1d_float, _arr_2d_int, _arr_2d_float] class TestMlirBasic(TestCase): def test_staticgetitem(self): @@ -32,13 +37,18 @@ def py_func(a): arr = np.asarray([5,6,7]) assert_equal(py_func(arr), jit_func(arr)) - def test_sum(self): - def py_func(a): - return a.sum() - - jit_func = njit(py_func) - arr = np.asarray([1,2,3]) - assert_equal(py_func(arr), jit_func(arr)) + def test_unary(self): + funcs = [ + lambda a: a.sum(), + lambda a: np.sqrt(a), + lambda a: np.square(a), + ] + + for py_func in funcs: + jit_func = njit(py_func) + for a in _test_arrays: + arr = np.array(a) + assert_equal(py_func(arr), jit_func(arr)) def test_add(self): def py_func(a, b): From 0d190026142cff1cfd8f5d13b7e3fde07b8e3f37 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Feb 2021 14:23:48 +0300 Subject: [PATCH 234/259] [MLIR] Static setitem and array handling fixes (#186) * some buffer opt passes * StaticSetItem * StaticSetItem test * fix to meminfo<->memref conversion * fix lowering to tbb * generate parallel loops only loops from prange --- mlir-compiler/mlir-compiler/src/lowering.cpp | 15 ++++++++++++-- .../src/pipelines/lower_to_llvm.cpp | 18 +++++++++++++++-- .../src/pipelines/parallel_to_tbb.cpp | 10 +++++++--- .../src/pipelines/plier_to_linalg.cpp | 4 ++++ .../src/rewrites/promote_to_parallel.cpp | 4 ++++ numba/mlir/tests/test_numpy.py | 20 +++++++++++++++++++ 6 files changed, 64 insertions(+), 7 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/lowering.cpp b/mlir-compiler/mlir-compiler/src/lowering.cpp index 294653f7342..d27743d8983 100644 --- a/mlir-compiler/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/mlir-compiler/src/lowering.cpp @@ -93,6 +93,7 @@ struct inst_handles Branch = mod.attr("Branch"); Jump = mod.attr("Jump"); SetItem = mod.attr("SetItem"); + StaticSetItem = mod.attr("StaticSetItem"); Arg = mod.attr("Arg"); Expr = mod.attr("Expr"); @@ -116,6 +117,7 @@ struct inst_handles py::handle Branch; py::handle Jump; py::handle SetItem; + py::handle StaticSetItem; py::handle Arg; py::handle Expr; @@ -232,7 +234,8 @@ struct plier_lowerer final auto val = lower_assign(inst, target); storevar(val, target); } - else if (py::isinstance(inst, insts.SetItem)) + else if (py::isinstance(inst, insts.SetItem) || + py::isinstance(inst, insts.StaticSetItem)) { setitem(inst.attr("target"), inst.attr("index"), inst.attr("value")); } @@ -473,7 +476,15 @@ struct plier_lowerer final void setitem(const py::handle& target, const py::handle& index, const py::handle& value) { - builder.create(get_current_loc(), loadvar(target), loadvar(index), loadvar(value)); + auto ind = [&]()->mlir::Value + { + if (py::isinstance(index)) + { + return builder.create(get_current_loc(), index.cast()); + } + return loadvar(index); + }(); + builder.create(get_current_loc(), loadvar(target), ind, loadvar(value)); } void storevar(mlir::Value val, const py::handle& inst) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp index c7278aa6b2f..000e63702f3 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -176,13 +176,27 @@ mlir::Value unflatten(mlir::Type type, mlir::Location loc, mlir::OpBuilder& buil } } +void write_memref_desc(llvm::raw_ostream& os, mlir::MemRefType memref_type) +{ + if (memref_type.hasRank()) + { + os << memref_type.getRank(); + } + else + { + os << "?"; + } + os << "x"; + memref_type.getElementType().print(os); +} + std::string gen_to_memref_conversion_func_name(mlir::MemRefType memref_type) { assert(memref_type); std::string ret; llvm::raw_string_ostream ss(ret); ss << "__convert_to_memref_"; - memref_type.getElementType().print(ss); + write_memref_desc(ss, memref_type); ss.flush(); return ret; } @@ -193,7 +207,7 @@ std::string gen_from_memref_conversion_func_name(mlir::MemRefType memref_type) std::string ret; llvm::raw_string_ostream ss(ret); ss << "__convert_from_memref_"; - memref_type.getElementType().print(ss); + write_memref_desc(ss, memref_type); ss.flush(); return ret; } diff --git a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp index b6ae835edb4..6012f4cf2ca 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -50,6 +50,10 @@ struct ParallelToTbb : public mlir::OpRewritePattern { return mlir::failure(); } + if (!op->hasAttr(plier::attributes::getParallelName())) + { + return mlir::failure(); + } int64_t max_concurrency = 0; auto mod = op->getParentOfType(); @@ -96,16 +100,16 @@ struct ParallelToTbb : public mlir::OpRewritePattern auto orig_step = op.step().front(); auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value thread_index) { - mapping.map(orig_lower_bound, lower_bound); - mapping.map(orig_upper_bound, upper_bound); for (auto it : llvm::enumerate(op.initVals())) { auto reduce_var = reduce_vars[it.index()]; auto val = builder.create(loc, reduce_var, thread_index); mapping.map(it.value(), val); } - auto new_op = builder.clone(*op, mapping); + auto new_op = mlir::cast(builder.clone(*op, mapping)); assert(new_op->getNumResults() == reduce_vars.size()); + new_op.lowerBoundMutable().assign(lower_bound); + new_op.upperBoundMutable().assign(upper_bound); for (auto it : llvm::enumerate(new_op->getResults())) { auto reduce_var = reduce_vars[it.index()]; diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 6082dd3662f..38d62bd412e 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -700,6 +700,10 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addPass(mlir::createFuncBufferizePass()); pm.addNestedPass(mlir::createFinalizingBufferizePass()); + pm.addNestedPass(mlir::createBufferHoistingPass()); + pm.addNestedPass(mlir::createBufferLoopHoistingPass()); + pm.addNestedPass(mlir::createPromoteBuffersToStackPass()); + pm.addNestedPass(mlir::createBufferDeallocationPass()); pm.addPass(std::make_unique()); diff --git a/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp index 0acdae74116..6ae712a285e 100644 --- a/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp +++ b/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp @@ -133,6 +133,10 @@ mlir::LogicalResult plier::PromoteToParallel::matchAndRewrite(mlir::scf::ForOp o }; auto parallel_op = rewriter.create(op.getLoc(), op.lowerBound(), op.upperBound(), op.step(), op.initArgs(), body_builder); + if (has_parallel_attr) + { + parallel_op->setAttr(plier::attributes::getParallelName(), rewriter.getUnitAttr()); + } rewriter.replaceOp(op, parallel_op.getResults()); return mlir::success(); diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index d1b8493d589..1dbcb977317 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -88,6 +88,15 @@ def py_func(a, b, c): arr3 = np.asarray([7,8,9]) assert_equal(py_func(arr1, arr2, arr3), jit_func(arr1, arr2, arr3)) + def test_static_setitem(self): + def py_func(a): + a[1] = 42 + return a[1] + + jit_func = njit(py_func) + arr = np.asarray([1,2,3]) + assert_equal(py_func(arr), jit_func(arr)) + def test_setitem1(self): def py_func(a, b): a[b] = 42 @@ -189,5 +198,16 @@ def py_func(a): arr = np.array([1,2,3]) assert_equal(py_func(arr), jit_func(arr)) + def test_array_prange_const(self): + def py_func(a, b): + a[0] = 42 + for i in numba.prange(b): + a[0] = 1 + return a[0] + + jit_func = njit(py_func, parallel=True) + arr = np.array([0.0]) + assert_equal(py_func(arr, 5), jit_func(arr, 5)) + if __name__ == '__main__': unittest.main() From c22682f145042e2c8debe054876bca0c3c36762f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 21 Feb 2021 15:28:36 +0300 Subject: [PATCH 235/259] [MLIR] Numpy empty and sum axis (#187) * store context in Var * rework shape accessor * dtype accessor rework * remove unused code * refactor shape * remove unused code * fix setitem lowering * numpy empty * numpy.sum * some kwargs support * linlag resolver kwargs support * linalg resolver some literal support * work on linalg resolver * add symbolDCE pass * numpy sum axis support --- .../src/pipelines/plier_to_linalg.cpp | 23 +- .../src/pipelines/plier_to_std.cpp | 32 +- .../mlir-compiler/src/py_linalg_resolver.cpp | 399 ++++++++++++++---- .../mlir-compiler/src/py_linalg_resolver.hpp | 4 +- .../include/plier/rewrites/call_lowering.hpp | 2 +- .../plier/src/rewrites/call_lowering.cpp | 23 +- numba/mlir/linalg_builder.py | 29 +- numba/mlir/numpy/funcs.py | 59 ++- numba/mlir/tests/test_numpy.py | 33 ++ 9 files changed, 456 insertions(+), 148 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 38d62bd412e..9fec83eef42 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -133,8 +133,12 @@ bool is_int(mlir::Type type) return type.isa(); } -mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } if ((operands.size() < 1 || operands.size() > 3) || !llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());})) { @@ -177,9 +181,9 @@ struct CallLowerer { mlir::LogicalResult operator()( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, - mlir::PatternRewriter& rewriter) + llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { - using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, llvm::ArrayRef>, mlir::PatternRewriter&); std::pair handlers[] = { {"numba.prange", lower_prange}, }; @@ -187,11 +191,11 @@ struct CallLowerer { if (handler.first == name) { - return handler.second(op, args, rewriter); + return handler.second(op, args, kwargs, rewriter); } } - if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args)) + if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args, kwargs)) { assert(result->size() == op->getNumResults()); rerun_std_pipeline(op); @@ -206,7 +210,7 @@ struct CallLowerer return mlir::success(); } - if (name == "len" && check_numpy_args(args, 1)) + if (name == "len" && check_numpy_args(args, 1) && kwargs.empty()) { auto loc = op.getLoc(); mlir::Value dim = rewriter.create(loc, args[0], 0); @@ -219,7 +223,6 @@ struct CallLowerer } private: - PyLinalgResolver linalg_resolver; }; @@ -436,7 +439,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern mlir::OpBuilder::InsertionGuard g(rewriter); if (auto parent_op = target.getDefiningOp()) { - rewriter.setInsertionPoint(parent_op); + rewriter.setInsertionPointAfter(parent_op); } else { @@ -456,6 +459,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern } else { + mlir::OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(use_op); auto new_val = rewriter.create(use_op->getLoc(), memref); rewriter.updateRootInPlace(use_op, [&]() @@ -602,6 +606,7 @@ struct LowerLinalgPass : mlir::DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); registry.insert(); registry.insert(); registry.insert(); @@ -686,6 +691,7 @@ void PostLinalgOptPass::runOnOperation() void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); + pm.addPass(mlir::createSymbolDCEPass()); } void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) @@ -708,6 +714,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addPass(std::make_unique()); pm.addPass(std::make_unique()); + pm.addPass(mlir::createSymbolDCEPass()); } } diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index 1718da5fc94..6f22320bd18 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1156,8 +1156,12 @@ struct FoldTupleGetitem : public mlir::OpRewritePattern } }; -mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } if ((operands.size() < 1 || operands.size() > 3) || !llvm::all_of(operands, [](mlir::Value val) { return is_int(val.getType());})) { @@ -1191,8 +1195,12 @@ mlir::LogicalResult lower_range(plier::PyCallOp op, llvm::ArrayRef return mlir::success(); } -mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } if (operands.size() != 1) { return mlir::failure(); @@ -1210,8 +1218,12 @@ mlir::LogicalResult lower_len(plier::PyCallOp op, llvm::ArrayRef op return mlir::success(); } -mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, mlir::PatternRewriter& rewriter) +mlir::LogicalResult lower_bool_cast(plier::PyCallOp op, llvm::ArrayRef operands, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } if (operands.size() != 1) { return mlir::failure(); @@ -1250,8 +1262,12 @@ mlir::FuncOp get_lib_symbol( mlir::LogicalResult lower_math_func( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, - mlir::PatternRewriter& rewriter) + llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { + if (!kwargs.empty()) + { + return mlir::failure(); + } auto ret_type = map_plier_type(op.getType()); auto valid_type = [&](mlir::Type type) { @@ -1285,14 +1301,14 @@ mlir::LogicalResult lower_math_func( struct CallLowerer { mlir::LogicalResult operator()(plier::PyCallOp op, llvm::StringRef name, - llvm::ArrayRef args, mlir::PatternRewriter& rewriter) + llvm::ArrayRef args, llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) { - if (mlir::succeeded(lower_math_func(op, name, args, rewriter))) + if (mlir::succeeded(lower_math_func(op, name, args, kwargs, rewriter))) { return mlir::success(); } - using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, mlir::PatternRewriter&); + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, llvm::ArrayRef>, mlir::PatternRewriter&); std::pair handlers[] = { {"bool", lower_bool_cast}, {"range", lower_range}, @@ -1302,7 +1318,7 @@ struct CallLowerer { if (handler.first == name) { - return handler.second(op, args, rewriter); + return handler.second(op, args, kwargs, rewriter); } } diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 181c19a9166..ec2ed44bd9a 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -27,26 +27,19 @@ struct PyBuilderContext namespace { -bool is_compatible_types(mlir::TypeRange types) +bool is_compatible_type(mlir::Type type) { - return !types.empty() && llvm::all_of(types, [](mlir::Type t) + if (auto tuple_type = type.dyn_cast()) { - return t.isIntOrFloat() || t.isa(); - }); -} - -py::handle get_dim(int64_t val) -{ - if (val == -1) - { - return py::none(); + return llvm::all_of(tuple_type, &is_compatible_type); } - return py::int_(val); + return type.isIntOrFloat() || type.isa(); } -size_t py_func_arg_count(py::handle signature, py::handle func) +template +bool is_compatible_types(R&& vals) { - return py::len(signature(func).attr("parameters")); + return llvm::all_of(vals, [](auto val) { return is_compatible_type(val.getType()); }); } template @@ -66,17 +59,6 @@ auto unwrap_ssa_val(py::handle obj) return unwrap_mlir(obj.attr("_ssa_val").cast()); } -auto unwrap_shape(py::list shape) -{ - llvm::SmallVector ret; - ret.reserve(shape.size()); - for (auto elem : shape) - { - ret.push_back(unwrap_ssa_val(elem)); - } - return ret; -} - size_t container_size(py::handle obj) { if (py::isinstance(obj)) @@ -113,38 +95,50 @@ void container_iterate(py::handle obj, F&& func) func(std::size_t(0), obj); } } + +llvm::Optional make_py_literal(mlir::Value val) +{ + if (auto int_val = plier::getConstVal(val)) + { + return py::int_(int_val.getInt()); + } + if (auto float_val = plier::getConstVal(val)) + { + return py::float_(float_val.getValueAsDouble()); + } + return {}; +} + +mlir::Value do_cast(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value val, mlir::Type type) +{ + if (val.getType() != type) + { + return builder.create(loc, type, val); + } + return val; +} + +void setup_py_var(py::handle var); } struct PyLinalgResolver::Context { py::handle var; - py::handle val; py::handle builder; - py::handle signature; + py::handle inspect; py::handle types_mod; py::handle compile_func; py::handle lookup_func; - py::object create_var(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value value) + py::object create_var(py::capsule context, mlir::Value value) { - if (value.getType().isa()) + if (auto literal = make_py_literal(value)) { - auto make_dim_val = [&](auto dim, auto ssa_val) - { - return val(get_dim(dim), wrap_mlir(ssa_val)); - }; - auto mlir_type = value.getType().cast(); - auto shape = mlir_type.getShape(); - auto elem_type = mlir_type.getElementType(); - py::list py_shape(shape.size()); - for (auto it2 : llvm::enumerate(shape)) - { - mlir::Value mlir_dim = builder.create(loc, value, it2.index()); - py_shape[it2.index()] = make_dim_val(it2.value(), mlir_dim); - } - return var(wrap_mlir(value), py_shape, wrap_mlir(elem_type)); + return *literal; } - return var(wrap_mlir(value), py::list(), wrap_mlir(value.getType())); + auto ret = var(context, wrap_mlir(value)); + setup_py_var(ret); + return ret; } mlir::FuncOp compile_body(py::handle body, py::list arg_types) @@ -156,7 +150,7 @@ struct PyLinalgResolver::Context return mlir_func; } - py::object wrap_result(mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange values) + py::object wrap_result(py::capsule context, mlir::ValueRange values) { if (values.empty()) { @@ -164,28 +158,96 @@ struct PyLinalgResolver::Context } if (values.size() == 1) { - return create_var(loc, builder, values.front()); + return create_var(context, values.front()); } py::tuple ret(values.size()); for (auto it : llvm::enumerate(values)) { - ret[it.index()] = create_var(loc, builder, it.value()); + ret[it.index()] = create_var(context, it.value()); } return std::move(ret); } + + mlir::Value unwrap_val(mlir::Location loc, mlir::OpBuilder& builder, py::handle obj) + { + if (py::isinstance(obj, var)) + { + return unwrap_ssa_val(obj); + } + if (py::isinstance(obj)) + { + auto attr = builder.getI64IntegerAttr(obj.cast()); + return builder.create(loc, attr); + } + if (py::isinstance(obj)) + { + auto attr = builder.getF64FloatAttr(obj.cast()); + return builder.create(loc, attr); + } + plier::report_error("Invalid element type"); + } }; namespace { - -PyBuilderContext& get_py_context(py::capsule& ctx) +py::list get_args(py::handle inspect, py::handle func, llvm::function_ref create_var, + mlir::ValueRange args, llvm::ArrayRef> kwargs) { - return *static_cast(ctx); + auto sig_func = inspect.attr("signature"); + auto sig = sig_func(func); + auto params = sig.attr("parameters"); + auto params_list = py::list(params); + params_list = params_list[py::slice(1, static_cast(params_list.size()), 1)]; // skip builder param + auto empty = inspect.attr("Parameter").attr("empty"); + + py::list ret(py::len(params_list)); + for (auto it : llvm::enumerate(params_list)) + { + auto index = it.index(); + auto param_name = it.value(); + auto param = params[param_name]; + if (!args.empty()) + { + ret[index] = create_var(args.front()); + args = args.drop_front(); + continue; + } + if (!kwargs.empty()) + { + auto name = param_name.cast(); + auto val = [&]()->mlir::Value + { + for (auto kwarg : kwargs) + { + if (kwarg.first == name) + { + return kwarg.second; + } + } + return {}; + }(); + if (val) + { + ret[index] = create_var(val); + continue; + } + } + auto def_val = param.attr("default"); + if (!def_val.is(empty)) + { + ret[index] = def_val; + } + else + { + return py::none(); + } + } + return ret; } -mlir::Value get_var_value(py::handle var) +PyBuilderContext& get_py_context(py::capsule& ctx) { - return unwrap_mlir(var.attr("_ssa_val").cast()); + return *static_cast(ctx); } auto get_types(mlir::ValueRange values) @@ -193,21 +255,25 @@ auto get_types(mlir::ValueRange values) return values.getTypes(); } -auto get_agrs_from_tuple(py::handle args) +auto get_agrs_from_tuple(py::handle args, llvm::function_ref unpack) { llvm::SmallVector ret; + if (args.is_none()) + { + return ret; + } if (py::isinstance(args)) { auto tuple = args.cast(); ret.resize(tuple.size()); for (auto it : llvm::enumerate(tuple)) { - ret[it.index()] = get_var_value(it.value()); + ret[it.index()] = unpack(it.value()); } } else { - ret.emplace_back(get_var_value(args)); + ret.emplace_back(unpack(args)); } return ret; } @@ -280,23 +346,89 @@ py::object broadcast_impl(py::capsule /*context*/, py::tuple args) } } -py::object init_tensor_impl(py::capsule context, py::list shape, py::capsule dtype) +py::object init_tensor_impl(py::capsule context, py::handle shape, py::capsule dtype, py::handle init_val) { auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; auto elem_type = unwrap_mlir(dtype); mlir::Value init; - if (shape.empty()) + auto count = py::len(shape); + if (count == 0) { - // TODO: undef - auto zero_val = plier::getZeroVal(elem_type); - assert(zero_val); - init = ctx.builder.create(ctx.loc, zero_val); + if (init_val.is_none()) + { + // TODO: undef + auto zero_val = plier::getZeroVal(elem_type); + assert(zero_val); + init = builder.create(loc, zero_val); + } + else + { + init = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type); + } } else { - init = ctx.builder.create(ctx.loc, unwrap_shape(shape), elem_type); + auto index_type = builder.getIndexType(); + llvm::SmallVector shape_val(count); + for (size_t i = 0; i < count; ++i) + { + shape_val[i] = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, shape[py::int_(i)]), index_type); + } + + if (init_val.is_none()) + { + init = builder.create(loc, shape_val, elem_type); + } + else + { + auto val = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/) + { + builder.create(loc, val); + }; + llvm::SmallVector shape(count, -1); + auto type = mlir::RankedTensorType::get(shape, elem_type); + init = builder.create(loc, type, shape_val, body); + } } - return ctx.context.create_var(ctx.loc, ctx.builder, init); + return ctx.context.create_var(context, init); +} + +py::object fill_tensor_impl(py::capsule context, py::handle tensor, py::handle value) +{ + auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; + auto tensor_val = ctx.context.unwrap_val(loc, builder, tensor); + auto tensor_type = tensor_val.getType().cast(); + auto init_val = ctx.context.unwrap_val(loc, builder, value); + if (init_val.getType() != tensor_type.getElementType()) + { + init_val = builder.create(loc, tensor_type.getElementType(), init_val); + } + +// auto val = builder.create(loc, tensor_type, tensor_val, init_val); + auto rank = static_cast(tensor_type.getRank()); + mlir::AffineMap affine_maps[] = { + mlir::AffineMap::getMultiDimIdentityMap(rank, builder.getContext()), + }; + llvm::SmallVector iterators(rank, "parallel"); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) + { + assert(values.size() == 1); + builder.create(loc, init_val); + }; + auto val = builder.create( + loc, + tensor_type, + llvm::None, + tensor_val, + affine_maps, + iterators, + body); + return ctx.context.create_var(context, val.getResult(0)); } py::object generic_impl(py::capsule context, py::handle inputs, py::handle outputs, py::list iterators, py::list maps, py::handle body) @@ -306,8 +438,13 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu auto& builder = ctx.builder; auto& mlir_context = *builder.getContext(); - auto inputs_args = get_agrs_from_tuple(inputs); - auto output_args = get_agrs_from_tuple(outputs); + auto unpack = [&](py::handle obj)->mlir::Value + { + return ctx.context.unwrap_val(loc, builder, obj); + }; + + auto inputs_args = get_agrs_from_tuple(inputs, unpack); + auto output_args = get_agrs_from_tuple(outputs, unpack); auto ret_types = get_types(output_args); auto mlir_iterators = get_iterators(iterators, mlir_context); @@ -337,7 +474,7 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu { inputs_args.append(output_args.begin(), output_args.end()); auto res = builder.create(loc, body_func, inputs_args); - return ctx.context.wrap_result(loc, builder, cast_values(res.getResults(), ret_types)); + return ctx.context.wrap_result(context, cast_values(res.getResults(), ret_types)); } else { @@ -359,7 +496,7 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu affine_maps, mlir_iterators, body_builder); - return ctx.context.wrap_result(loc, builder, generic_op.getResults()); + return ctx.context.wrap_result(context, generic_op.getResults()); } } @@ -400,7 +537,7 @@ py::object from_elements_impl(py::capsule context, py::handle values, py::capsul } }); auto res = builder.create(loc, vals); - return ctx.context.create_var(ctx.loc, ctx.builder, res); + return ctx.context.create_var(context, res); } py::object extract_impl(py::capsule context, py::handle value, py::handle indices) @@ -425,14 +562,15 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice plier::report_error("Invalid element type"); } }); - auto res = builder.create(loc, get_var_value(value), ind); - return ctx.context.create_var(ctx.loc, ctx.builder, res); + auto res = builder.create(loc, ctx.context.unwrap_val(loc, builder, value), ind); + return ctx.context.create_var(context, res); } void setup_py_builder(py::handle builder, mlir::OpBuilder& b) { py::setattr(builder, "_broadcast", py::cpp_function(&broadcast_impl)); py::setattr(builder, "_init_tensor", py::cpp_function(&init_tensor_impl)); + py::setattr(builder, "_fill_tensor", py::cpp_function(&fill_tensor_impl)); py::setattr(builder, "_generic", py::cpp_function(&generic_impl)); py::setattr(builder, "_from_elements", py::cpp_function(&from_elements_impl)); py::setattr(builder, "_extract", py::cpp_function(&extract_impl)); @@ -452,6 +590,90 @@ void setup_py_builder(py::handle builder, mlir::OpBuilder& b) add_type("float64", b.getF64Type()); } +py::object shape_impl(py::capsule context, py::capsule ssa_val) +{ + auto& ctx = get_py_context(context); + auto value = unwrap_mlir(ssa_val); + if (value.getType().isa()) + { + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto mlir_type = value.getType().cast(); + auto shape = mlir_type.getShape(); + llvm::SmallVector shape_vals(shape.size()); + for (auto it : llvm::enumerate(shape)) + { + auto i = it.index(); + mlir::Value mlir_dim = builder.create(loc, value, i); + shape_vals[i] = mlir_dim; + } + llvm::SmallVector shape_types(shape.size(), builder.getIndexType()); + auto shape_type = mlir::TupleType::get(builder.getContext(), shape_types); + auto shape_var = builder.create(loc, shape_type, shape_vals); + return ctx.context.create_var(context, shape_var.getResult()); + } + return py::list(); +} + +py::object dtype_impl(py::capsule /*context*/, py::capsule ssa_val) +{ + auto value = unwrap_mlir(ssa_val); + auto type = value.getType(); + if (auto tensor_type = type.dyn_cast()) + { + return wrap_mlir(tensor_type.getElementType()); + } + return wrap_mlir(type); +} + +py::object len_impl(py::capsule /*context*/, py::capsule ssa_val) +{ + auto value = unwrap_mlir(ssa_val); + auto type = value.getType(); + if (auto tuple_type = type.dyn_cast()) + { + return py::int_(tuple_type.size()); + } + return py::int_(1); +} + +py::object getitem_impl(py::capsule context, py::capsule ssa_val, py::handle index) +{ + auto& ctx = get_py_context(context); + auto value = unwrap_mlir(ssa_val); + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto index_val = index.cast(); + auto type = value.getType(); + if (auto tuple_type = type.dyn_cast()) + { + if (index_val < 0 || index_val >= static_cast(tuple_type.size())) + { + plier::report_error("Invelid getitem index"); + } + auto elem_type = tuple_type.getType(static_cast(index_val)); + auto ind = builder.create(loc, index_val); + auto item = builder.create(loc, elem_type, value, ind); + return ctx.context.create_var(context, item.getResult()); + } + else + { + if (0 != index_val) + { + plier::report_error("Invelid getitem index"); + } + return ctx.context.create_var(context, value); + } +} + +void setup_py_var(pybind11::handle var) +{ + py::setattr(var, "_shape", py::cpp_function(&shape_impl)); + py::setattr(var, "_dtype", py::cpp_function(&dtype_impl)); + py::setattr(var, "_len", py::cpp_function(&len_impl)); + py::setattr(var, "_getitem", py::cpp_function(&getitem_impl)); +} + PyLinalgResolver::Values unpack_results(py::handle object) { PyLinalgResolver::Values ret; @@ -479,9 +701,8 @@ PyLinalgResolver::PyLinalgResolver(): { auto builder_mod = py::module::import("numba.mlir.linalg_builder"); context->var = builder_mod.attr("Var"); - context->val = builder_mod.attr("Val"); context->builder = builder_mod.attr("Builder"); - context->signature = py::module::import("inspect").attr("signature"); + context->inspect = py::module::import("inspect"); context->types_mod = py::module::import("numba.core.types"); context->compile_func = builder_mod.attr("compile_func"); context->lookup_func = builder_mod.attr("lookup_func"); @@ -492,36 +713,40 @@ PyLinalgResolver::~PyLinalgResolver() } -llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args) +llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, llvm::ArrayRef> kwargs) { assert(!name.empty()); - if (!is_compatible_types(args.getTypes())) + if (!is_compatible_types(args) || + !is_compatible_types(llvm::make_second_range(kwargs))) { return {}; } auto builder_func = context->lookup_func(py::str(name.data(), name.size())); - if (builder_func.is_none() || py_func_arg_count(context->signature, builder_func) != (args.size() + 1)) + if (builder_func.is_none()) { return {}; } PyBuilderContext py_builder_context{loc, builder, *context}; - auto py_builder = context->builder(py::capsule(&py_builder_context)); - setup_py_builder(py_builder, builder); - - assert(!args.empty()); - auto module = args.front().getParentRegion()->getParentOfType(); - assert(module); - - py::list py_args(args.size()); - for (auto it : llvm::enumerate(args)) + auto py_context = py::capsule(&py_builder_context); + auto py_args = get_args( + context->inspect, + builder_func, + [&](auto val){ return context->create_var(py_context, val);}, + args, + kwargs); + if (py_args.is_none()) { - auto index = static_cast(it.index()); - auto mlir_arg = it.value(); - py_args[index] = context->create_var(loc, builder, mlir_arg); + return {}; } + auto py_builder = context->builder(py_context); + setup_py_builder(py_builder, builder); auto result = builder_func(py_builder, *py_args); + if (result.is_none()) + { + return {}; + } return unpack_results(result); } diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp index 80ca93da3d0..66769084b50 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp @@ -4,6 +4,7 @@ #include #include +#include namespace llvm { @@ -27,7 +28,8 @@ class PyLinalgResolver using Values = llvm::SmallVector; - llvm::Optional rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args); + llvm::Optional rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, + llvm::ArrayRef> kwargs); private: friend struct PyBuilderContext; diff --git a/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp index c686004d5e1..48d58f9958d 100644 --- a/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp @@ -15,7 +15,7 @@ namespace plier { struct CallOpLowering : public mlir::OpRewritePattern { - using resolver_t = llvm::function_ref, mlir::PatternRewriter&)>; + using resolver_t = llvm::function_ref, llvm::ArrayRef> , mlir::PatternRewriter&)>; CallOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, diff --git a/mlir-compiler/plier/src/rewrites/call_lowering.cpp b/mlir-compiler/plier/src/rewrites/call_lowering.cpp index 53e93c30d60..6de918fa317 100644 --- a/mlir-compiler/plier/src/rewrites/call_lowering.cpp +++ b/mlir-compiler/plier/src/rewrites/call_lowering.cpp @@ -18,23 +18,22 @@ mlir::LogicalResult plier::CallOpLowering::matchAndRewrite(plier::PyCallOp op, m return mlir::failure(); } - llvm::SmallVector arg_types; llvm::SmallVector args; + llvm::SmallVector, 8> kwargs; auto getattr = mlir::dyn_cast_or_null(operands[0].getDefiningOp()); - if (!getattr) + if (getattr) { - llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); - llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); - // TODO kwargs + args.push_back(getattr.getOperand()); } - else + auto kw_start = op.kw_start(); + operands = operands.drop_front(); + llvm::copy(operands.take_front(kw_start), std::back_inserter(args)); + for (auto it : llvm::zip(operands.drop_front(kw_start), op.kw_names())) { - arg_types.push_back(getattr.getOperand().getType()); - args.push_back(getattr.getOperand()); - llvm::copy(llvm::drop_begin(op.getOperandTypes(), 1), std::back_inserter(arg_types)); - llvm::copy(llvm::drop_begin(op.getOperands(), 1), std::back_inserter(args)); - // TODO kwargs + auto arg = std::get<0>(it); + auto name = std::get<1>(it).cast(); + kwargs.emplace_back(name.getValue(), arg); } - return resolver(op, op.func_name(), args, rewriter); + return resolver(op, op.func_name(), args, kwargs, rewriter); } diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 0d3c0bb8bfa..7167a7211de 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -1,28 +1,26 @@ from .func_registry import add_func class Var: - def __init__(self, ssa_val, shape, dtype): + def __init__(self, context, ssa_val): + self._context = context self._ssa_val = ssa_val - self._shape = shape - self._dtype = dtype @property def shape(self): - return self._shape + return self._shape(self._context, self._ssa_val) @property def dtype(self): - return self._dtype - + return self._dtype(self._context, self._ssa_val) + def __len__(self): + return self._len(self._context, self._ssa_val) -class Val: - def __init__(self, const_val, ssa_val): - self._const_val = const_val - self._ssa_val = ssa_val + def __getitem__(self, index): + return self._getitem(self._context, self._ssa_val, index) - def is_const(self): - return not _const_val is None +def is_literal(val): + return not isinstance(val, Var) class Builder: def __init__(self, context): @@ -31,8 +29,11 @@ def __init__(self, context): def broadcast(self, *args): return self._broadcast(self._context, args) - def init_tensor(self, shape, dtype): - return self._init_tensor(self._context, shape, dtype) + def init_tensor(self, shape, dtype, init_val=None): + return self._init_tensor(self._context, shape, dtype, init_val) + + def fill_tensor(self, tensor, value): + return self._fill_tensor(self._context, tensor, value) def generic(self, inputs, outputs, iterators, maps, body): return self._generic(self._context, inputs, outputs, iterators, maps, body) diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index f6ef37d15e4..24050d43437 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -1,4 +1,4 @@ -from ..linalg_builder import register_func +from ..linalg_builder import register_func, is_literal import numpy import math @@ -31,22 +31,42 @@ def body(a, b, c): return eltwise(builder, (arg1, arg2), body) @register_func('array.sum') -def sum_impl(builder, arg): - shape = arg.shape +@register_func('numpy.sum', numpy.sum) +def sum_impl(builder, arg, axis=None): + if axis is None: + shape = arg.shape + num_dims = len(shape) + iterators = ['reduction' for _ in range(num_dims)] + dims = ','.join(['d%s' % i for i in range(num_dims)]) + expr1 = f'({dims}) -> ({dims})' + expr2 = f'({dims}) -> (0)' + maps = [expr1,expr2] + init = builder.from_elements(0, arg.dtype) + + def body(a, b): + return a + b + + res = builder.generic(arg, init, iterators, maps, body) + return builder.extract(res, 0) + elif isinstance(axis, int): + shape = arg.shape + num_dims = len(shape) + iterators = [('reduction' if i == axis else 'parallel') for i in range(num_dims)] + dims1 = ','.join(['d%s' % i for i in range(num_dims)]) + dims2 = ','.join(['d%s' % i for i in range(num_dims) if i != axis]) + expr1 = f'({dims1}) -> ({dims1})' + expr2 = f'({dims1}) -> ({dims2})' + maps = [expr1,expr2] + res_shape = tuple(shape[i] for i in range(len(shape)) if i != axis) + + init = builder.init_tensor(res_shape, builder.int64, 0) #TODO: type + # val = builder.fill_tensor(init, 0) + + def body(a, b): + return a + b + + return builder.generic(arg, init, iterators, maps, body) - num_dims = len(shape) - iterators = ['reduction' for _ in range(num_dims)] - dims = ','.join(['d%s' % i for i in range(num_dims)]) - expr1 = f'({dims}) -> ({dims})' - expr2 = f'({dims}) -> (0)' - maps = [expr1,expr2] - init = builder.from_elements(0, arg.dtype) - - def body(a, b): - return a + b - - res = builder.generic(arg, init, iterators, maps, body) - return builder.extract(res, 0) @register_func('numpy.sqrt', numpy.sqrt) def sqrt_impl(builder, arg): @@ -57,9 +77,14 @@ def body(a, b): return eltwise(builder, arg, body, builder.float64) @register_func('numpy.square', numpy.square) -def quare_impl(builder, arg): +def square_impl(builder, arg): def body(a, b): return a * a return eltwise(builder, arg, body) + +@register_func('numpy.empty', numpy.empty) +def empty_impl(builder, shape): + # TODO: dtype + return builder.init_tensor(shape, builder.float64) diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 1dbcb977317..405d5cdc019 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -40,6 +40,7 @@ def py_func(a): def test_unary(self): funcs = [ lambda a: a.sum(), + lambda a: np.sum(a), lambda a: np.sqrt(a), lambda a: np.square(a), ] @@ -50,6 +51,17 @@ def test_unary(self): arr = np.array(a) assert_equal(py_func(arr), jit_func(arr)) + def test_sum_axis(self): + funcs = [ + lambda a: np.sum(a, axis=0), + lambda a: np.sum(a, axis=1), + ] + + for py_func in funcs: + jit_func = njit(py_func) + arr = np.array([[1,2,3],[4,5,6]]) + assert_equal(py_func(arr), jit_func(arr)) + def test_add(self): def py_func(a, b): return np.add(a, b) @@ -209,5 +221,26 @@ def py_func(a, b): arr = np.array([0.0]) assert_equal(py_func(arr, 5), jit_func(arr, 5)) + def test_empty1(self): + def py_func(d): + a = np.empty(d) + for i in range(d): + a[i] = i + return a + + jit_func = njit(py_func) + assert_equal(py_func(5), jit_func(5)) + + def test_empty2(self): + def py_func(d1, d2): + a = np.empty((d1, d2)) + for i in range(d1): + for j in range(d2): + a[i, j] = i + j * 10 + return a + + jit_func = njit(py_func) + assert_equal(py_func(5, 7), jit_func(5, 7)) + if __name__ == '__main__': unittest.main() From ee7dfd71a577e2aeca28f60e4950bb421a35d37e Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 22 Feb 2021 18:01:01 +0300 Subject: [PATCH 236/259] [MLIR] Numpy dot, sum with axis, array.size, array.T (#188) * refac * proper types for numpy.sum * move test * numpy dot 1D * numpy dot 2D * adapt linalg resolver to attrs * array size support * refac * array transpose --- .../src/pipelines/plier_to_linalg.cpp | 109 ++++++++++++++++-- .../mlir-compiler/src/py_linalg_resolver.cpp | 101 ++++++++++++++-- .../mlir-compiler/src/py_linalg_resolver.hpp | 10 +- .../mlir-compiler/src/py_map_types.cpp | 4 +- numba/mlir/linalg_builder.py | 24 +++- numba/mlir/numpy/funcs.py | 73 +++++++++++- numba/mlir/tests/test_numpy.py | 39 +++++-- 7 files changed, 317 insertions(+), 43 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 9fec83eef42..1966cbd8cb9 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -43,9 +43,25 @@ namespace { -bool parse_layout(llvm::StringRef& name) +enum class ArrayLayout { - return name.consume_back("C"); // TODO + C, + F +}; + +bool parse_layout(llvm::StringRef& name, ArrayLayout& layout) +{ + if (name.consume_back("C")) + { + layout = ArrayLayout::C; + return true; + } + if (name.consume_back("F")) + { + layout = ArrayLayout::F; + return true; + } + return false; } template @@ -67,24 +83,43 @@ bool consume_int_back(llvm::StringRef& name, T& result) return false; } -mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter, - llvm::StringRef& name) +struct ArrayDesc +{ + unsigned dims = 0; + ArrayLayout layout = {}; + llvm::StringRef name; +}; + +llvm::Optional parse_array_desc(llvm::StringRef& name) { unsigned num_dims = 0; + ArrayLayout layout = {}; if (name.consume_front("array(") && name.consume_back(")") && - parse_layout(name) && + parse_layout(name, layout) && name.consume_back(", ") && name.consume_back("d") && consume_int_back(name, num_dims) && name.consume_back(", ") && !name.empty()) { - if (auto type = conveter.convertType(plier::PyType::get(&ctx, name))) + return ArrayDesc{num_dims, layout, name}; + } + return {}; +} + +mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter, + llvm::StringRef& name) +{ + if (auto desc = parse_array_desc(name)) + { + if (desc->layout == ArrayLayout::C) { - llvm::SmallVector shape(num_dims, -1); -// return mlir::MemRefType::get(shape, type); - return mlir::RankedTensorType::get(shape, type); + if (auto type = conveter.convertType(plier::PyType::get(&ctx, desc->name))) + { + llvm::SmallVector shape(desc->dims, -1); + return mlir::RankedTensorType::get(shape, type); + } } } return nullptr; @@ -181,7 +216,8 @@ struct CallLowerer { mlir::LogicalResult operator()( plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, - llvm::ArrayRef> kwargs, mlir::PatternRewriter& rewriter) + llvm::ArrayRef> kwargs, + mlir::PatternRewriter& rewriter) { using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, llvm::ArrayRef>, mlir::PatternRewriter&); std::pair handlers[] = { @@ -195,7 +231,7 @@ struct CallLowerer } } - if (auto result = linalg_resolver.rewrite(name, op.getLoc(), rewriter, args, kwargs)) + if (auto result = linalg_resolver.rewrite_func(name, op.getLoc(), rewriter, args, kwargs)) { assert(result->size() == op->getNumResults()); rerun_std_pipeline(op); @@ -222,6 +258,32 @@ struct CallLowerer return mlir::failure(); } + mlir::LogicalResult operator()( + plier::GetattrOp op, llvm::StringRef name, mlir::Value arg, + mlir::PatternRewriter& rewriter) + { + if (!arg.getType().isa()) + { + return mlir::failure(); + } + auto full_name = (llvm::Twine("array.") + name).str(); + if (auto result = linalg_resolver.rewrite_attr(full_name, op.getLoc(), rewriter, arg)) + { + assert(result->size() == op->getNumResults()); + rerun_std_pipeline(op); + if (result->empty()) + { + rewriter.eraseOp(op); + } + else + { + rewriter.replaceOp(op, *result); + } + return mlir::success(); + } + return mlir::failure(); + } + private: PyLinalgResolver linalg_resolver; }; @@ -550,6 +612,28 @@ struct ArrayShape : public mlir::OpRewritePattern mlir::TypeConverter& converter; }; +struct GetattrRewriter : public mlir::OpRewritePattern +{ + using resolver_t = llvm::function_ref; + + GetattrRewriter(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context, + resolver_t resolver): + OpRewritePattern(context), + resolver(resolver) + {} + + mlir::LogicalResult matchAndRewrite( + plier::GetattrOp op, mlir::PatternRewriter &rewriter) const override + { + return resolver(op, op.name(), op.value(), rewriter); + } + +private: + resolver_t resolver; +}; + void PlierToLinalgPass::runOnOperation() { @@ -579,7 +663,8 @@ void PlierToLinalgPass::runOnOperation() CallLowerer callLowerer; patterns.insert< - plier::CallOpLowering + plier::CallOpLowering, + GetattrRewriter >(type_converter, context, callLowerer); patterns.insert< diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index ec2ed44bd9a..76b706317ba 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -59,6 +59,11 @@ auto unwrap_ssa_val(py::handle obj) return unwrap_mlir(obj.attr("_ssa_val").cast()); } +auto unwrap_type(py::handle obj) +{ + return unwrap_mlir(obj.attr("_mlir_type").cast()); +} + size_t container_size(py::handle obj) { if (py::isinstance(obj)) @@ -118,12 +123,18 @@ mlir::Value do_cast(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value va return val; } +bool cmp_capsule(py::capsule a1, py::capsule a2) +{ + return static_cast(a1) == static_cast(a2); +} + void setup_py_var(py::handle var); } struct PyLinalgResolver::Context { py::handle var; + py::handle type; py::handle builder; py::handle inspect; py::handle types_mod; @@ -141,6 +152,11 @@ struct PyLinalgResolver::Context return ret; } + py::object create_type(mlir::Type t) + { + return type(wrap_mlir(t), py::cpp_function(&cmp_capsule)); + } + mlir::FuncOp compile_body(py::handle body, py::list arg_types) { auto func = compile_func(body, arg_types).cast(); @@ -346,12 +362,12 @@ py::object broadcast_impl(py::capsule /*context*/, py::tuple args) } } -py::object init_tensor_impl(py::capsule context, py::handle shape, py::capsule dtype, py::handle init_val) +py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dtype, py::handle init_val) { auto& ctx = get_py_context(context); auto loc = ctx.loc; auto& builder = ctx.builder; - auto elem_type = unwrap_mlir(dtype); + auto elem_type = unwrap_type(dtype); mlir::Value init; auto count = py::len(shape); if (count == 0) @@ -500,12 +516,12 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu } } -py::object from_elements_impl(py::capsule context, py::handle values, py::capsule dtype) +py::object from_elements_impl(py::capsule context, py::handle values, py::handle dtype) { auto& ctx = get_py_context(context); auto& builder = ctx.builder; auto loc = ctx.loc; - auto type = unwrap_mlir(dtype); + auto type = unwrap_type(dtype); llvm::SmallVector vals(container_size(values)); container_iterate(values, [&](auto index, py::handle obj) @@ -566,7 +582,7 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice return ctx.context.create_var(context, res); } -void setup_py_builder(py::handle builder, mlir::OpBuilder& b) +void setup_py_builder(py::handle builder, mlir::OpBuilder& b, llvm::function_ref create_type) { py::setattr(builder, "_broadcast", py::cpp_function(&broadcast_impl)); py::setattr(builder, "_init_tensor", py::cpp_function(&init_tensor_impl)); @@ -577,13 +593,14 @@ void setup_py_builder(py::handle builder, mlir::OpBuilder& b) auto add_type = [&](const char* name, mlir::Type type) { - py::setattr(builder, name, wrap_mlir(type)); + py::setattr(builder, name, create_type(type)); }; add_type("int8", b.getIntegerType(8)); add_type("int16", b.getIntegerType(16)); add_type("int32", b.getIntegerType(32)); add_type("int64", b.getIntegerType(64)); + add_type("index", b.getIndexType()); add_type("float16", b.getF16Type()); add_type("float32", b.getF32Type()); @@ -615,15 +632,16 @@ py::object shape_impl(py::capsule context, py::capsule ssa_val) return py::list(); } -py::object dtype_impl(py::capsule /*context*/, py::capsule ssa_val) +py::object dtype_impl(py::capsule context, py::capsule ssa_val) { + auto& ctx = get_py_context(context); auto value = unwrap_mlir(ssa_val); auto type = value.getType(); if (auto tensor_type = type.dyn_cast()) { - return wrap_mlir(tensor_type.getElementType()); + return ctx.context.create_type(tensor_type.getElementType()); } - return wrap_mlir(type); + return ctx.context.create_type(type); } py::object len_impl(py::capsule /*context*/, py::capsule ssa_val) @@ -666,12 +684,61 @@ py::object getitem_impl(py::capsule context, py::capsule ssa_val, py::handle ind } } +template +mlir::Value binop_func(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value lhs, mlir::Value rhs) +{ + return builder.create(loc, lhs, rhs); +} + +py::object binop_impl(py::capsule context, py::capsule ssa_val, py::handle rhs, py::str op) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto loc = ctx.loc; + auto lhs = unwrap_mlir(ssa_val); + + auto type = lhs.getType(); + if (!type.isa()) + { + plier::report_error("Invalid binop arg type"); + } + + auto is_float = [&]()->bool + { + if (auto shaped_type = type.dyn_cast()) + { + return shaped_type.getElementType().isa(); + } + return type.isa(); + }(); + + using binop_func_t = mlir::Value(*)(mlir::Location loc, mlir::OpBuilder& builder, mlir::Value lhs, mlir::Value rhs); + const std::tuple funcs[] = { + {"*", &binop_func, &binop_func}, + }; + + auto op_name = static_cast(op); + for (auto f : funcs) + { + auto name = std::get<0>(f); + auto func = (is_float ? std::get<2>(f) : std::get<1>(f)); + if (name == op_name) + { + auto rhs_var = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, rhs), type); + auto res = func(loc, builder, lhs, rhs_var); + return ctx.context.create_var(context, res); + } + } + plier::report_error("Unhandled binop type"); +} + void setup_py_var(pybind11::handle var) { py::setattr(var, "_shape", py::cpp_function(&shape_impl)); py::setattr(var, "_dtype", py::cpp_function(&dtype_impl)); py::setattr(var, "_len", py::cpp_function(&len_impl)); py::setattr(var, "_getitem", py::cpp_function(&getitem_impl)); + py::setattr(var, "_binop", py::cpp_function(&binop_impl)); } PyLinalgResolver::Values unpack_results(py::handle object) @@ -701,6 +768,7 @@ PyLinalgResolver::PyLinalgResolver(): { auto builder_mod = py::module::import("numba.mlir.linalg_builder"); context->var = builder_mod.attr("Var"); + context->type = builder_mod.attr("Type"); context->builder = builder_mod.attr("Builder"); context->inspect = py::module::import("inspect"); context->types_mod = py::module::import("numba.core.types"); @@ -713,7 +781,18 @@ PyLinalgResolver::~PyLinalgResolver() } -llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, llvm::ArrayRef> kwargs) +llvm::Optional PyLinalgResolver::rewrite_func(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, KWArgs kwargs) +{ + auto mangled_name = (llvm::Twine(name) + "()").str(); + return rewrite(mangled_name, loc, builder, args, kwargs); +} + +llvm::Optional PyLinalgResolver::rewrite_attr(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::Value arg) +{ + return rewrite(name, loc, builder, arg, {}); +} + +llvm::Optional PyLinalgResolver::rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, KWArgs kwargs) { assert(!name.empty()); if (!is_compatible_types(args) || @@ -741,7 +820,7 @@ llvm::Optional PyLinalgResolver::rewrite(llvm::StringR return {}; } auto py_builder = context->builder(py_context); - setup_py_builder(py_builder, builder); + setup_py_builder(py_builder, builder, [&](auto type){ return context->create_type(type);}); auto result = builder_func(py_builder, *py_args); if (result.is_none()) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp index 66769084b50..156a5f3acad 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.hpp @@ -27,12 +27,18 @@ class PyLinalgResolver ~PyLinalgResolver(); using Values = llvm::SmallVector; + using KWArgs = llvm::ArrayRef>; - llvm::Optional rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, - llvm::ArrayRef> kwargs); + llvm::Optional rewrite_func(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, + KWArgs kwargs); + + llvm::Optional rewrite_attr(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::Value arg); private: friend struct PyBuilderContext; struct Context; std::unique_ptr context; + + llvm::Optional rewrite(llvm::StringRef name, mlir::Location loc, mlir::OpBuilder& builder, mlir::ValueRange args, + KWArgs kwargs); }; diff --git a/mlir-compiler/mlir-compiler/src/py_map_types.cpp b/mlir-compiler/mlir-compiler/src/py_map_types.cpp index 2e85ee500aa..64f68ba72ae 100644 --- a/mlir-compiler/mlir-compiler/src/py_map_types.cpp +++ b/mlir-compiler/mlir-compiler/src/py_map_types.cpp @@ -59,8 +59,8 @@ py::object map_type(const py::handle& types_mod, mlir::Type type) {&is_int<64, mlir::IntegerType::Signless>, "int64"}, {&is_int<64, mlir::IntegerType::Unsigned>, "uint64"}, - {&is_float<32>, "float"}, - {&is_float<64>, "double"}, + {&is_float<32>, "float32"}, + {&is_float<64>, "float64"}, }; for (auto h : primitive_types) diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 7167a7211de..630f3a9f12d 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -19,6 +19,17 @@ def __len__(self): def __getitem__(self, index): return self._getitem(self._context, self._ssa_val, index) + def __mul__(self, o): return self._binop(self._context, self._ssa_val, o, '*') + def __rmul__(self, o): return self._binop(self._context, self._ssa_val, o, '*') + +class Type: + def __init__(self, mlir_type, eq): + self._mlir_type = mlir_type + self._eq = eq + + def __eq__(self, other): + return self._eq(self._mlir_type, other._mlir_type) + def is_literal(val): return not isinstance(val, Var) @@ -53,13 +64,22 @@ def compile_func(*args, **kwargs): def register_func(name, orig_func = None): def _decorator(func): global _func_registry - assert not name in _func_registry - _func_registry[name] = func + mangled_name = name + '()' + assert not mangled_name in _func_registry + _func_registry[mangled_name] = func if not orig_func is None: add_func(orig_func, name) return func return _decorator +def register_attr(name): + def _decorator(func): + global _func_registry + assert not name in _func_registry + _func_registry[name] = func + return func + return _decorator + def lookup_func(name): global _func_registry return _func_registry.get(name) diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index 24050d43437..702520eaca7 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -1,8 +1,14 @@ -from ..linalg_builder import register_func, is_literal +from ..linalg_builder import register_func, register_attr import numpy import math +def is_int(t, b): + return t == b.int8 or t == b.int16 or t == b.int32 or t == b.int64 + +def is_float(t, b): + return t == b.float16 or t == b.float32 or t == b.float64 + def eltwise(builder, args, body, res_type = None): if isinstance(args, tuple): args = builder.broadcast(*args) @@ -59,8 +65,12 @@ def body(a, b): maps = [expr1,expr2] res_shape = tuple(shape[i] for i in range(len(shape)) if i != axis) - init = builder.init_tensor(res_shape, builder.int64, 0) #TODO: type - # val = builder.fill_tensor(init, 0) + orig_type = arg.dtype + if is_int(orig_type, builder): + res_type = builder.int64 + else: + res_type = orig_type + init = builder.init_tensor(res_shape, res_type, 0) def body(a, b): return a + b @@ -88,3 +98,60 @@ def body(a, b): def empty_impl(builder, shape): # TODO: dtype return builder.init_tensor(shape, builder.float64) + +@register_func('numpy.dot', numpy.dot) +def empty_impl(builder, a, b): + shape1 = a.shape + shape2 = b.shape + if len(shape1) == 1 and len(shape2) == 1: + iterators = ['reduction'] + expr1 = '(d0) -> (d0)' + expr2 = '(d0) -> (0)' + maps = [expr1,expr1,expr2] + init = builder.from_elements(0, a.dtype) + + def body(a, b, c): + return a * b + c + + res = builder.generic((a,b), init, iterators, maps, body) + return builder.extract(res, 0) + if len(shape1) == 2 and len(shape2) == 2: + iterators = ['parallel','parallel','reduction'] + expr1 = '(d0,d1,d2) -> (d0,d2)' + expr2 = '(d0,d1,d2) -> (d2,d1)' + expr3 = '(d0,d1,d2) -> (d0,d1)' + maps = [expr1,expr2,expr3] + res_shape = (shape1[0], shape2[1]) + init = builder.init_tensor(res_shape, a.dtype, 0) + + def body(a, b, c): + return a * b + c + + return builder.generic((a,b), init, iterators, maps, body) + +@register_attr('array.size') +def size_impl(builder, arg): + shape = arg.shape + res = builder.init_tensor([], builder.index, 1) + for i in range(len(shape)): + res = res * shape[i] + return res + +@register_attr('array.T') +def size_impl(builder, arg): + shape = arg.shape + dims = len(shape) + if dims == 1: + return arg + if dims == 2: + iterators = ['parallel','parallel'] + expr1 = '(d0,d1) -> (d0,d1)' + expr2 = '(d0,d1) -> (d1,d0)' + maps = [expr1,expr2] + res_shape = (shape[1], shape[0]) + init = builder.init_tensor(res_shape, arg.dtype) + + def body(a, b): + return a + + return builder.generic(arg, init, iterators, maps, body) diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 405d5cdc019..4c0f8a56e8f 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -43,6 +43,9 @@ def test_unary(self): lambda a: np.sum(a), lambda a: np.sqrt(a), lambda a: np.square(a), + lambda a: a.size, + # lambda a: a.T, TODO: need fortran layout support + lambda a: a.T.T, ] for py_func in funcs: @@ -51,17 +54,6 @@ def test_unary(self): arr = np.array(a) assert_equal(py_func(arr), jit_func(arr)) - def test_sum_axis(self): - funcs = [ - lambda a: np.sum(a, axis=0), - lambda a: np.sum(a, axis=1), - ] - - for py_func in funcs: - jit_func = njit(py_func) - arr = np.array([[1,2,3],[4,5,6]]) - assert_equal(py_func(arr), jit_func(arr)) - def test_add(self): def py_func(a, b): return np.add(a, b) @@ -80,6 +72,18 @@ def py_func(a, b): arr2 = 2 assert_equal(py_func(arr1, arr2), jit_func(arr1, arr2)) + def test_sum_axis(self): + funcs = [ + lambda a: np.sum(a, axis=0), + lambda a: np.sum(a, axis=1), + ] + + for py_func in funcs: + jit_func = njit(py_func) + arr = np.array([[1,2,3],[4,5,6]]) + for a in [arr, arr.astype(np.float32)]: + assert_equal(py_func(a), jit_func(a)) + def test_sum_add(self): def py_func(a, b): return np.add(a, b).sum() @@ -100,6 +104,19 @@ def py_func(a, b, c): arr3 = np.asarray([7,8,9]) assert_equal(py_func(arr1, arr2, arr3), jit_func(arr1, arr2, arr3)) + def test_dot(self): + def py_func(a, b): + return np.dot(a, b) + + jit_func = njit(py_func) + arr1 = np.asarray([1,2,3], np.float32) + arr2 = np.asarray([4,5,6], np.float32) + arr3 = np.asarray([[1,2,3],[4,5,6]], np.float32) + arr4 = np.asarray([[1,2],[3,4],[5,6]], np.float32) + + for a, b in [(arr1,arr2), (arr3,arr4)]: + assert_equal(py_func(a, b), jit_func(a, b)) + def test_static_setitem(self): def py_func(a): a[1] = 42 From d22d464c9c7b168cb39ecde451d00dcf0b2f73ca Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 23 Feb 2021 14:55:57 +0300 Subject: [PATCH 237/259] [MLIR] tbb fixes (#189) * fix plier parallel * use loop to init reduce vars * transfor to tbb parallel if have parallel attr or outermost loop * we dont need fix_tls_observer, also do not recreate task arena each time --- .../src/pipelines/parallel_to_tbb.cpp | 37 +++++++++++++------ numba/np/ufunc/tbbpool.cpp | 27 ++++++++++++-- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp index 6012f4cf2ca..984fe92aa47 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -50,7 +50,9 @@ struct ParallelToTbb : public mlir::OpRewritePattern { return mlir::failure(); } - if (!op->hasAttr(plier::attributes::getParallelName())) + bool need_parallel = op->hasAttr(plier::attributes::getParallelName()) || + !op->getParentOfType(); + if (!need_parallel) { return mlir::failure(); } @@ -85,14 +87,26 @@ struct ParallelToTbb : public mlir::OpRewritePattern auto reduce = rewriter.create(loc, reduce_type); auto index = static_cast(it.index()); reduce_vars[index] = reduce; - auto zero = getZeroVal(rewriter, loc, type); - mapping.map(op.initVals()[index], zero); - for (unsigned i = 0; i < max_concurrency; ++i) + } + + auto reduce_init_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args) + { + assert(args.empty()); + (void)args; + for (auto it : llvm::enumerate(reduce_vars)) { - mlir::Value index = rewriter.create(loc, i); - rewriter.create(loc, zero, reduce, index); + auto reduce = it.value(); + auto type = op.getResultTypes()[it.index()]; + auto zero = getZeroVal(rewriter, loc, type); + builder.create(loc, zero, reduce, index); } - } + builder.create(loc); + }; + + auto reduce_lower_bound = rewriter.create(loc, 0); + auto reduce_upper_bound = rewriter.create(loc, max_concurrency); + auto reduce_step = rewriter.create(loc, 1); + rewriter.create(loc, reduce_lower_bound, reduce_upper_bound, reduce_step, llvm::None, reduce_init_body_builder); auto& old_body = op.getLoopBody().front(); auto orig_lower_bound = op.lowerBound().front(); @@ -100,16 +114,19 @@ struct ParallelToTbb : public mlir::OpRewritePattern auto orig_step = op.step().front(); auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value thread_index) { + llvm::SmallVector initVals(op.initVals().size()); for (auto it : llvm::enumerate(op.initVals())) { auto reduce_var = reduce_vars[it.index()]; auto val = builder.create(loc, reduce_var, thread_index); - mapping.map(it.value(), val); + initVals[it.index()] = val; } auto new_op = mlir::cast(builder.clone(*op, mapping)); + new_op->removeAttr(plier::attributes::getParallelName()); assert(new_op->getNumResults() == reduce_vars.size()); new_op.lowerBoundMutable().assign(lower_bound); new_op.upperBoundMutable().assign(upper_bound); + new_op.initValsMutable().assign(initVals); for (auto it : llvm::enumerate(new_op->getResults())) { auto reduce_var = reduce_vars[it.index()]; @@ -119,10 +136,6 @@ struct ParallelToTbb : public mlir::OpRewritePattern rewriter.create(loc, orig_lower_bound, orig_upper_bound, orig_step, body_builder); - auto reduce_lower_bound = rewriter.create(loc, 0); - auto reduce_upper_bound = rewriter.create(loc, max_concurrency); - auto reduce_step = rewriter.create(loc, 1); - auto reduce_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value index, mlir::ValueRange args) { assert(args.size() == reduce_vars.size()); diff --git a/numba/np/ufunc/tbbpool.cpp b/numba/np/ufunc/tbbpool.cpp index d5ade473509..74a9bd8a211 100644 --- a/numba/np/ufunc/tbbpool.cpp +++ b/numba/np/ufunc/tbbpool.cpp @@ -42,6 +42,21 @@ Implement parallel vectorize workqueue on top of Intel TBB. static tbb::task_group *tg = NULL; static tbb::task_scheduler_init *tsi = NULL; + +namespace +{ +struct ThreadContext +{ + ThreadContext(int n_threads): + num_threads(n_threads), + arena(n_threads) {} + + int num_threads = 0; + tbb::task_arena arena; +}; +static ThreadContext* thread_context = nullptr; +} + static int tsi_count = 0; #ifdef _MSC_VER @@ -209,15 +224,15 @@ parallel_for(void *fn, char **args, size_t *dimensions, size_t *steps, void *dat using parallel_for2_fptr = void(*)(size_t, size_t, size_t, void*); static void parallel_for2(size_t lower_bound, size_t upper_bound, size_t step, parallel_for2_fptr func, void* ctx) { - auto num_threads = get_num_threads(); + auto context = thread_context; + assert(nullptr != context); + auto num_threads = context->num_threads; if(_DEBUG) { printf("parallel_for2 %d %d %d %d\n", (int)lower_bound, (int)upper_bound, (int)step, (int)num_threads); } - tbb::task_arena limited(num_threads); - fix_tls_observer observer(limited, num_threads); - limited.execute([&] + context->arena.execute([&] { size_t count = (upper_bound - lower_bound - 1) / step + 1; size_t grain = std::max(size_t(1), std::min(count / num_threads / 2, size_t(64))); @@ -284,6 +299,8 @@ static void unload_tbb(void) tbb::set_assertion_handler(orig); delete tsi; tsi = NULL; + delete thread_context; + thread_context = nullptr; } } #endif @@ -300,6 +317,8 @@ static void launch_threads(int count) tg = new tbb::task_group; tg->run([] {}); // start creating threads asynchronously + thread_context = new ThreadContext(count); + _INIT_NUM_THREADS = count; #ifndef _MSC_VER From e5d1d72702a9b7b0855893b16e32d2e08bf0e4a2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 26 Feb 2021 19:23:16 +0300 Subject: [PATCH 238/259] some reshape support (#190) --- .../src/pipelines/plier_to_linalg.cpp | 46 +++++++++++++++ .../mlir-compiler/src/py_linalg_resolver.cpp | 47 +++++++++++++-- numba/mlir/linalg_builder.py | 3 + numba/mlir/numpy/funcs.py | 58 ++++++++++++++++++- numba/mlir/tests/test_numpy.py | 20 ++++++- 5 files changed, 164 insertions(+), 10 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 1966cbd8cb9..5218e688c43 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -612,6 +612,51 @@ struct ArrayShape : public mlir::OpRewritePattern mlir::TypeConverter& converter; }; +template +bool has_compatibale_shape(T&& a1, T&& a2) +{ + if (!a1.hasRank() || !a2.hasRank() || a1.getRank() != a2.getRank()) + { + return false; + } + for (auto it : llvm::zip(a1.getShape(), a2.getShape())) + { + auto s1 = std::get<0>(it); + auto s2 = std::get<1>(it); + if (s1 >= 0 && s2 >= 0 && s1 != s2) + { + return false; + } + } + return true; +} + +struct RankedTypesCasts : public mlir::OpRewritePattern +{ + RankedTypesCasts(mlir::TypeConverter& /*type_converter*/, + mlir::MLIRContext* context): + OpRewritePattern(context){} + + mlir::LogicalResult matchAndRewrite( + plier::CastOp op, mlir::PatternRewriter &rewriter) const override + { + auto src_type = op.value().getType(); + auto dst_type = op.getType(); + if (src_type.isa() && dst_type.isa()) + { + auto src = src_type.cast(); + auto dst = dst_type.cast(); + if (!has_compatibale_shape(src,dst)) + { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp(op, dst, op.value()); + return mlir::success(); + } + return mlir::failure(); + } +}; + struct GetattrRewriter : public mlir::OpRewritePattern { using resolver_t = llvm::function_ref(type_converter, context); diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 76b706317ba..2686ba820b9 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -33,7 +33,7 @@ bool is_compatible_type(mlir::Type type) { return llvm::all_of(tuple_type, &is_compatible_type); } - return type.isIntOrFloat() || type.isa(); + return type.isa(); } template @@ -304,14 +304,18 @@ auto get_iterators(py::list iterators, mlir::MLIRContext& ctx) return ret; } +mlir::AffineMapAttr get_affine_map_attr(py::handle obj, mlir::MLIRContext& ctx) +{ + auto str = (llvm::Twine("affine_map<") + obj.cast() + ">").str(); + return mlir::parseAttribute(str, &ctx).cast(); +} + auto get_affine_maps(py::list maps, mlir::MLIRContext& ctx) { llvm::SmallVector ret(maps.size()); for (auto it : llvm::enumerate(maps)) { - auto str = (llvm::Twine("affine_map<") + it.value().cast() + ">").str(); - auto attr = mlir::parseAttribute(str, &ctx); - ret[it.index()] = attr.cast().getValue(); + ret[it.index()] = get_affine_map_attr(it.value(), ctx).getValue(); } return ret; } @@ -582,6 +586,32 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice return ctx.context.create_var(context, res); } +py::object reshape_impl(py::capsule context, py::handle tensor, py::int_ out_dims, py::list maps) +{ + auto& ctx = get_py_context(context); + auto& builder = ctx.builder; + auto loc = ctx.loc; + + auto tensor_val = ctx.context.unwrap_val(loc, builder, tensor); + if (!tensor_val.getType().isa()) + { + plier::report_error("Invalid reshapa argument"); + } + auto elem_type = tensor_val.getType().cast().getElementType(); + auto new_dims = out_dims.cast(); + llvm::SmallVector dims(new_dims, -1); + auto new_type = mlir::RankedTensorType::get(dims, elem_type); + + llvm::SmallVector affine_maps(container_size(maps)); + container_iterate(maps, [&](auto index, py::handle obj) + { + affine_maps[index] = get_affine_map_attr(obj, *builder.getContext()); + }); + auto affine_maps_attr = mlir::ArrayAttr::get(affine_maps, builder.getContext()); + auto reshape = builder.create(loc, new_type, tensor_val, affine_maps_attr); + return ctx.context.create_var(context, reshape); +} + void setup_py_builder(py::handle builder, mlir::OpBuilder& b, llvm::function_ref create_type) { py::setattr(builder, "_broadcast", py::cpp_function(&broadcast_impl)); @@ -590,6 +620,7 @@ void setup_py_builder(py::handle builder, mlir::OpBuilder& b, llvm::function_ref py::setattr(builder, "_generic", py::cpp_function(&generic_impl)); py::setattr(builder, "_from_elements", py::cpp_function(&from_elements_impl)); py::setattr(builder, "_extract", py::cpp_function(&extract_impl)); + py::setattr(builder, "_reshape", py::cpp_function(&reshape_impl)); auto add_type = [&](const char* name, mlir::Type type) { @@ -667,7 +698,11 @@ py::object getitem_impl(py::capsule context, py::capsule ssa_val, py::handle ind { if (index_val < 0 || index_val >= static_cast(tuple_type.size())) { - plier::report_error("Invelid getitem index"); + plier::report_error("Invalid getitem index"); + } + if (auto parent_op = value.getDefiningOp()) + { + return ctx.context.create_var(context, parent_op.getOperand(static_cast(index_val))); } auto elem_type = tuple_type.getType(static_cast(index_val)); auto ind = builder.create(loc, index_val); @@ -678,7 +713,7 @@ py::object getitem_impl(py::capsule context, py::capsule ssa_val, py::handle ind { if (0 != index_val) { - plier::report_error("Invelid getitem index"); + plier::report_error("Invalid getitem index"); } return ctx.context.create_var(context, value); } diff --git a/numba/mlir/linalg_builder.py b/numba/mlir/linalg_builder.py index 630f3a9f12d..ddcc95c8fb6 100644 --- a/numba/mlir/linalg_builder.py +++ b/numba/mlir/linalg_builder.py @@ -55,6 +55,9 @@ def from_elements(self, values, dtype): def extract(self, value, indices): return self._extract(self._context, value, indices) + def reshape(self, src, num_dims, affine_maps): + return self._reshape(self._context, src, num_dims, affine_maps) + def compile_func(*args, **kwargs): import numba.mlir.inner_compiler return numba.mlir.inner_compiler.compile_func(*args, **kwargs) diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index 702520eaca7..a325bb090c2 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -1,4 +1,4 @@ -from ..linalg_builder import register_func, register_attr +from ..linalg_builder import register_func, register_attr, is_literal import numpy import math @@ -100,7 +100,7 @@ def empty_impl(builder, shape): return builder.init_tensor(shape, builder.float64) @register_func('numpy.dot', numpy.dot) -def empty_impl(builder, a, b): +def dot_impl(builder, a, b): shape1 = a.shape shape2 = b.shape if len(shape1) == 1 and len(shape2) == 1: @@ -138,7 +138,7 @@ def size_impl(builder, arg): return res @register_attr('array.T') -def size_impl(builder, arg): +def transpose_impl(builder, arg): shape = arg.shape dims = len(shape) if dims == 1: @@ -155,3 +155,55 @@ def body(a, b): return a return builder.generic(arg, init, iterators, maps, body) + +def flatten(builder, arg, src_dims_count): + if 1 == src_dims_count: + return arg + dims = ','.join(['d%s' % i for i in range(src_dims_count)]) + expr = f'({dims}) -> ({dims})' + maps = [ + expr + ] + return builder.reshape(arg, 1, maps) + +def find_size_index(shape): + size_index = -1 + for i in range(len(shape)): + d = shape[i] + if is_literal(d): + if 1 != d: + return -1 + else: + if size_index != -1: + return -1 + size_index = i + return size_index + +@register_func('array.reshape') +def reshape_impl(builder, arg, new_shape): + shape = arg.shape + src_count = len(shape) + count = len(new_shape) + if count == 1: + return flatten(builder, arg, src_count) + else: + size_index = find_size_index(new_shape) + if size_index < 0: + return + + flat = flatten(builder, arg, src_count) + init = builder.init_tensor(new_shape, arg.dtype) + + iterators = ['parallel'] + # dims1 = ','.join(['d%s' % i for i in range(count)]) + # dims2 = ','.join(['d%s' % i if i == size_index else '0' for i in range(count)]) + dims3 = ','.join(['d0' if i == size_index else '0' for i in range(count)]) + expr1 = f'(d0) -> (d0)' + expr2 = f'(d0) -> ({dims3})' + maps = [expr1, expr2] + + def body(a, b): + return a + + return builder.generic(flat, init, iterators, maps, body) + diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 4c0f8a56e8f..f552bef3992 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -213,7 +213,7 @@ def py_func(a): def test_array_shape(self): def py_func(a): shape = a.shape - return shape[0] + shape[1] + return shape[0] + shape[1] * 10 jit_func = njit(py_func) arr = np.array([[1,2,3],[4,5,6]]) @@ -259,5 +259,23 @@ def py_func(d1, d2): jit_func = njit(py_func) assert_equal(py_func(5, 7), jit_func(5, 7)) + def test_reshape(self): + funcs = [ + lambda a: a.reshape(a.size), + lambda a: a.reshape((a.size,)), + lambda a: a.reshape((a.size,1)), + lambda a: a.reshape((1, a.size)), + lambda a: a.reshape((1, a.size, 1)), + ] + + arr1 = np.array([1,2,3,4,5,6,7,8,9,10,11,12]) + # arr2 = arr1.reshape((2,6)) + # arr3 = arr1.reshape((2,3,2)) + for py_func in funcs: + jit_func = njit(py_func) + # for a in [arr1,arr2,arr3]: TODO: flatten support + for a in [arr1]: + assert_equal(py_func(a), jit_func(a)) + if __name__ == '__main__': unittest.main() From 8ad48230b3bc8c97ff882b75fd19ec070011dd06 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 27 Feb 2021 00:16:51 +0300 Subject: [PATCH 239/259] some broadcasting support (#191) --- .../mlir-compiler/src/py_linalg_resolver.cpp | 171 +++++++++++++++++- numba/mlir/tests/test_numpy.py | 2 + 2 files changed, 171 insertions(+), 2 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 2686ba820b9..c6dac247f91 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -354,16 +354,183 @@ auto generic_op_body_result_types(mlir::ValueRange outputs) return ret; } -py::object broadcast_impl(py::capsule /*context*/, py::tuple args) +bool is_int(mlir::Type type) +{ + return type.isa(); +} + +unsigned get_int_bit_width(mlir::Type type) +{ + if (type.isa()) + { + return type.cast().getWidth(); + } + if (type.isa()) + { + return 64; // TODO + } + llvm_unreachable("No an integer type"); +} + +bool is_float(mlir::Type type) +{ + return type.isa(); +} + +unsigned get_float_bit_width(mlir::Type type) +{ + return type.cast().getWidth(); +} + +mlir::Type broadcast_type(mlir::Type type1, mlir::Type type2) +{ + if (type1 == type2) + { + return type1; + } + // TODO + if (is_int(type1) && is_int(type2)) + { + auto width = std::max(get_int_bit_width(type1), get_int_bit_width(type2)); + return mlir::IntegerType::get(type1.getContext(), width); + } + if (is_float(type1) && is_float(type2)) + { + return (get_float_bit_width(type1) > get_float_bit_width(type2) ? type1 : type2); + } + if (is_float(type1) && is_int(type2)) + { + return type1; + } + if (is_int(type1) && is_float(type2)) + { + return type2; + } + llvm_unreachable("Unable to broadcast type"); +} + +py::object broadcast_impl(py::capsule context, py::tuple args) { if (1 == args.size()) { return args[0]; } + auto& ctx = get_py_context(context); + auto loc = ctx.loc; + auto& builder = ctx.builder; + llvm::SmallVector mlir_args(args.size()); + for (auto it : llvm::enumerate(args)) + { + mlir_args[it.index()] = ctx.context.unwrap_val(loc, builder, it.value()); + } + using shape_t = llvm::SmallVector; + auto get_shape = [&](mlir::Value val)->llvm::Optional> + { + auto type = val.getType(); + if (auto shaped = type.dyn_cast()) + { + if (!shaped.hasRank()) + { + return {}; + } + shape_t ret(static_cast(shaped.getRank())); + for (auto it : llvm::enumerate(ret)) + { + auto dim = builder.create(loc, val, it.index()); + ret[it.index()] = dim; + } + return std::make_pair(ret, shaped.getElementType()); + } + if (type.isa()) + { + return std::make_pair(shape_t{}, type); + } + return {}; + }; + mlir::Type res_type; + mlir::SmallVector shape_vals; + if (auto shape_and_type = get_shape(mlir_args.front())) + { + res_type = shape_and_type->second; + shape_vals = shape_and_type->first; + } else { - return std::move(args); + return py::none(); + } + + for (auto arg : llvm::drop_begin(mlir_args)) + { + auto shape_and_type = get_shape(arg); + if (!shape_and_type) + { + py::none(); + } + res_type = broadcast_type(res_type, shape_and_type->second); + if (shape_and_type->first.size() > shape_vals.size()) + { + shape_vals = shape_and_type->first; // TODO + } + } + + llvm::SmallVector shape(static_cast(shape_vals.size()), -1); + py::tuple ret(args.size()); + if (shape_vals.empty()) + { + for (auto it : llvm::enumerate(mlir_args)) + { + mlir::Value val = it.value(); + if (val.getType() != res_type) + { + val = builder.create(loc, res_type, val); + } + ret[it.index()] = ctx.context.create_var(context, val); + } + return std::move(ret); + } + + auto tensor_type = mlir::RankedTensorType::get(shape, res_type); + for (auto it : llvm::enumerate(mlir_args)) + { + mlir::Value val = it.value(); + auto type = val.getType(); + if (type != tensor_type) + { + if (auto src_type = type.dyn_cast()) + { + assert(src_type.hasRank()); + auto num_dims = static_cast(src_type.getRank()); + auto init = builder.create(loc, shape_vals, tensor_type.getElementType()).getResult(); + llvm::SmallVector iterators(num_dims, "parallel"); + auto map = mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()); + mlir::AffineMap maps[] = { + map, + map, + }; + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) + { + assert(values.size() == 2); + auto res = builder.create(loc, tensor_type.getElementType(), values[0]); + builder.create(loc, res.getResult()); + }; + val = builder.create(loc, tensor_type, val, init, maps, iterators, body).getResult(0); + } + else + { + if (tensor_type.getElementType() != type) + { + val = builder.create(loc, tensor_type.getElementType(), val); + } + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/) + { + builder.create(loc, val); + }; + val = builder.create(loc, tensor_type, shape_vals, body); + } + } + ret[it.index()] = ctx.context.create_var(context, val); } + return std::move(ret); } py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dtype, py::handle init_val) diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index f552bef3992..918d46ddc93 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -46,6 +46,8 @@ def test_unary(self): lambda a: a.size, # lambda a: a.T, TODO: need fortran layout support lambda a: a.T.T, + lambda a: np.add(a, 1), + lambda a: np.add(a, 2.5), ] for py_func in funcs: From 22f3d4330cd9d1d5124d543c5112c4c97bf9a882 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 27 Feb 2021 15:06:12 +0300 Subject: [PATCH 240/259] [MLIR] Some numpy operators support (#192) * use std::function * refac test * numpy operator add * fix plier binop * numpy subtract * numpy multiply --- .../src/pipelines/plier_to_linalg.cpp | 87 ++++++++++++++----- .../src/pipelines/plier_to_std.cpp | 2 +- mlir-compiler/plier/CMakeLists.txt | 2 +- mlir-compiler/plier/include/plier/PlierOps.td | 4 +- .../include/plier/rewrites/call_lowering.hpp | 2 +- numba/mlir/numpy/funcs.py | 17 ++++ numba/mlir/tests/test_numpy.py | 33 ++++--- 7 files changed, 102 insertions(+), 45 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 5218e688c43..b3444a8a1d0 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -214,12 +214,14 @@ mlir::LogicalResult lower_prange(plier::PyCallOp op, llvm::ArrayRef struct CallLowerer { + using args_t = llvm::ArrayRef; + using kwargs_t = llvm::ArrayRef>; mlir::LogicalResult operator()( - plier::PyCallOp op, llvm::StringRef name, llvm::ArrayRef args, - llvm::ArrayRef> kwargs, + plier::PyCallOp op, llvm::StringRef name, args_t args, + kwargs_t kwargs, mlir::PatternRewriter& rewriter) { - using func_t = mlir::LogicalResult(*)(plier::PyCallOp, llvm::ArrayRef, llvm::ArrayRef>, mlir::PatternRewriter&); + using func_t = mlir::LogicalResult(*)(plier::PyCallOp, args_t, kwargs_t, mlir::PatternRewriter&); std::pair handlers[] = { {"numba.prange", lower_prange}, }; @@ -231,18 +233,8 @@ struct CallLowerer } } - if (auto result = linalg_resolver.rewrite_func(name, op.getLoc(), rewriter, args, kwargs)) + if (mlir::succeeded(applyRewrite(op, rewriter, linalg_resolver.rewrite_func(name, op.getLoc(), rewriter, args, kwargs)))) { - assert(result->size() == op->getNumResults()); - rerun_std_pipeline(op); - if (result->empty()) - { - rewriter.eraseOp(op); - } - else - { - rewriter.replaceOp(op, *result); - } return mlir::success(); } @@ -267,7 +259,39 @@ struct CallLowerer return mlir::failure(); } auto full_name = (llvm::Twine("array.") + name).str(); - if (auto result = linalg_resolver.rewrite_attr(full_name, op.getLoc(), rewriter, arg)) + return applyRewrite(op, rewriter, linalg_resolver.rewrite_attr(full_name, op.getLoc(), rewriter, arg)); + } + + mlir::LogicalResult operator()( + plier::BinOp op, llvm::StringRef name, mlir::Value lhs, mlir::Value rhs, + mlir::PatternRewriter& rewriter) + { + if (!lhs.getType().isa() && + !rhs.getType().isa()) + { + return mlir::failure(); + } + const std::pair names[] = { + {"+", "operator.add"}, + {"-", "operator.sub"}, + {"*", "operator.mul"}, + }; + for (auto it : names) + { + if (it.first == name) + { + return applyRewrite(op, rewriter, linalg_resolver.rewrite_func(it.second, op.getLoc(), rewriter, {lhs, rhs}, {})); + } + } + return mlir::failure(); + } + +private: + PyLinalgResolver linalg_resolver; + + mlir::LogicalResult applyRewrite(mlir::Operation* op, mlir::PatternRewriter& rewriter, llvm::Optional result) + { + if (result) { assert(result->size() == op->getNumResults()); rerun_std_pipeline(op); @@ -283,9 +307,6 @@ struct CallLowerer } return mlir::failure(); } - -private: - PyLinalgResolver linalg_resolver; }; mlir::Value index_cast(mlir::Value value, mlir::Location loc, mlir::OpBuilder& builder) @@ -659,8 +680,8 @@ struct RankedTypesCasts : public mlir::OpRewritePattern struct GetattrRewriter : public mlir::OpRewritePattern { - using resolver_t = llvm::function_ref; + using resolver_t = std::function; GetattrRewriter(mlir::TypeConverter &/*typeConverter*/, mlir::MLIRContext *context, @@ -679,6 +700,27 @@ struct GetattrRewriter : public mlir::OpRewritePattern resolver_t resolver; }; +struct BinopRewriter : public mlir::OpRewritePattern +{ + using resolver_t = std::function; + + BinopRewriter(mlir::TypeConverter &/*typeConverter*/, + mlir::MLIRContext *context, + resolver_t resolver): + OpRewritePattern(context), + resolver(resolver) + {} + + mlir::LogicalResult matchAndRewrite( + plier::BinOp op, mlir::PatternRewriter &rewriter) const override + { + return resolver(op, op.op(), op.lhs(), op.rhs(), rewriter); + } + +private: + resolver_t resolver; +}; void PlierToLinalgPass::runOnOperation() { @@ -710,8 +752,9 @@ void PlierToLinalgPass::runOnOperation() patterns.insert< plier::CallOpLowering, - GetattrRewriter - >(type_converter, context, callLowerer); + GetattrRewriter, + BinopRewriter + >(type_converter, context, std::ref(callLowerer)); patterns.insert< GetitemOpLowering, diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index 6f22320bd18..27190ae9e39 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1404,7 +1404,7 @@ void PlierToStdPass::runOnOperation() patterns.insert< plier::CallOpLowering - >(type_converter, context, callLowerer); + >(type_converter, context, std::ref(callLowerer)); mlir::populateStdExpandOpsPatterns(context, patterns); diff --git a/mlir-compiler/plier/CMakeLists.txt b/mlir-compiler/plier/CMakeLists.txt index 52c1d6ba4de..6de69b1d317 100644 --- a/mlir-compiler/plier/CMakeLists.txt +++ b/mlir-compiler/plier/CMakeLists.txt @@ -85,7 +85,7 @@ target_include_directories(${PLIER_LIB} PRIVATE target_include_directories(${PLIER_LIB} PUBLIC ./include - ${PROJECT_BINARY_DIR}/include + ${PROJECT_BINARY_DIR}/plier/include ) add_dependencies(${PLIER_LIB} MLIRPlierOpsIncGen) diff --git a/mlir-compiler/plier/include/plier/PlierOps.td b/mlir-compiler/plier/include/plier/PlierOps.td index 9324098fb71..084ffea2574 100644 --- a/mlir-compiler/plier/include/plier/PlierOps.td +++ b/mlir-compiler/plier/include/plier/PlierOps.td @@ -56,14 +56,14 @@ def GlobalOp : Plier_Op<"global", [NoSideEffect]> { def BinOp : Plier_Op<"binop", []> { let arguments = (ins - AnyType:$rhs, AnyType:$lhs, + AnyType:$rhs, StrAttr:$op); let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$rhs, "::mlir::Value":$lhs, "::mlir::StringRef ":$op)> + OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs, "::mlir::StringRef ":$op)> ]; } diff --git a/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp index 48d58f9958d..b87c11b0ad3 100644 --- a/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp @@ -15,7 +15,7 @@ namespace plier { struct CallOpLowering : public mlir::OpRewritePattern { - using resolver_t = llvm::function_ref, llvm::ArrayRef> , mlir::PatternRewriter&)>; + using resolver_t = std::function, llvm::ArrayRef> , mlir::PatternRewriter&)>; CallOpLowering(mlir::TypeConverter &typeConverter, mlir::MLIRContext *context, diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index a325bb090c2..e3094418d14 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -30,12 +30,29 @@ def eltwise(builder, args, body, res_type = None): return builder.generic(args, init, iterators, maps, body) @register_func('numpy.add', numpy.add) +@register_func('operator.add') def add_impl(builder, arg1, arg2): def body(a, b, c): return a + b return eltwise(builder, (arg1, arg2), body) +@register_func('numpy.subtract', numpy.subtract) +@register_func('operator.sub') +def sub_impl(builder, arg1, arg2): + def body(a, b, c): + return a - b + + return eltwise(builder, (arg1, arg2), body) + +@register_func('numpy.multiply', numpy.multiply) +@register_func('operator.mul') +def mul_impl(builder, arg1, arg2): + def body(a, b, c): + return a * b + + return eltwise(builder, (arg1, arg2), body) + @register_func('array.sum') @register_func('numpy.sum', numpy.sum) def sum_impl(builder, arg, axis=None): diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 918d46ddc93..32320b34827 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -4,6 +4,7 @@ import numpy as np from numba.tests.support import TestCase import unittest +import itertools _arr_1d_int = [1,2,3,4,5,6,7,8] _arr_1d_float = [1.0,2.1,3.2,4.3,5.4,6.5,7.6,8.7] @@ -46,8 +47,6 @@ def test_unary(self): lambda a: a.size, # lambda a: a.T, TODO: need fortran layout support lambda a: a.T.T, - lambda a: np.add(a, 1), - lambda a: np.add(a, 2.5), ] for py_func in funcs: @@ -56,23 +55,21 @@ def test_unary(self): arr = np.array(a) assert_equal(py_func(arr), jit_func(arr)) - def test_add(self): - def py_func(a, b): - return np.add(a, b) - - jit_func = njit(py_func) - arr1 = np.array([1,2,3]) - arr2 = np.array([4,5,6]) - assert_equal(py_func(arr1,arr2), jit_func(arr1,arr2)) - - def test_add_scalar(self): - def py_func(a, b): - return np.add(a, b) + def test_binary(self): + funcs = [ + lambda a, b: np.add(a, b), + lambda a, b: a + b, + lambda a, b: np.subtract(a, b), + lambda a, b: a - b, + lambda a, b: np.multiply(a, b), + lambda a, b: a * b, + ] - jit_func = njit(py_func) - arr1 = 1 - arr2 = 2 - assert_equal(py_func(arr1, arr2), jit_func(arr1, arr2)) + test_data = [1, 2.5, np.array([1,2,3]), np.array([4.4,5.5,6.6])] + for py_func in funcs: + jit_func = njit(py_func) + for a1, a2 in itertools.product(test_data, test_data): + assert_equal(py_func(a1,a2), jit_func(a1,a2)) def test_sum_axis(self): funcs = [ From 6d295aa74a9a77cb3ac4c8cb44db3384c6de3839 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 28 Feb 2021 16:04:22 +0300 Subject: [PATCH 241/259] [MLIR] Numpy broadcasting (#193) * fixes to broadcasting * fix * work on broadcast * some work on broadcast * work on broadcast * broadcasting * broadcast fix * PostFusionOptPass --- .../src/pipelines/plier_to_linalg.cpp | 40 ++++- .../mlir-compiler/src/py_linalg_resolver.cpp | 152 ++++++++++++++++-- numba/mlir/tests/test_numpy.py | 19 +++ 3 files changed, 194 insertions(+), 17 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index b3444a8a1d0..7114e2584f4 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -802,8 +802,8 @@ void LowerLinalgPass::runOnOperation() (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } -struct PostLinalgOptPass : - public mlir::PassWrapper> +struct PostFusionOptPass : + public mlir::PassWrapper> { virtual void getDependentDialects( mlir::DialectRegistry ®istry) const override @@ -817,6 +817,26 @@ struct PostLinalgOptPass : void runOnOperation() override; }; +void PostFusionOptPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + for (auto *op : context.getRegisteredOperations()) + { + op->getCanonicalizationPatterns(patterns, &context); + } + + patterns.insert< + // LoopInvariantCodeMotion, TODO + plier::CSERewrite + >(&context); + + plier::populate_index_propagate_patterns(context, patterns); + + mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + struct LoopInvariantCodeMotion : public mlir::OpRewritePattern { using mlir::OpRewritePattern::OpRewritePattern; @@ -839,6 +859,21 @@ struct LoopInvariantCodeMotion : public mlir::OpRewritePattern } }; +struct PostLinalgOptPass : + public mlir::PassWrapper> +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + void PostLinalgOptPass::runOnOperation() { mlir::OwningRewritePatternList patterns; @@ -871,6 +906,7 @@ void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::createLinalgFusionOfTensorOpsPass()); + pm.addPass(std::make_unique()); pm.addPass(mlir::createTensorConstantBufferizePass()); pm.addNestedPass(mlir::createSCFBufferizePass()); diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index c6dac247f91..42320cee92a 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -206,8 +207,8 @@ struct PyLinalgResolver::Context namespace { -py::list get_args(py::handle inspect, py::handle func, llvm::function_ref create_var, - mlir::ValueRange args, llvm::ArrayRef> kwargs) +py::object get_args(py::handle inspect, py::handle func, llvm::function_ref create_var, + mlir::ValueRange args, llvm::ArrayRef> kwargs) { auto sig_func = inspect.attr("signature"); auto sig = sig_func(func); @@ -258,7 +259,11 @@ py::list get_args(py::handle inspect, py::handle func, llvm::function_ref()); + assert(val2.getType().isa()); + auto one = builder.create(loc, 1); + auto cond = builder.create(loc, mlir::CmpIPredicate::eq, val1, one); + return builder.create(loc, cond, val2, val1); +} + +mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, unsigned dim, mlir::ValueRange target_shape) +{ + auto context = builder.getContext(); + auto src_type = src.getType().cast(); + auto num_dims = static_cast(src_type.getRank()); + auto shape = llvm::to_vector<8>(src_type.getShape()); + shape[dim] = -1; + mlir::Type target_type = mlir::RankedTensorType::get(shape, src_type.getElementType()); + auto dim_val = builder.create(loc, src, dim); + auto one = builder.create(loc, 1); + mlir::Value cond = builder.create(loc, mlir::CmpIPredicate::eq, one, dim_val); + mlir::Value cond2 = builder.create(loc, mlir::CmpIPredicate::ne, target_shape[dim], dim_val); + cond = builder.create(loc, cond, cond2); + llvm::SmallVector new_shape(num_dims); + for (unsigned i = 0 ; i < num_dims; ++i) + { + if (i == dim) + { + new_shape[i] = target_shape[i]; + } + else + { + new_shape[i] = builder.create(loc, src, i); + } + } + auto true_body = [&](mlir::OpBuilder &builder, mlir::Location loc) + { + assert(dim < shape.size()); + shape[dim] = 1; + mlir::Type casted_type = mlir::RankedTensorType::get(shape, src_type.getElementType()); + auto casted = builder.create(loc, casted_type, src).getResult(); + auto init = builder.create(loc, new_shape, src_type.getElementType()).getResult(); + llvm::SmallVector exprs(num_dims); + for (unsigned i = 0; i < num_dims; ++i) + { + if (i == dim) + { + exprs[i] = mlir::getAffineConstantExpr(0, context); + } + else + { + exprs[i] = mlir::getAffineDimExpr(i, context); + } + } + const mlir::AffineMap maps[] = { + mlir::AffineMap::get(num_dims, 0, exprs, context), + mlir::AffineMap::getMultiDimIdentityMap(num_dims, context), + }; + llvm::SmallVector iterators(num_dims, "parallel"); + + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) + { + assert(values.size() == 2); + builder.create(loc, values[0]); + }; + + auto expanded = builder.create(loc, target_type, casted, init, maps, iterators, body); + auto res = builder.create(loc, target_type, expanded.getResult(0)); + builder.create(loc, res.getResult()); + }; + auto false_body = [&](mlir::OpBuilder &builder, mlir::Location loc) + { + auto res = builder.create(loc, target_type, src); + builder.create(loc, res.getResult()); + }; + return builder.create(loc, target_type, cond, true_body, false_body).getResult(0); +} + +mlir::Value expand_dims(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value val, unsigned num_dims, mlir::ValueRange target_shape) +{ + assert(num_dims <= target_shape.size()); + if (num_dims < target_shape.size()) + { + target_shape = target_shape.drop_front(target_shape.size() - num_dims); + } + for (unsigned i = 0; i < num_dims; ++i) + { + val = expand_dim(builder, loc, val, i, target_shape); + } + return val; +} + py::object broadcast_impl(py::capsule context, py::tuple args) { if (1 == args.size()) @@ -467,14 +563,22 @@ py::object broadcast_impl(py::capsule context, py::tuple args) py::none(); } res_type = broadcast_type(res_type, shape_and_type->second); - if (shape_and_type->first.size() > shape_vals.size()) + auto new_shape_vals = shape_and_type->first; + for (auto it : llvm::zip(llvm::reverse(shape_vals), llvm::reverse(new_shape_vals))) { - shape_vals = shape_and_type->first; // TODO + auto& old_val = std::get<0>(it); + auto new_val = std::get<1>(it); + old_val = broadcast_dim(builder, loc, old_val, new_val); + } + if (new_shape_vals.size() > shape_vals.size()) + { + auto front = llvm::makeArrayRef(new_shape_vals).drop_back(shape_vals.size()); + assert(!front.empty()); + shape_vals.insert(shape_vals.begin(), front.begin(), front.end()); } } - llvm::SmallVector shape(static_cast(shape_vals.size()), -1); - py::tuple ret(args.size()); + py::tuple ret(mlir_args.size()); if (shape_vals.empty()) { for (auto it : llvm::enumerate(mlir_args)) @@ -489,24 +593,31 @@ py::object broadcast_impl(py::capsule context, py::tuple args) return std::move(ret); } + llvm::SmallVector shape(static_cast(shape_vals.size()), -1); auto tensor_type = mlir::RankedTensorType::get(shape, res_type); for (auto it : llvm::enumerate(mlir_args)) { mlir::Value val = it.value(); - auto type = val.getType(); - if (type != tensor_type) + if (auto src_type = val.getType().dyn_cast()) + { + assert(src_type.hasRank()); + val = expand_dims(builder, loc, val, static_cast(src_type.getRank()), shape_vals); + } + if (val.getType() != tensor_type) { + auto type = val.getType(); if (auto src_type = type.dyn_cast()) { assert(src_type.hasRank()); - auto num_dims = static_cast(src_type.getRank()); + auto src_num_dims = static_cast(src_type.getRank()); + auto num_dims = static_cast(tensor_type.getRank()); auto init = builder.create(loc, shape_vals, tensor_type.getElementType()).getResult(); - llvm::SmallVector iterators(num_dims, "parallel"); - auto map = mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()); mlir::AffineMap maps[] = { - map, - map, + mlir::AffineMap::getMinorIdentityMap(num_dims, src_num_dims, builder.getContext()), +// mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()).getMajorSubMap(src_num_dims), + mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()), }; + llvm::SmallVector iterators(num_dims, "parallel"); auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) { assert(values.size() == 2); @@ -559,9 +670,15 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt { auto index_type = builder.getIndexType(); llvm::SmallVector shape_val(count); + llvm::SmallVector static_shape(count, -1); for (size_t i = 0; i < count; ++i) { - shape_val[i] = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, shape[py::int_(i)]), index_type); + auto elem = shape[py::int_(i)]; + if (py::isinstance(elem)) + { + static_shape[i] = elem.cast(); + } + shape_val[i] = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, elem), index_type); } if (init_val.is_none()) @@ -579,6 +696,11 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt auto type = mlir::RankedTensorType::get(shape, elem_type); init = builder.create(loc, type, shape_val, body); } + if (llvm::any_of(static_shape, [](auto val){ return val >= 0;})) + { + auto new_type = mlir::RankedTensorType::get(static_shape, elem_type); + init = builder.create(loc, new_type, init); + } } return ctx.context.create_var(context, init); } diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 32320b34827..8c770a1fc65 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -276,5 +276,24 @@ def test_reshape(self): for a in [arr1]: assert_equal(py_func(a), jit_func(a)) + def test_broadcast(self): + def py_func(a, b): + return np.add(a, b) + + jit_func = njit(py_func) + + test_data = [ + 1, + np.array([1]), + np.array([[1]]), + np.array([[1,2],[3,4]]), + np.array([5,6]), + np.array([[5],[6]]), + np.array([[5,6]]), + ] + + for a, b in itertools.product(test_data, test_data): + assert_equal(py_func(a,b), jit_func(a,b)) + if __name__ == '__main__': unittest.main() From b3fe4a3d60ae7473364a897417cf37d47f5f0f03 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 1 Mar 2021 02:39:44 +0300 Subject: [PATCH 242/259] [MLIR] Use default SmallVec size, move readme file (#194) * Use default values for SmallVector N * Move readme file, remove old file --- mlir-compiler/mlir-compiler/src/lowering.cpp | 14 +- mlir-compiler/mlir-compiler/src/mangle.cpp | 2 +- .../src/pipelines/lower_to_llvm.cpp | 10 +- .../src/pipelines/parallel_to_tbb.cpp | 6 +- .../src/pipelines/plier_to_linalg.cpp | 8 +- .../src/pipelines/plier_to_std.cpp | 12 +- .../mlir-compiler/src/py_linalg_resolver.cpp | 46 +++--- mlir-compiler/mlir-compiler/test.py | 136 ------------------ mlir-compiler/plier/src/dialect.cpp | 2 +- .../plier/src/rewrites/call_lowering.cpp | 4 +- .../src/rewrites/canonicalize_reductions.cpp | 6 +- .../src/rewrites/promote_to_parallel.cpp | 2 +- mlir-compiler/{mlir-compiler => }/readme.md | 0 13 files changed, 56 insertions(+), 192 deletions(-) delete mode 100644 mlir-compiler/mlir-compiler/test.py rename mlir-compiler/{mlir-compiler => }/readme.md (100%) diff --git a/mlir-compiler/mlir-compiler/src/lowering.cpp b/mlir-compiler/mlir-compiler/src/lowering.cpp index d27743d8983..befeb7f34eb 100644 --- a/mlir-compiler/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/mlir-compiler/src/lowering.cpp @@ -356,7 +356,7 @@ struct plier_lowerer final mlir::Value lower_build_tuple(const py::handle& inst) { auto items = inst.attr("items").cast(); - mlir::SmallVector args; + mlir::SmallVector args; for (auto item : items) { args.push_back(loadvar(item)); @@ -396,8 +396,8 @@ struct plier_lowerer final auto kws = expr.attr("kws").cast(); auto vararg = expr.attr("vararg"); - mlir::SmallVector args_list; - mlir::SmallVector, 8> kwargs_list; + mlir::SmallVector args_list; + mlir::SmallVector> kwargs_list; for (auto a : args) { args_list.push_back(loadvar(a)); @@ -558,7 +558,7 @@ struct plier_lowerer final mlir::FunctionType get_func_type(const py::handle& fnargs, const py::handle& restype) { auto ret = get_obj_type(restype()); - llvm::SmallVector args; + llvm::SmallVector args; for (auto arg : fnargs()) { args.push_back(get_obj_type(arg)); @@ -608,7 +608,7 @@ struct plier_lowerer final if (auto op = mlir::dyn_cast(term)) { auto dest = op.getDest(); - mlir::SmallVector args; + mlir::SmallVector args; build_arg_list(dest, info.outgoing_phi_nodes, args); op.erase(); builder.create(builder.getUnknownLoc(), dest, args); @@ -618,8 +618,8 @@ struct plier_lowerer final auto true_dest = op.trueDest(); auto false_dest = op.falseDest(); auto cond = op.getCondition(); - mlir::SmallVector true_args; - mlir::SmallVector false_args; + mlir::SmallVector true_args; + mlir::SmallVector false_args; build_arg_list(true_dest, info.outgoing_phi_nodes, true_args); build_arg_list(false_dest, info.outgoing_phi_nodes, false_args); op.erase(); diff --git a/mlir-compiler/mlir-compiler/src/mangle.cpp b/mlir-compiler/mlir-compiler/src/mangle.cpp index 826194c7d57..cc06a538323 100644 --- a/mlir-compiler/mlir-compiler/src/mangle.cpp +++ b/mlir-compiler/mlir-compiler/src/mangle.cpp @@ -132,7 +132,7 @@ template void mangle_ident_impl(llvm::raw_ostream& res, llvm::StringRef ident, F&& template_params) { assert(!ident.empty()); - llvm::SmallVector parts; + llvm::SmallVector parts; ident.split(parts, '.'); assert(!parts.empty()); auto write_part = [&](auto part) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 000e63702f3..0e9c3062a16 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -406,7 +406,7 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) auto old_type = func.getType(); assert(old_type.getNumResults() <= 1); auto& ctx = *old_type.getContext(); - llvm::SmallVector args; + llvm::SmallVector args; auto ptr = [&](auto arg) { @@ -426,7 +426,7 @@ void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) builder.setInsertionPointToStart(&func.getBody().front()); auto loc = builder.getUnknownLoc(); - llvm::SmallVector new_args; + llvm::SmallVector new_args; auto process_arg = [&](mlir::Type type) { if (auto memref_type = type.dyn_cast()) @@ -871,8 +871,8 @@ struct LowerParallel : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite(plier::ParallelOp op, mlir::PatternRewriter &rewriter) const override { - llvm::SmallVector context_vars; - llvm::SmallVector context_constants; + llvm::SmallVector context_vars; + llvm::SmallVector context_constants; llvm::DenseSet context_vars_set; auto add_context_var = [&](mlir::Value value) { @@ -930,7 +930,7 @@ struct LowerParallel : public mlir::OpRewritePattern auto context_type = [&]()->mlir::LLVM::LLVMStructType { - llvm::SmallVector fields; + llvm::SmallVector fields; fields.reserve(context_vars.size()); for (auto var : context_vars) { diff --git a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp index 984fe92aa47..4d19587a223 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -78,7 +78,7 @@ struct ParallelToTbb : public mlir::OpRewritePattern auto loc = op.getLoc(); mlir::BlockAndValueMapping mapping; - llvm::SmallVector reduce_vars(op.getNumResults()); + llvm::SmallVector reduce_vars(op.getNumResults()); for (auto it : llvm::enumerate(op.getResultTypes())) { auto type = it.value(); @@ -114,7 +114,7 @@ struct ParallelToTbb : public mlir::OpRewritePattern auto orig_step = op.step().front(); auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value thread_index) { - llvm::SmallVector initVals(op.initVals().size()); + llvm::SmallVector initVals(op.initVals().size()); for (auto it : llvm::enumerate(op.initVals())) { auto reduce_var = reduce_vars[it.index()]; @@ -144,7 +144,7 @@ struct ParallelToTbb : public mlir::OpRewritePattern { return mlir::isa(op); }); - llvm::SmallVector yield_args; + llvm::SmallVector yield_args; yield_args.reserve(args.size()); for (auto it : llvm::enumerate(reduce_ops)) { diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 7114e2584f4..00a3aaa5add 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -117,7 +117,7 @@ mlir::Type map_array_type(mlir::MLIRContext& ctx, mlir::TypeConverter& conveter, { if (auto type = conveter.convertType(plier::PyType::get(&ctx, desc->name))) { - llvm::SmallVector shape(desc->dims, -1); + llvm::SmallVector shape(desc->dims, -1); return mlir::RankedTensorType::get(shape, type); } } @@ -350,7 +350,7 @@ struct GetitemOpLowering : public mlir::OpRewritePattern } auto loc = op.getLoc(); - llvm::SmallVector indices; + llvm::SmallVector indices; if (auto tuple_type = index.getType().template dyn_cast()) { indices.resize(tuple_type.size()); @@ -572,7 +572,7 @@ struct SetitemOpLowering : public mlir::OpRewritePattern rerun_std_pipeline(op); } - llvm::SmallVector indices; + llvm::SmallVector indices; if (auto tuple_type = index.getType().template dyn_cast()) { indices.resize(tuple_type.size()); @@ -617,7 +617,7 @@ struct ArrayShape : public mlir::OpRewritePattern return mlir::failure(); } - llvm::SmallVector dims(rank); + llvm::SmallVector dims(rank); for (size_t i = 0; i < rank; ++i) { auto dim = rewriter.create(op.getLoc(), op.value(), i); diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index 27190ae9e39..5786b2f40fd 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -113,7 +113,7 @@ mlir::Type map_unituple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) !name.consumeInteger(10, count) && name.consume_front(")")) { - llvm::SmallVector types(count, type); + llvm::SmallVector types(count, type); return mlir::TupleType::get(&ctx, types); } return nullptr; @@ -125,7 +125,7 @@ mlir::Type map_tuple_type(mlir::MLIRContext& ctx, llvm::StringRef& name) { return nullptr; } - llvm::SmallVector types; + llvm::SmallVector types; while (true) { if (name.consume_front(")")) @@ -500,7 +500,7 @@ template void replace_op(mlir::Operation* op, mlir::PatternRewriter& rewriter, mlir::Type new_type, mlir::ValueRange operands) { assert(nullptr != op); - llvm::SmallVector new_operands(operands.size()); + llvm::SmallVector new_operands(operands.size()); for (auto it : llvm::enumerate(operands)) { new_operands[it.index()] = do_cast(new_type, it.value(), rewriter); @@ -776,7 +776,7 @@ struct ScfIfRewrite : public mlir::OpRewritePattern } mlir::BlockAndValueMapping mapper; - llvm::SmallVector yield_vals; + llvm::SmallVector yield_vals; auto copy_block = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::Block& block) { mapper.clear(); @@ -954,9 +954,9 @@ struct ScfWhileRewrite : public mlir::OpRewritePattern } mlir::BlockAndValueMapping mapper; - llvm::SmallVector yield_vars; + llvm::SmallVector yield_vars; auto before_block_args = before_block->getArguments(); - llvm::SmallVector orig_vars(before_block_args.begin(), before_block_args.end()); + llvm::SmallVector orig_vars(before_block_args.begin(), before_block_args.end()); auto before_body = [&](mlir::OpBuilder& builder, mlir::Location loc, mlir::ValueRange iterargs) { diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 42320cee92a..3b5d742c45a 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -278,7 +278,7 @@ auto get_types(mlir::ValueRange values) auto get_agrs_from_tuple(py::handle args, llvm::function_ref unpack) { - llvm::SmallVector ret; + llvm::SmallVector ret; if (args.is_none()) { return ret; @@ -301,7 +301,7 @@ auto get_agrs_from_tuple(py::handle args, llvm::function_ref ret(iterators.size()); + llvm::SmallVector ret(iterators.size()); for (auto it : llvm::enumerate(iterators)) { ret[it.index()] = mlir::StringAttr::get(it.value().cast(), &ctx).getValue(); @@ -317,7 +317,7 @@ mlir::AffineMapAttr get_affine_map_attr(py::handle obj, mlir::MLIRContext& ctx) auto get_affine_maps(py::list maps, mlir::MLIRContext& ctx) { - llvm::SmallVector ret(maps.size()); + llvm::SmallVector ret(maps.size()); for (auto it : llvm::enumerate(maps)) { ret[it.index()] = get_affine_map_attr(it.value(), ctx).getValue(); @@ -327,7 +327,7 @@ auto get_affine_maps(py::list maps, mlir::MLIRContext& ctx) auto get_generic_op_body_types(mlir::ValueRange inputs, mlir::ValueRange outputs) { - llvm::SmallVector ret; + llvm::SmallVector ret; ret.reserve(inputs.size() + outputs.size()); for (auto r : {inputs, outputs}) { @@ -349,7 +349,7 @@ auto get_generic_op_body_types(mlir::ValueRange inputs, mlir::ValueRange outputs auto generic_op_body_result_types(mlir::ValueRange outputs) { - llvm::SmallVector ret; + llvm::SmallVector ret; ret.reserve(outputs.size()); for (auto type : outputs.getTypes()) { @@ -436,7 +436,7 @@ mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value mlir::Value cond = builder.create(loc, mlir::CmpIPredicate::eq, one, dim_val); mlir::Value cond2 = builder.create(loc, mlir::CmpIPredicate::ne, target_shape[dim], dim_val); cond = builder.create(loc, cond, cond2); - llvm::SmallVector new_shape(num_dims); + llvm::SmallVector new_shape(num_dims); for (unsigned i = 0 ; i < num_dims; ++i) { if (i == dim) @@ -455,7 +455,7 @@ mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value mlir::Type casted_type = mlir::RankedTensorType::get(shape, src_type.getElementType()); auto casted = builder.create(loc, casted_type, src).getResult(); auto init = builder.create(loc, new_shape, src_type.getElementType()).getResult(); - llvm::SmallVector exprs(num_dims); + llvm::SmallVector exprs(num_dims); for (unsigned i = 0; i < num_dims; ++i) { if (i == dim) @@ -471,7 +471,7 @@ mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value mlir::AffineMap::get(num_dims, 0, exprs, context), mlir::AffineMap::getMultiDimIdentityMap(num_dims, context), }; - llvm::SmallVector iterators(num_dims, "parallel"); + llvm::SmallVector iterators(num_dims, "parallel"); auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) { @@ -514,12 +514,12 @@ py::object broadcast_impl(py::capsule context, py::tuple args) auto& ctx = get_py_context(context); auto loc = ctx.loc; auto& builder = ctx.builder; - llvm::SmallVector mlir_args(args.size()); + llvm::SmallVector mlir_args(args.size()); for (auto it : llvm::enumerate(args)) { mlir_args[it.index()] = ctx.context.unwrap_val(loc, builder, it.value()); } - using shape_t = llvm::SmallVector; + using shape_t = llvm::SmallVector; auto get_shape = [&](mlir::Value val)->llvm::Optional> { auto type = val.getType(); @@ -544,7 +544,7 @@ py::object broadcast_impl(py::capsule context, py::tuple args) return {}; }; mlir::Type res_type; - mlir::SmallVector shape_vals; + mlir::SmallVector shape_vals; if (auto shape_and_type = get_shape(mlir_args.front())) { res_type = shape_and_type->second; @@ -593,7 +593,7 @@ py::object broadcast_impl(py::capsule context, py::tuple args) return std::move(ret); } - llvm::SmallVector shape(static_cast(shape_vals.size()), -1); + llvm::SmallVector shape(static_cast(shape_vals.size()), -1); auto tensor_type = mlir::RankedTensorType::get(shape, res_type); for (auto it : llvm::enumerate(mlir_args)) { @@ -617,7 +617,7 @@ py::object broadcast_impl(py::capsule context, py::tuple args) // mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()).getMajorSubMap(src_num_dims), mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()), }; - llvm::SmallVector iterators(num_dims, "parallel"); + llvm::SmallVector iterators(num_dims, "parallel"); auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) { assert(values.size() == 2); @@ -669,7 +669,7 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt else { auto index_type = builder.getIndexType(); - llvm::SmallVector shape_val(count); + llvm::SmallVector shape_val(count); llvm::SmallVector static_shape(count, -1); for (size_t i = 0; i < count; ++i) { @@ -692,7 +692,7 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt { builder.create(loc, val); }; - llvm::SmallVector shape(count, -1); + llvm::SmallVector shape(count, -1); auto type = mlir::RankedTensorType::get(shape, elem_type); init = builder.create(loc, type, shape_val, body); } @@ -723,7 +723,7 @@ py::object fill_tensor_impl(py::capsule context, py::handle tensor, py::handle v mlir::AffineMap affine_maps[] = { mlir::AffineMap::getMultiDimIdentityMap(rank, builder.getContext()), }; - llvm::SmallVector iterators(rank, "parallel"); + llvm::SmallVector iterators(rank, "parallel"); auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) { assert(values.size() == 1); @@ -763,7 +763,7 @@ py::object generic_impl(py::capsule context, py::handle inputs, py::handle outpu auto cast_values = [&](mlir::ValueRange vals, mlir::TypeRange types) { assert(vals.size() == types.size()); - llvm::SmallVector ret(vals.size()); + llvm::SmallVector ret(vals.size()); auto do_cast = [&](mlir::Value val, mlir::Type type) { if (val.getType() == type) @@ -816,7 +816,7 @@ py::object from_elements_impl(py::capsule context, py::handle values, py::handle auto loc = ctx.loc; auto type = unwrap_type(dtype); - llvm::SmallVector vals(container_size(values)); + llvm::SmallVector vals(container_size(values)); container_iterate(values, [&](auto index, py::handle obj) { if (py::isinstance(obj, ctx.context.var)) @@ -855,7 +855,7 @@ py::object extract_impl(py::capsule context, py::handle value, py::handle indice auto& builder = ctx.builder; auto loc = ctx.loc; - llvm::SmallVector ind(container_size(indices)); + llvm::SmallVector ind(container_size(indices)); container_iterate(indices, [&](auto index, py::handle obj) { if (py::isinstance(obj, ctx.context.var)) @@ -888,10 +888,10 @@ py::object reshape_impl(py::capsule context, py::handle tensor, py::int_ out_dim } auto elem_type = tensor_val.getType().cast().getElementType(); auto new_dims = out_dims.cast(); - llvm::SmallVector dims(new_dims, -1); + llvm::SmallVector dims(new_dims, -1); auto new_type = mlir::RankedTensorType::get(dims, elem_type); - llvm::SmallVector affine_maps(container_size(maps)); + llvm::SmallVector affine_maps(container_size(maps)); container_iterate(maps, [&](auto index, py::handle obj) { affine_maps[index] = get_affine_map_attr(obj, *builder.getContext()); @@ -937,14 +937,14 @@ py::object shape_impl(py::capsule context, py::capsule ssa_val) auto loc = ctx.loc; auto mlir_type = value.getType().cast(); auto shape = mlir_type.getShape(); - llvm::SmallVector shape_vals(shape.size()); + llvm::SmallVector shape_vals(shape.size()); for (auto it : llvm::enumerate(shape)) { auto i = it.index(); mlir::Value mlir_dim = builder.create(loc, value, i); shape_vals[i] = mlir_dim; } - llvm::SmallVector shape_types(shape.size(), builder.getIndexType()); + llvm::SmallVector shape_types(shape.size(), builder.getIndexType()); auto shape_type = mlir::TupleType::get(builder.getContext(), shape_types); auto shape_var = builder.create(loc, shape_type, shape_vals); return ctx.context.create_var(context, shape_var.getResult()); diff --git a/mlir-compiler/mlir-compiler/test.py b/mlir-compiler/mlir-compiler/test.py deleted file mode 100644 index 9ef63842f92..00000000000 --- a/mlir-compiler/mlir-compiler/test.py +++ /dev/null @@ -1,136 +0,0 @@ -import numba -import numpy as np - -_tests_total = 0 -_tests_passes = 0 -_failed_tests = [] - -def ret(a): - return a - -def const(): - return 42 - -def sum1(a): - return a + 42 - -def sum2(a, b): - return a + b - -def cond(a, b): - if a > b: - return a - else: - return b - -def var(a): - c = 1 - c = c + a - return c - -def jump(a, b): - c = 3 - if a > 5: - c = c + a - c = c + b - return c - -sum2_jit = numba.njit()(sum2) - -def call(a, b, c): - return sum2_jit(a, sum2_jit(b, c)) - -def tuple(a,b,c): - t = (a,b,c) - return t[0] + t[1] + t[2] - -def arr_loop(): - res = 0 - arr = [1,2,3] - for i in arr: - res = res + i - return res - -def range_loop(n): - res = 0 - res1 = 2 - for i in range(n): - if i > 5: - res = res + i - else: - res1 = res1 + i * 2 - return res + res1 - -def range_loop_nested(a, b, c): - res = 0 - for i in range(a): - for j in range(b): - for k in range(c): - res = res + i + j * 10 + k * 100 - return res - -def np_getitem(a, b): - return a[b] - -def np_getitem2(a, b, c): - return b[c] - -def np_sum(a): - return a.sum() - -def np_add(a, b): - return np.add(a, b).sum() - -def np_add2(a, b, c): - t = np.add(a, b) - return np.add(t, c).sum() - -def test(func, params): - global _tests_total - global _tests_passes - global _failed_tests - _tests_total += 1 - test_name = f'{func.__name__} {params}' - print('test', test_name, '... ', end='') - result = func(*params) - wrapped = numba.njit()(func) - try: - res = wrapped(*params) - if (res != result): - raise Exception(f'Invalid value "{res}", expected "{result}"') - print('SUCCESS') - _tests_passes += 1 - except Exception as e: - print(e) - print('FAILED') - _failed_tests.append(test_name) - -print('=========================================================') - -test(ret, (7,)) -test(const, ()) -test(sum1, (5,)) -test(sum2, (3,4)) -test(cond, (5,6)) -test(cond, (8,7)) -test(var, (8,)) -test(jump, (1,8)) -test(jump, (7,8)) -test(call, (1,2,3)) -test(tuple, (1,2,3)) -test(tuple, (1,2.0,3)) -test(arr_loop, ()) -test(range_loop, (8,)) -test(range_loop_nested, (8,9,10)) -test(sum2, (np.asarray([1,2,3]),np.asarray([4,5,6]))) -test(np_getitem, (np.asarray([1,2,3]),1)) -test(np_getitem2, (np.asarray([1,2,3]),np.asarray([4,5,6]),1)) -test(np_sum, (np.asarray([1,2,3]),)) -test(np_add, (np.asarray([1,2,3]),np.asarray([4,5,6]))) -test(np_add2, (np.asarray([1,2,3]),np.asarray([4,5,6]),np.asarray([7,8,9]))) - -print(f'Tests passed: {_tests_passes}/{_tests_total}') -if (len(_failed_tests) != 0): - print('Failed:') - for t in _failed_tests: - print(t) diff --git a/mlir-compiler/plier/src/dialect.cpp b/mlir-compiler/plier/src/dialect.cpp index 32ae0ff85fb..d25187809db 100644 --- a/mlir-compiler/plier/src/dialect.cpp +++ b/mlir-compiler/plier/src/dialect.cpp @@ -169,7 +169,7 @@ void PyCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir all_args.reserve(args.size() + kwargs.size()); std::copy(args.begin(), args.end(), std::back_inserter(all_args)); auto kw_start = static_cast(all_args.size()); - mlir::SmallVector kw_names; + mlir::SmallVector kw_names; kw_names.reserve(kwargs.size()); for (auto& a : kwargs) { diff --git a/mlir-compiler/plier/src/rewrites/call_lowering.cpp b/mlir-compiler/plier/src/rewrites/call_lowering.cpp index 6de918fa317..29bd46e1627 100644 --- a/mlir-compiler/plier/src/rewrites/call_lowering.cpp +++ b/mlir-compiler/plier/src/rewrites/call_lowering.cpp @@ -18,8 +18,8 @@ mlir::LogicalResult plier::CallOpLowering::matchAndRewrite(plier::PyCallOp op, m return mlir::failure(); } - llvm::SmallVector args; - llvm::SmallVector, 8> kwargs; + llvm::SmallVector args; + llvm::SmallVector> kwargs; auto getattr = mlir::dyn_cast_or_null(operands[0].getDefiningOp()); if (getattr) { diff --git a/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp index 49c4f11cf32..37c0f31edf9 100644 --- a/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp +++ b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp @@ -74,7 +74,7 @@ mlir::Value createScalarLoad( else if (llvm::all_of(shape, [](auto s) { return s == 1; })) { auto index = builder.create(loc, 0); - llvm::SmallVector indices(shape.size(), index); + llvm::SmallVector indices(shape.size(), index); return builder.create(loc, memref, indices); } else @@ -95,7 +95,7 @@ void createScalarStore( else if (llvm::all_of(shape, [](auto s) { return s == 1; })) { auto index = builder.create(loc, 0); - llvm::SmallVector indices(shape.size(), index); + llvm::SmallVector indices(shape.size(), index); builder.create(loc, val, memref, indices); } else @@ -107,7 +107,7 @@ void createScalarStore( mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const { - llvm::SmallVector to_process; + llvm::SmallVector to_process; for (auto& current : op.getLoopBody().front()) { if (auto load = mlir::dyn_cast(current)) diff --git a/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp index 6ae712a285e..2c1b84391eb 100644 --- a/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp +++ b/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp @@ -43,7 +43,7 @@ mlir::LogicalResult plier::PromoteToParallel::matchAndRewrite(mlir::scf::ForOp o auto& old_body = op.getLoopBody().front(); auto old_yield = mlir::cast(old_body.getTerminator()); auto reduce_args = old_body.getArguments().drop_front(); - llvm::SmallVector, 8> reduce_bodies(reduce_args.size()); + llvm::SmallVector> reduce_bodies(reduce_args.size()); llvm::DenseSet reduce_ops; for (auto it : llvm::enumerate(reduce_args)) { diff --git a/mlir-compiler/mlir-compiler/readme.md b/mlir-compiler/readme.md similarity index 100% rename from mlir-compiler/mlir-compiler/readme.md rename to mlir-compiler/readme.md From b0900a5bced0ba36932b18cb6278f54816b8c7c4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 2 Mar 2021 02:55:05 +0300 Subject: [PATCH 243/259] [MLIR] some multiindex ParallelOp supports (#195) * change builder func * accept list of bounds * plier::ParallelOp nested loops support * fix to nested parallel loops --- .../src/pipelines/lower_to_llvm.cpp | 111 ++++++++++++++---- .../src/pipelines/parallel_to_tbb.cpp | 12 +- mlir-compiler/plier/include/plier/PlierOps.td | 17 ++- mlir-compiler/plier/src/dialect.cpp | 36 ++++-- numba/mlir/tests/test_numpy.py | 9 ++ numba/np/ufunc/tbbpool.cpp | 107 ++++++++++++++--- 6 files changed, 227 insertions(+), 65 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 0e9c3062a16..0fd8e6e65a3 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include @@ -871,6 +872,7 @@ struct LowerParallel : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite(plier::ParallelOp op, mlir::PatternRewriter &rewriter) const override { + auto num_loops = op.getNumLoops(); llvm::SmallVector context_vars; llvm::SmallVector context_constants; llvm::DenseSet context_vars_set; @@ -951,6 +953,24 @@ struct LowerParallel : public mlir::OpRewritePattern auto context_ptr_type = mlir::LLVM::LLVMPointerType::get(context_type); auto loc = op.getLoc(); + auto index_type = rewriter.getIndexType(); + auto llvm_index_type = mlir::IntegerType::get(op.getContext(), 64); // TODO + auto to_llvm_index = [&](mlir::Value val)->mlir::Value + { + if (val.getType() != llvm_index_type) + { + return rewriter.create(loc, llvm_index_type, val); + } + return val; + }; + auto from_llvm_index = [&](mlir::Value val)->mlir::Value + { + if (val.getType() != index_type) + { + return rewriter.create(loc, index_type, val); + } + return val; + }; auto llvm_i32_type = mlir::IntegerType::get(op.getContext(), 32); auto zero = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0)); auto one = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(1)); @@ -971,12 +991,29 @@ struct LowerParallel : public mlir::OpRewritePattern auto void_ptr_type = mlir::LLVM::LLVMPointerType::get(mlir::IntegerType::get(op.getContext(), 8)); auto context_abstract = rewriter.create(loc, void_ptr_type, context); - auto index_type = rewriter.getIndexType(); + auto input_range_type = [&]() + { + const mlir::Type members[] = { + llvm_index_type, // lower_bound + llvm_index_type, // upper_bound + llvm_index_type, // step + }; + return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), members); + }(); + auto input_range_ptr = mlir::LLVM::LLVMPointerType::get(input_range_type); + auto range_type = [&]() + { + const mlir::Type members[] = { + llvm_index_type, // lower_bound + llvm_index_type, // upper_bound + }; + return mlir::LLVM::LLVMStructType::getLiteral(op.getContext(), members); + }(); + auto range_ptr = mlir::LLVM::LLVMPointerType::get(range_type); auto func_type = [&]() { - mlir::Type args[] = { - index_type, // lower_bound - index_type, // upper_bound + const mlir::Type args[] = { + range_ptr, // bounds index_type, // thread index void_ptr_type // context }; @@ -1014,21 +1051,34 @@ struct LowerParallel : public mlir::OpRewritePattern auto entry = func.addEntryBlock(); auto loc = rewriter.getUnknownLoc(); mlir::OpBuilder::InsertionGuard guard(rewriter); - mapping.map(old_entry.getArgument(0), entry->getArgument(0)); - mapping.map(old_entry.getArgument(1), entry->getArgument(1)); - mapping.map(old_entry.getArgument(2), entry->getArgument(2)); rewriter.setInsertionPointToStart(entry); + auto pos0 = rewriter.getI64ArrayAttr(0); + auto pos1 = rewriter.getI64ArrayAttr(1); + for (unsigned i = 0; i < num_loops; ++i) + { + auto arg = entry->getArgument(0); + const mlir::Value indices[] = { + rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast(i))) + }; + auto ptr = rewriter.create(loc, range_ptr, arg, indices); + auto dims = rewriter.create(loc, ptr); + auto lower = rewriter.create(loc, llvm_index_type, dims, pos0); + auto upper = rewriter.create(loc, llvm_index_type, dims, pos1); + mapping.map(old_entry.getArgument(i), from_llvm_index(lower)); + mapping.map(old_entry.getArgument(i + num_loops), from_llvm_index(upper)); + } + mapping.map(old_entry.getArgument(2 * num_loops), entry->getArgument(1)); // thread index for (auto arg : context_constants) { rewriter.clone(*arg, mapping); } - auto context_ptr = rewriter.create(loc, context_ptr_type, entry->getArgument(3)); + auto context_ptr = rewriter.create(loc, context_ptr_type, entry->getArgument(2)); auto zero = rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(0)); for (auto it : llvm::enumerate(context_vars)) { auto index = it.index(); auto old_val = it.value(); - mlir::Value indices[] = { + const mlir::Value indices[] = { zero, rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast(index))) }; @@ -1060,21 +1110,39 @@ struct LowerParallel : public mlir::OpRewritePattern { return sym; } - mlir::Type args[] = { - index_type, // lower bound - index_type, // upper bound - index_type, // step - func_type, - void_ptr_type + const mlir::Type args[] = { + input_range_ptr, // bounds + index_type, // num_loops + func_type, // func + void_ptr_type // context }; - auto func_type = mlir::FunctionType::get(op.getContext(), args, {}); - return plier::add_function(rewriter, mod, func_name, func_type); + auto parallel_func_type = mlir::FunctionType::get(op.getContext(), args, {}); + return plier::add_function(rewriter, mod, func_name, parallel_func_type); }(); auto func_addr = rewriter.create(loc, func_type, rewriter.getSymbolRefAttr(outlined_func)); - mlir::Value pf_args[] = { - op.lowerBound(), - op.upperBound(), - op.step(), + + auto num_loops_var = rewriter.create(loc, num_loops); + auto input_ranges = rewriter.create(loc, input_range_ptr, to_llvm_index(num_loops_var), 0); + for (unsigned i = 0; i < num_loops; ++i) + { + mlir::Value input_range = rewriter.create(loc, input_range_type); + auto insert = [&](mlir::Value val, unsigned index) + { + input_range = rewriter.create(loc, input_range, val, rewriter.getI64ArrayAttr(index)); + }; + insert(to_llvm_index(op.lowerBounds()[i]), 0); + insert(to_llvm_index(op.upperBounds()[i]), 1); + insert(to_llvm_index(op.steps()[i]), 2); + const mlir::Value indices[] = { + rewriter.create(loc, llvm_i32_type, rewriter.getI32IntegerAttr(static_cast(i))) + }; + auto ptr = rewriter.create(loc, input_range_ptr, input_ranges, indices); + rewriter.create(loc, input_range, ptr); + } + + const mlir::Value pf_args[] = { + input_ranges, + num_loops_var, func_addr, context_abstract }; @@ -1226,6 +1294,7 @@ void populate_lower_to_llvm_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); pm.addPass(mlir::createLowerToCFGPass()); + pm.addPass(mlir::createCanonicalizerPass()); // pm.addPass(std::make_unique()); pm.addNestedPass(std::make_unique()); pm.addPass(std::make_unique(getLLVMOptions())); diff --git a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp index 4d19587a223..3f708381340 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -46,10 +46,6 @@ struct ParallelToTbb : public mlir::OpRewritePattern { return mlir::failure(); } - if (op.getNumLoops() != 1) - { - return mlir::failure(); - } bool need_parallel = op->hasAttr(plier::attributes::getParallelName()) || !op->getParentOfType(); if (!need_parallel) @@ -109,10 +105,10 @@ struct ParallelToTbb : public mlir::OpRewritePattern rewriter.create(loc, reduce_lower_bound, reduce_upper_bound, reduce_step, llvm::None, reduce_init_body_builder); auto& old_body = op.getLoopBody().front(); - auto orig_lower_bound = op.lowerBound().front(); - auto orig_upper_bound = op.upperBound().front(); - auto orig_step = op.step().front(); - auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value thread_index) + auto orig_lower_bound = op.lowerBound(); + auto orig_upper_bound = op.upperBound(); + auto orig_step = op.step(); + auto body_builder = [&](mlir::OpBuilder &builder, ::mlir::Location loc, mlir::ValueRange lower_bound, mlir::ValueRange upper_bound, mlir::Value thread_index) { llvm::SmallVector initVals(op.initVals().size()); for (auto it : llvm::enumerate(op.initVals())) diff --git a/mlir-compiler/plier/include/plier/PlierOps.td b/mlir-compiler/plier/include/plier/PlierOps.td index 084ffea2574..e8aee0d66b1 100644 --- a/mlir-compiler/plier/include/plier/PlierOps.td +++ b/mlir-compiler/plier/include/plier/PlierOps.td @@ -215,21 +215,26 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { } def ParallelOp : Plier_Op<"parallel", - [DeclareOpInterfaceMethods, + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"plier::YieldOp">, RecursiveSideEffects]> { - let arguments = (ins Index:$lowerBound, - Index:$upperBound, - Index:$step); + let arguments = (ins Variadic:$lowerBounds, + Variadic:$upperBounds, + Variadic:$steps); let regions = (region SizedRegion<1>:$region); let skipDefaultBuilders = 1; let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$lowerBound, "::mlir::Value":$upperBound, "::mlir::Value":$step, - CArg<"::mlir::function_ref", + OpBuilderDAG<(ins "::mlir::ValueRange":$lowerBounds, "::mlir::ValueRange":$upperBounds, "::mlir::ValueRange":$steps, + CArg<"::mlir::function_ref", "nullptr">)> ]; + + let extraClassDeclaration = [{ + unsigned getNumLoops() { return steps().size(); } + }]; } def YieldOp : Plier_Op<"yield", [NoSideEffect, ReturnLike, Terminator, diff --git a/mlir-compiler/plier/src/dialect.cpp b/mlir-compiler/plier/src/dialect.cpp index d25187809db..eae67beb7fb 100644 --- a/mlir-compiler/plier/src/dialect.cpp +++ b/mlir-compiler/plier/src/dialect.cpp @@ -354,23 +354,33 @@ bool ParallelOp::isDefinedOutsideOfLoop(mlir::Value value) void ParallelOp::build( mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, - mlir::Value lowerBound, mlir::Value upperBound, mlir::Value step, - mlir::function_ref bodyBuilder) { - odsState.addOperands({lowerBound, upperBound, step}); + mlir::ValueRange lowerBounds, mlir::ValueRange upperBounds, mlir::ValueRange steps, + mlir::function_ref bodyBuilder) { + assert(lowerBounds.size() == upperBounds.size()); + assert(lowerBounds.size() == steps.size()); + odsState.addOperands(lowerBounds); + odsState.addOperands(upperBounds); + odsState.addOperands(steps); + odsState.addAttribute( + ParallelOp::getOperandSegmentSizeAttr(), + odsBuilder.getI32VectorAttr({static_cast(lowerBounds.size()), + static_cast(upperBounds.size()), + static_cast(steps.size())})); auto bodyRegion = odsState.addRegion(); - bodyRegion->push_back(new mlir::Block); - auto& bodyBlock = bodyRegion->front(); - bodyBlock.addArgument(odsBuilder.getIndexType()); // lower bound - bodyBlock.addArgument(odsBuilder.getIndexType()); // upper bound - bodyBlock.addArgument(odsBuilder.getIndexType()); // thread index + auto count = lowerBounds.size(); + mlir::OpBuilder::InsertionGuard guard(odsBuilder); + llvm::SmallVector argTypes(count * 2 + 1, odsBuilder.getIndexType()); + auto *bodyBlock = odsBuilder.createBlock(bodyRegion, {}, argTypes); if (bodyBuilder) { - mlir::OpBuilder::InsertionGuard guard(odsBuilder); - odsBuilder.setInsertionPointToStart(&bodyBlock); - bodyBuilder(odsBuilder, odsState.location, bodyBlock.getArgument(0), - bodyBlock.getArgument(1), bodyBlock.getArgument(2)); + odsBuilder.setInsertionPointToStart(bodyBlock); + auto args = bodyBlock->getArguments(); + bodyBuilder(odsBuilder, odsState.location, + args.take_front(count), + args.drop_front(count).take_front(count), + args.back()); ParallelOp::ensureTerminator(*bodyRegion, odsBuilder, odsState.location); } } diff --git a/numba/mlir/tests/test_numpy.py b/numba/mlir/tests/test_numpy.py index 8c770a1fc65..d28d1651753 100644 --- a/numba/mlir/tests/test_numpy.py +++ b/numba/mlir/tests/test_numpy.py @@ -295,5 +295,14 @@ def py_func(a, b): for a, b in itertools.product(test_data, test_data): assert_equal(py_func(a,b), jit_func(a,b)) + def test_parallel(self): + def py_func(a, b): + return np.add(a, b) + + jit_func = njit(py_func, parallel=True) + arr = np.asarray([[[1,2,3],[4,5,6]], + [[1,2,3],[4,5,6]]]) + assert_equal(py_func(arr,arr), jit_func(arr,arr)) + if __name__ == '__main__': unittest.main() diff --git a/numba/np/ufunc/tbbpool.cpp b/numba/np/ufunc/tbbpool.cpp index 74a9bd8a211..1f6f44b0bf8 100644 --- a/numba/np/ufunc/tbbpool.cpp +++ b/numba/np/ufunc/tbbpool.cpp @@ -15,7 +15,9 @@ Implement parallel vectorize workqueue on top of Intel TBB. #include #include #include +#include #include +#include #include "workqueue.h" #include "gufunc_scheduler.h" @@ -221,33 +223,104 @@ parallel_for(void *fn, char **args, size_t *dimensions, size_t *steps, void *dat }); } -using parallel_for2_fptr = void(*)(size_t, size_t, size_t, void*); -static void parallel_for2(size_t lower_bound, size_t upper_bound, size_t step, parallel_for2_fptr func, void* ctx) +struct InputRange +{ + size_t lower; + size_t upper; + size_t step; +}; + +struct Range +{ + size_t lower; + size_t upper; +}; + +struct Dim +{ + Range val; + Dim* prev; +}; + +using parallel_for2_fptr = void(*)(const Range*, size_t, void*); + +static void parallel_for2_nested(const InputRange* input_ranges, Range* ranges, size_t depth, size_t num_threads, size_t num_loops, Dim* prev_dim, parallel_for2_fptr func, void* ctx) +{ + auto input = input_ranges[depth]; + auto lower_bound = input.lower; + auto upper_bound = input.upper; + auto step = input.step; + + if(_DEBUG) + { + printf("parallel_for2_nested: lower_bound=%d, upper_bound=%d, step=%d, depth=%d\n", (int)lower_bound, (int)upper_bound, (int)step, (int)depth); + } + + size_t count = (upper_bound - lower_bound + step - 1) / step; + size_t grain = std::max(size_t(1), std::min(count / num_threads / 2, size_t(64))); + tbb::parallel_for(tbb::blocked_range(0, count, grain), + [&](const tbb::blocked_range& r) + { + auto begin = lower_bound + r.begin() * step; + auto end = lower_bound + r.end() * step; + if(_DEBUG) + { + printf("parallel_for2_nested body: begin=%d, end=%d, depth=%d\n\n", (int)begin, (int)end, (int)depth); + } + auto next = depth + 1; + Dim dim{Range{begin, end}, prev_dim}; + if (next == num_loops) + { + auto thread_index = static_cast(tbb::this_task_arena::current_thread_index()); + auto range_ptr = &ranges[thread_index * num_loops]; + + Dim* current = &dim; + for (size_t i = 0; i < num_loops; ++i) + { + range_ptr[num_loops - i - 1] = current->val; + current = current->prev; + } + func(range_ptr, thread_index, ctx); + } + else + { + parallel_for2_nested(input_ranges, ranges, next, num_threads, num_loops, &dim, func, ctx); + } + }, tbb::auto_partitioner()); +} + +static void parallel_for2(const InputRange* input_ranges, size_t num_loops, parallel_for2_fptr func, void* ctx) { auto context = thread_context; assert(nullptr != context); auto num_threads = context->num_threads; if(_DEBUG) { - printf("parallel_for2 %d %d %d %d\n", (int)lower_bound, (int)upper_bound, (int)step, (int)num_threads); + printf("parallel_for2 num_loops=%d: ", (int)num_loops); + for (size_t i = 0; i < num_loops; ++i) + { + auto r = input_ranges[i]; + printf("(%d, %d, %d) ", (int)r.lower, (int)r.upper, (int)r.step); + } + puts("\n"); } + std::array static_ranges; + std::unique_ptr dyn_ranges; + auto* ranges = [&]()->Range* + { + auto count = num_loops * num_threads; + if (count <= static_ranges.size()) + { + return static_ranges.data(); + } + dyn_ranges.reset(new Range[count]); + return dyn_ranges.get(); + }(); + context->arena.execute([&] { - size_t count = (upper_bound - lower_bound - 1) / step + 1; - size_t grain = std::max(size_t(1), std::min(count / num_threads / 2, size_t(64))); - tbb::parallel_for(tbb::blocked_range(0, count, grain), - [&](const tbb::blocked_range& r) - { - auto thread_index = static_cast(tbb::this_task_arena::current_thread_index()); - auto begin = lower_bound + r.begin() * step; - auto end = lower_bound + r.end() * step; - if(_DEBUG) - { - printf("parallel_for2 body %d %d %d\n", (int)begin, (int)end, (int)thread_index); - } - func(begin, end, thread_index, ctx); - }, tbb::auto_partitioner()); + parallel_for2_nested(input_ranges, ranges, 0, num_threads, num_loops, nullptr, func, ctx); }); } From c3061dcee6c143b99b9be5768adcd2a16a2836d4 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 2 Mar 2021 03:01:48 +0300 Subject: [PATCH 244/259] fix (#196) --- numba/np/ufunc/tbbpool.cpp | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/numba/np/ufunc/tbbpool.cpp b/numba/np/ufunc/tbbpool.cpp index 1f6f44b0bf8..87a829873aa 100644 --- a/numba/np/ufunc/tbbpool.cpp +++ b/numba/np/ufunc/tbbpool.cpp @@ -244,7 +244,7 @@ struct Dim using parallel_for2_fptr = void(*)(const Range*, size_t, void*); -static void parallel_for2_nested(const InputRange* input_ranges, Range* ranges, size_t depth, size_t num_threads, size_t num_loops, Dim* prev_dim, parallel_for2_fptr func, void* ctx) +static void parallel_for2_nested(const InputRange* input_ranges, size_t depth, size_t num_threads, size_t num_loops, Dim* prev_dim, parallel_for2_fptr func, void* ctx) { auto input = input_ranges[depth]; auto lower_bound = input.lower; @@ -272,7 +272,17 @@ static void parallel_for2_nested(const InputRange* input_ranges, Range* ranges, if (next == num_loops) { auto thread_index = static_cast(tbb::this_task_arena::current_thread_index()); - auto range_ptr = &ranges[thread_index * num_loops]; + std::array static_ranges; + std::unique_ptr dyn_ranges; + auto* range_ptr = [&]()->Range* + { + if (num_loops <= static_ranges.size()) + { + return static_ranges.data(); + } + dyn_ranges.reset(new Range[num_loops]); + return dyn_ranges.get(); + }(); Dim* current = &dim; for (size_t i = 0; i < num_loops; ++i) @@ -284,7 +294,7 @@ static void parallel_for2_nested(const InputRange* input_ranges, Range* ranges, } else { - parallel_for2_nested(input_ranges, ranges, next, num_threads, num_loops, &dim, func, ctx); + parallel_for2_nested(input_ranges, next, num_threads, num_loops, &dim, func, ctx); } }, tbb::auto_partitioner()); } @@ -305,22 +315,9 @@ static void parallel_for2(const InputRange* input_ranges, size_t num_loops, para puts("\n"); } - std::array static_ranges; - std::unique_ptr dyn_ranges; - auto* ranges = [&]()->Range* - { - auto count = num_loops * num_threads; - if (count <= static_ranges.size()) - { - return static_ranges.data(); - } - dyn_ranges.reset(new Range[count]); - return dyn_ranges.get(); - }(); - context->arena.execute([&] { - parallel_for2_nested(input_ranges, ranges, 0, num_threads, num_loops, nullptr, func, ctx); + parallel_for2_nested(input_ranges, 0, num_threads, num_loops, nullptr, func, ctx); }); } From 899647fcf6ee5536a1989700431b8b73dc84f3f6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 2 Mar 2021 12:56:10 +0300 Subject: [PATCH 245/259] update to llvm master (#197) --- mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/mlir-compiler/src/lowering.cpp | 2 ++ .../src/pipelines/lower_to_llvm.cpp | 25 ++++++++++--------- .../src/pipelines/plier_to_linalg.cpp | 6 ++--- .../mlir-compiler/src/py_linalg_resolver.cpp | 4 +-- mlir-compiler/plier/src/compiler/compiler.cpp | 2 +- mlir-compiler/plier/src/dialect.cpp | 4 +-- mlir-compiler/plier/src/rewrites/cse.cpp | 2 +- .../plier/src/transforms/pipeline_utils.cpp | 4 +-- 9 files changed, 27 insertions(+), 24 deletions(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index 6ededfdfa6c..e0624e4fe90 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -ba87f99168c93461b28a4aa2d05e238ff774d57a +7d09e1d7cf27ce781e83f9d388a7a3e1e6487ead diff --git a/mlir-compiler/mlir-compiler/src/lowering.cpp b/mlir-compiler/mlir-compiler/src/lowering.cpp index befeb7f34eb..96c0bcf92cc 100644 --- a/mlir-compiler/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/mlir-compiler/src/lowering.cpp @@ -13,6 +13,7 @@ #include #include +#include #include @@ -660,6 +661,7 @@ py::bytes gen_ll_module(mlir::ModuleOp mod) std::unique_ptr ll_mod; plier::scoped_diag_handler(*mod.getContext(), diag_handler, [&]() { + mlir::registerLLVMDialectTranslation(*mod.getContext()); ll_mod = mlir::translateModuleToLLVMIR(mod, ll_ctx); if (nullptr == ll_mod) { diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 0fd8e6e65a3..6825a55e5b3 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -273,8 +273,8 @@ mlir::FuncOp get_to_memref_conversion_func( auto func_type = mlir::FunctionType::get(builder.getContext(), src_type, dst_type); auto loc = builder.getUnknownLoc(); auto new_func = plier::add_function(builder, module, func_name, func_type); - auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); - new_func->setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); + auto alwaysinline = mlir::StringAttr::get(builder.getContext(), "alwaysinline"); + new_func->setAttr("passthrough", mlir::ArrayAttr::get(builder.getContext(), alwaysinline)); mlir::OpBuilder::InsertionGuard guard(builder); auto block = new_func.addEntryBlock(); builder.setInsertionPointToStart(block); @@ -326,8 +326,8 @@ mlir::FuncOp get_from_memref_conversion_func( auto func_type = mlir::FunctionType::get(builder.getContext(), src_type, dst_type); auto loc = builder.getUnknownLoc(); auto new_func = plier::add_function(builder, module, func_name, func_type); - auto alwaysinline = mlir::StringAttr::get("alwaysinline", builder.getContext()); - new_func->setAttr("passthrough", mlir::ArrayAttr::get(alwaysinline, builder.getContext())); + auto alwaysinline = mlir::StringAttr::get(builder.getContext(), "alwaysinline"); + new_func->setAttr("passthrough", mlir::ArrayAttr::get(builder.getContext(), alwaysinline)); mlir::OpBuilder::InsertionGuard guard(builder); auto block = new_func.addEntryBlock(); builder.setInsertionPointToStart(block); @@ -377,10 +377,10 @@ mlir::Attribute get_fastmath_attrs(mlir::MLIRContext& ctx) auto add_pair = [&](auto name, auto val) { const mlir::Attribute attrs[] = { - mlir::StringAttr::get(name, &ctx), - mlir::StringAttr::get(val, &ctx) + mlir::StringAttr::get(&ctx, name), + mlir::StringAttr::get(&ctx, val) }; - return mlir::ArrayAttr::get(attrs, &ctx); + return mlir::ArrayAttr::get(&ctx, attrs); }; const mlir::Attribute attrs[] = { add_pair("denormal-fp-math", "preserve-sign,preserve-sign"), @@ -391,7 +391,7 @@ mlir::Attribute get_fastmath_attrs(mlir::MLIRContext& ctx) add_pair("unsafe-fp-math", "true"), add_pair(plier::attributes::getFastmathName(), "1"), }; - return mlir::ArrayAttr::get(attrs, &ctx); + return mlir::ArrayAttr::get(&ctx, attrs); } void fix_func_sig(LLVMTypeHelper& type_helper, mlir::FuncOp func) @@ -805,9 +805,10 @@ class CheckForPlierTypes : void runOnOperation() override { markAllAnalysesPreserved(); + auto plier_dialect = getContext().getOrLoadDialect(); getOperation()->walk([&](mlir::Operation* op) { - if (op->getName().getDialect() == plier::PlierDialect::getDialectNamespace()) + if (op->getName().getDialect() == plier_dialect) { op->emitOpError(": not all plier ops were translated\n"); signalPassFailure(); @@ -1170,7 +1171,7 @@ struct LowerParallelToCFGPass : mlir::OwningRewritePatternList patterns; patterns.insert(&getContext()); - mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -1194,7 +1195,7 @@ struct PreLLVMLowering : public mlir::PassWrapper(&getContext(), type_helper.get_type_converter()); - mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; @@ -1283,7 +1284,7 @@ struct LLVMLoweringPass : public mlir::PassWrappersetAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), - StringAttr::get(options.dataLayout.getStringRepresentation(), m.getContext())); + StringAttr::get(m.getContext(), options.dataLayout.getStringRepresentation())); } private: diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 00a3aaa5add..df6a55c3ad6 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -156,7 +156,7 @@ bool check_numpy_args(llvm::ArrayRef args, unsigned expected_count) void rerun_std_pipeline(mlir::Operation* op) { assert(nullptr != op); - auto marker = mlir::StringAttr::get(plier_to_std_pipeline_name(), op->getContext()); + auto marker = mlir::StringAttr::get(op->getContext(), plier_to_std_pipeline_name()); auto mod = op->getParentOfType(); assert(nullptr != mod); plier::add_pipeline_jump_marker(mod, marker); @@ -834,7 +834,7 @@ void PostFusionOptPass::runOnOperation() plier::populate_index_propagate_patterns(context, patterns); - mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } struct LoopInvariantCodeMotion : public mlir::OpRewritePattern @@ -894,7 +894,7 @@ void PostLinalgOptPass::runOnOperation() plier::populate_index_propagate_patterns(context, patterns); - mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index 3b5d742c45a..c2585238349 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -304,7 +304,7 @@ auto get_iterators(py::list iterators, mlir::MLIRContext& ctx) llvm::SmallVector ret(iterators.size()); for (auto it : llvm::enumerate(iterators)) { - ret[it.index()] = mlir::StringAttr::get(it.value().cast(), &ctx).getValue(); + ret[it.index()] = mlir::StringAttr::get(&ctx, it.value().cast()).getValue(); } return ret; } @@ -896,7 +896,7 @@ py::object reshape_impl(py::capsule context, py::handle tensor, py::int_ out_dim { affine_maps[index] = get_affine_map_attr(obj, *builder.getContext()); }); - auto affine_maps_attr = mlir::ArrayAttr::get(affine_maps, builder.getContext()); + auto affine_maps_attr = mlir::ArrayAttr::get(builder.getContext(), affine_maps); auto reshape = builder.create(loc, new_type, tensor_val, affine_maps_attr); return ctx.context.create_var(context, reshape); } diff --git a/mlir-compiler/plier/src/compiler/compiler.cpp b/mlir-compiler/plier/src/compiler/compiler.cpp index 0a0a1a3cb19..d469fa79547 100644 --- a/mlir-compiler/plier/src/compiler/compiler.cpp +++ b/mlir-compiler/plier/src/compiler/compiler.cpp @@ -136,7 +136,7 @@ struct PassManagerSchedule auto it = stages_map.find(jump.data()); assert(it != stages_map.end()); assert(nullptr != it->second); - auto name = mlir::StringAttr::get(jump, &ctx); + auto name = mlir::StringAttr::get(&ctx, jump); stage.stage->add_jump(name, it->second); } } diff --git a/mlir-compiler/plier/src/dialect.cpp b/mlir-compiler/plier/src/dialect.cpp index eae67beb7fb..ca7e9111931 100644 --- a/mlir-compiler/plier/src/dialect.cpp +++ b/mlir-compiler/plier/src/dialect.cpp @@ -173,11 +173,11 @@ void PyCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir kw_names.reserve(kwargs.size()); for (auto& a : kwargs) { - kw_names.push_back(mlir::StringAttr::get(a.first, ctx)); + kw_names.push_back(mlir::StringAttr::get(ctx, a.first)); all_args.push_back(a.second); } PyCallOp::build(builder, state, PyType::getUndefined(state.getContext()), - func, all_args, func_name, kw_start, mlir::ArrayAttr::get(kw_names, ctx)); + func, all_args, func_name, kw_start, mlir::ArrayAttr::get(ctx, kw_names)); } void BuildTupleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, diff --git a/mlir-compiler/plier/src/rewrites/cse.cpp b/mlir-compiler/plier/src/rewrites/cse.cpp index 9a39e6ba923..963b338b55e 100644 --- a/mlir-compiler/plier/src/rewrites/cse.cpp +++ b/mlir-compiler/plier/src/rewrites/cse.cpp @@ -42,7 +42,7 @@ mlir::LogicalResult simplifyRegion(ScopedMapTy& map, mlir::Region& region, mlir: bool success = false; for (auto &inst : llvm::make_early_inc_range(region.front())) { - if (inst.isKnownTerminator()) + if (inst.hasTrait()) { break; } diff --git a/mlir-compiler/plier/src/transforms/pipeline_utils.cpp b/mlir-compiler/plier/src/transforms/pipeline_utils.cpp index c390201a79c..c7127522bb9 100644 --- a/mlir-compiler/plier/src/transforms/pipeline_utils.cpp +++ b/mlir-compiler/plier/src/transforms/pipeline_utils.cpp @@ -34,7 +34,7 @@ void plier::add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr nam { name_list.insert(it, name); } - module->setAttr(jump_markers, mlir::ArrayAttr::get(name_list, module.getContext())); + module->setAttr(jump_markers, mlir::ArrayAttr::get(module.getContext(), name_list)); } @@ -56,5 +56,5 @@ void plier::remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr }); assert(it != name_list.end()); name_list.erase(it); - module->setAttr(jump_markers, mlir::ArrayAttr::get(name_list, module.getContext())); + module->setAttr(jump_markers, mlir::ArrayAttr::get(module.getContext(), name_list)); } From bf50f4f9cb51c0395f54a0f536faf249c2635d65 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 3 Mar 2021 03:21:46 +0300 Subject: [PATCH 246/259] fix linux (#198) --- mlir-compiler/mlir-compiler/CMakeLists.txt | 8 +++++--- .../mlir-compiler/src/pipelines/parallel_to_tbb.cpp | 2 +- .../mlir-compiler/src/pipelines/plier_to_std.cpp | 2 +- mlir-compiler/plier/CMakeLists.txt | 3 --- mlir-compiler/plier/include/plier/rewrites/cse.hpp | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/mlir-compiler/mlir-compiler/CMakeLists.txt b/mlir-compiler/mlir-compiler/CMakeLists.txt index 411d0257d7b..136b40f02d6 100644 --- a/mlir-compiler/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/mlir-compiler/CMakeLists.txt @@ -1,6 +1,3 @@ -if(UNIX) - add_link_options("-Wl,--exclude-libs,ALL") -endif() find_package(pybind11 REQUIRED) @@ -47,6 +44,11 @@ if (MSVC) target_compile_options(${PROJECT_NAME} PRIVATE /EHsc) endif () +if(UNIX) + target_link_options(${PROJECT_NAME} PRIVATE "LINKER:--exclude-libs,ALL") +endif() + + target_compile_definitions(${PROJECT_NAME} PRIVATE ${LLVM_DEFINITIONS}) target_link_libraries(${PROJECT_NAME} PRIVATE diff --git a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp index 3f708381340..f0046b1bd9e 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/parallel_to_tbb.cpp @@ -193,7 +193,7 @@ void ParallelToTbbPass::runOnOperation() ParallelToTbb >(&getContext()); - mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } void populate_parallel_to_tbb_pipeline(mlir::OpPassManager& pm) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp index 5786b2f40fd..4bd91a4f2db 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_std.cpp @@ -1131,7 +1131,7 @@ struct FoldTupleGetitem : public mlir::OpRewritePattern { FoldTupleGetitem(mlir::TypeConverter &/*typeConverter*/, mlir::MLIRContext *context): - OpRewritePattern(context) {} + mlir::OpRewritePattern(context) {} mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter &rewriter) const override diff --git a/mlir-compiler/plier/CMakeLists.txt b/mlir-compiler/plier/CMakeLists.txt index 6de69b1d317..53430d38e4d 100644 --- a/mlir-compiler/plier/CMakeLists.txt +++ b/mlir-compiler/plier/CMakeLists.txt @@ -1,6 +1,3 @@ -if(UNIX) - add_link_options("-Wl,--exclude-libs,ALL") -endif() find_package(LLVM REQUIRED CONFIG) find_package(MLIR REQUIRED CONFIG) diff --git a/mlir-compiler/plier/include/plier/rewrites/cse.hpp b/mlir-compiler/plier/include/plier/rewrites/cse.hpp index fa29039d3ad..f5d41ba64be 100644 --- a/mlir-compiler/plier/include/plier/rewrites/cse.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/cse.hpp @@ -14,7 +14,7 @@ template struct CSERewrite : public mlir::OpRewritePattern { CSERewrite(mlir::MLIRContext *context): - OpRewritePattern(context, /*benefit*/1) {} // TODO: benefit=0 + mlir::OpRewritePattern(context, /*benefit*/1) {} // TODO: benefit=0 mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter &rewriter) const override From 4c147f6da82d0219ceb2a66f7ff12107a1743561 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 3 Mar 2021 17:56:09 +0300 Subject: [PATCH 247/259] [MLIR] Some linalg fixes and proper increfs/decrefs for array arguments (#199) * some broadcoasting opt * rework broadcast * add pass * move force inline to opt pass * Proper increfs/decrefs for input arrays --- .../src/pipelines/lower_to_llvm.cpp | 66 +++++++++++++------ .../src/pipelines/plier_to_linalg.cpp | 49 ++++++++++++-- .../mlir-compiler/src/py_linalg_resolver.cpp | 32 +++++---- mlir-compiler/plier/include/plier/PlierOps.td | 10 +++ mlir-compiler/plier/include/plier/dialect.hpp | 2 + mlir-compiler/plier/src/dialect.cpp | 5 ++ 6 files changed, 127 insertions(+), 37 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp index 6825a55e5b3..c8d0badad17 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -610,12 +610,6 @@ struct ApplyFastmathFlags : public mlir::OpRewritePattern }; // Copypaste from StandardToLLVM -mlir::Value createIndexAttrConstant(mlir::OpBuilder &builder, mlir::Location loc, - mlir::Type resultType, int64_t value) { - return builder.create( - loc, resultType, builder.getIntegerAttr(builder.getIndexType(), value)); -} - struct AllocLikeOpLowering : public mlir::ConvertToLLVMPattern { using ConvertToLLVMPattern::createIndexConstant; using ConvertToLLVMPattern::getIndexType; @@ -625,19 +619,6 @@ struct AllocLikeOpLowering : public mlir::ConvertToLLVMPattern { : ConvertToLLVMPattern(opName, &converter.getContext(), converter, /*benefit*/99) {} protected: - // Returns 'input' aligned up to 'alignment'. Computes - // bumped = input + alignement - 1 - // aligned = bumped - bumped % alignment -// static mlir::Value createAligned(mlir::ConversionPatternRewriter &rewriter, mlir::Location loc, -// mlir::Value input, mlir::Value alignment) { -// using namespace mlir; -// Value one = createIndexAttrConstant(rewriter, loc, alignment.getType(), 1); -// Value bump = rewriter.create(loc, alignment, one); -// Value bumped = rewriter.create(loc, input, bump); -// Value mod = rewriter.create(loc, bumped, alignment); -// return rewriter.create(loc, bumped, mod); -// } - // Creates a call to an allocation function with params and casts the // resulting void pointer to ptrType. mlir::Value createAllocCall(mlir::Location loc, mlir::StringRef name, mlir::Type ptrType, @@ -1227,6 +1208,51 @@ struct PostLLVMLowering : } }; +struct LowerRetain : public mlir::OpConversionPattern +{ + using mlir::OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(plier::RetainOp op, llvm::ArrayRef operands, + mlir::ConversionPatternRewriter &rewriter) const override { + assert(operands.size() == 1); + auto arg = operands[0]; + if (!arg.getType().isa()) + { + return mlir::failure(); + } + + auto llvmVoidPointerType = + mlir::LLVM::LLVMPointerType::get(rewriter.getIntegerType(8)); + auto incref_func = [&]() + { + auto mod = op->getParentOfType(); + assert(mod); + auto func = mod.lookupSymbol("NRT_incref"); + if (!func) + { + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(mod.getBody()); + auto llvmVoidType = mlir::LLVM::LLVMVoidType::get(rewriter.getContext()); + func = rewriter.create( + rewriter.getUnknownLoc(), "NRT_incref", + mlir::LLVM::LLVMFunctionType::get(llvmVoidType, llvmVoidPointerType)); + } + return func; + }(); + + auto loc = op.getLoc(); + auto index = rewriter.getI64ArrayAttr(0); + auto elemType = arg.getType().cast().getBody()[0]; + mlir::Value ptr = rewriter.create(loc, elemType, arg, index); + ptr = rewriter.create(loc, llvmVoidPointerType, ptr); + rewriter.create(loc, incref_func, ptr); + rewriter.replaceOp(op, arg); + + return mlir::success(); + } +}; + struct LowerCasts : public mlir::OpConversionPattern { using mlir::OpConversionPattern::OpConversionPattern; @@ -1277,7 +1303,7 @@ struct LLVMLoweringPass : public mlir::PassWrapper(typeConverter, &getContext()); + patterns.insert(typeConverter, &getContext()); patterns.insert(typeConverter); LLVMConversionTarget target(getContext()); diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index df6a55c3ad6..4bc2ab30c8f 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -759,8 +759,7 @@ void PlierToLinalgPass::runOnOperation() patterns.insert< GetitemOpLowering, GetitemOpLowering, - SetitemOpLowering, - plier::ForceInline + SetitemOpLowering >(&getContext()); // range/prange lowering need dead branch pruning to properly @@ -802,8 +801,8 @@ void LowerLinalgPass::runOnOperation() (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } -struct PostFusionOptPass : - public mlir::PassWrapper> +struct CommonOptPass : + public mlir::PassWrapper> { virtual void getDependentDialects( mlir::DialectRegistry ®istry) const override @@ -817,7 +816,7 @@ struct PostFusionOptPass : void runOnOperation() override; }; -void PostFusionOptPass::runOnOperation() +void CommonOptPass::runOnOperation() { mlir::OwningRewritePatternList patterns; @@ -829,6 +828,7 @@ void PostFusionOptPass::runOnOperation() patterns.insert< // LoopInvariantCodeMotion, TODO + plier::ForceInline, plier::CSERewrite >(&context); @@ -859,6 +859,41 @@ struct LoopInvariantCodeMotion : public mlir::OpRewritePattern } }; +struct RetainArgsPass : + public mlir::PassWrapper +{ + virtual void getDependentDialects( + mlir::DialectRegistry ®istry) const override + { + registry.insert(); + } + + void runOnFunction() override; +}; + +void RetainArgsPass::runOnFunction() +{ + auto func = getFunction(); + if (func.isPrivate() || func.isDeclaration() || func.body().empty()) + { + return; + } + + mlir::OpBuilder builder(&getContext()); + auto loc = builder.getUnknownLoc(); + auto block = &func.body().front(); + builder.setInsertionPointToStart(block); + for (auto arg : block->getArguments()) + { + if (arg.getType().isa()) + { + auto retained = builder.create(loc, arg); + llvm::SmallPtrSet except({retained}); + arg.replaceAllUsesExcept(retained, except); + } + } +} + struct PostLinalgOptPass : public mlir::PassWrapper> { @@ -900,13 +935,14 @@ void PostLinalgOptPass::runOnOperation() void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); pm.addPass(mlir::createSymbolDCEPass()); } void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) { pm.addPass(mlir::createLinalgFusionOfTensorOpsPass()); - pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); pm.addPass(mlir::createTensorConstantBufferizePass()); pm.addNestedPass(mlir::createSCFBufferizePass()); @@ -920,6 +956,7 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addNestedPass(mlir::createBufferLoopHoistingPass()); pm.addNestedPass(mlir::createPromoteBuffersToStackPass()); + pm.addNestedPass(std::make_unique()); pm.addNestedPass(mlir::createBufferDeallocationPass()); pm.addPass(std::make_unique()); diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index c2585238349..a0f11554cff 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -104,6 +104,7 @@ void container_iterate(py::handle obj, F&& func) llvm::Optional make_py_literal(mlir::Value val) { + assert(val); if (auto int_val = plier::getConstVal(val)) { return py::int_(int_val.getInt()); @@ -144,6 +145,7 @@ struct PyLinalgResolver::Context py::object create_var(py::capsule context, mlir::Value value) { + assert(value); if (auto literal = make_py_literal(value)) { return *literal; @@ -423,7 +425,7 @@ mlir::Value broadcast_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Va return builder.create(loc, cond, val2, val1); } -mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, unsigned dim, mlir::ValueRange target_shape) +mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value initial, mlir::Value src, unsigned dim, mlir::ValueRange target_shape) { auto context = builder.getContext(); auto src_type = src.getType().cast(); @@ -431,11 +433,9 @@ mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value auto shape = llvm::to_vector<8>(src_type.getShape()); shape[dim] = -1; mlir::Type target_type = mlir::RankedTensorType::get(shape, src_type.getElementType()); - auto dim_val = builder.create(loc, src, dim); + auto dim_val = builder.create(loc, initial, dim); auto one = builder.create(loc, 1); mlir::Value cond = builder.create(loc, mlir::CmpIPredicate::eq, one, dim_val); - mlir::Value cond2 = builder.create(loc, mlir::CmpIPredicate::ne, target_shape[dim], dim_val); - cond = builder.create(loc, cond, cond2); llvm::SmallVector new_shape(num_dims); for (unsigned i = 0 ; i < num_dims; ++i) { @@ -498,11 +498,12 @@ mlir::Value expand_dims(mlir::OpBuilder& builder, mlir::Location loc, mlir::Valu { target_shape = target_shape.drop_front(target_shape.size() - num_dims); } + mlir::Value current = val; for (unsigned i = 0; i < num_dims; ++i) { - val = expand_dim(builder, loc, val, i, target_shape); + current = expand_dim(builder, loc, val, current, i, target_shape); } - return val; + return current; } py::object broadcast_impl(py::capsule context, py::tuple args) @@ -632,11 +633,20 @@ py::object broadcast_impl(py::capsule context, py::tuple args) { val = builder.create(loc, tensor_type.getElementType(), val); } - auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/) + val = builder.create(loc, val); + auto num_dims = static_cast(tensor_type.getRank()); + auto init = builder.create(loc, shape_vals, tensor_type.getElementType()).getResult(); + mlir::AffineMap maps[] = { + mlir::AffineMap::get(num_dims, 0, mlir::getAffineConstantExpr(0, builder.getContext())), + mlir::AffineMap::getMultiDimIdentityMap(num_dims, builder.getContext()), + }; + llvm::SmallVector iterators(num_dims, "parallel"); + auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange values) { - builder.create(loc, val); + assert(values.size() == 2); + builder.create(loc, values[0]); }; - val = builder.create(loc, tensor_type, shape_vals, body); + val = builder.create(loc, tensor_type, val, init, maps, iterators, body).getResult(0); } } ret[it.index()] = ctx.context.create_var(context, val); @@ -688,12 +698,12 @@ py::object init_tensor_impl(py::capsule context, py::handle shape, py::handle dt else { auto val = do_cast(loc, builder, ctx.context.unwrap_val(loc, builder, init_val), elem_type); + llvm::SmallVector shape(count, -1); + auto type = mlir::RankedTensorType::get(shape, elem_type); auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange /*indices*/) { builder.create(loc, val); }; - llvm::SmallVector shape(count, -1); - auto type = mlir::RankedTensorType::get(shape, elem_type); init = builder.create(loc, type, shape_val, body); } if (llvm::any_of(static_shape, [](auto val){ return val >= 0;})) diff --git a/mlir-compiler/plier/include/plier/PlierOps.td b/mlir-compiler/plier/include/plier/PlierOps.td index e8aee0d66b1..24d46b7ff12 100644 --- a/mlir-compiler/plier/include/plier/PlierOps.td +++ b/mlir-compiler/plier/include/plier/PlierOps.td @@ -214,6 +214,16 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { ]; } +def RetainOp : Plier_Op<"retain"> { + let arguments = (ins AnyMemRef:$value); + + let results = (outs Res]>:$memref); + + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$value)> + ]; +} + def ParallelOp : Plier_Op<"parallel", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, diff --git a/mlir-compiler/plier/include/plier/dialect.hpp b/mlir-compiler/plier/include/plier/dialect.hpp index a4a8696892f..f859191215a 100644 --- a/mlir-compiler/plier/include/plier/dialect.hpp +++ b/mlir-compiler/plier/include/plier/dialect.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -14,6 +15,7 @@ using Value = ::mlir::Value; using Region = ::mlir::Region; using LogicalResult = ::mlir::LogicalResult; using Operation = ::mlir::Operation; +namespace MemoryEffects = ::mlir::MemoryEffects; template using ArrayRef = ::mlir::ArrayRef; diff --git a/mlir-compiler/plier/src/dialect.cpp b/mlir-compiler/plier/src/dialect.cpp index ca7e9111931..d2a60358a00 100644 --- a/mlir-compiler/plier/src/dialect.cpp +++ b/mlir-compiler/plier/src/dialect.cpp @@ -336,6 +336,11 @@ void GetattrOp::getCanonicalizationPatterns( results.insert(context); } +void RetainOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value) { + RetainOp::build(builder, state, value.getType(), value); +} + mlir::LogicalResult ParallelOp::moveOutOfLoop(mlir::ArrayRef ops) { for (mlir::Operation *op : ops) From bf8a729e72db2253c2dde823ca2579d9847543ed Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 11 Mar 2021 22:20:03 +0300 Subject: [PATCH 248/259] [MLIR] Fusing optimizations (#200) * PostPlierToLinalgPass pass * refactor out populate_common_opts_patterns * propagate if const values * simplify expand dims * plier inliner interface * simplify if and select * fixes * fixes * EnforceShapeOp * copy removal pass and parallel loop fusion * fixes * more SelectOp folding * Add non-recursive CSE * more efficient reshape * run few rounds ParallelOp fusing --- .../src/pipelines/plier_to_linalg.cpp | 191 ++++++++++++++---- .../mlir-compiler/src/py_linalg_resolver.cpp | 6 +- mlir-compiler/plier/CMakeLists.txt | 4 + mlir-compiler/plier/include/plier/PlierOps.td | 14 ++ .../include/plier/rewrites/common_opts.hpp | 12 ++ .../plier/include/plier/rewrites/cse.hpp | 6 +- .../include/plier/rewrites/if_rewrites.cpp | 140 +++++++++++++ .../include/plier/rewrites/if_rewrites.hpp | 51 +++++ mlir-compiler/plier/src/dialect.cpp | 111 ++++++++++ .../plier/src/rewrites/common_opts.cpp | 33 +++ mlir-compiler/plier/src/rewrites/cse.cpp | 35 +++- numba/mlir/numpy/funcs.py | 11 +- 12 files changed, 555 insertions(+), 59 deletions(-) create mode 100644 mlir-compiler/plier/include/plier/rewrites/common_opts.hpp create mode 100644 mlir-compiler/plier/include/plier/rewrites/if_rewrites.cpp create mode 100644 mlir-compiler/plier/include/plier/rewrites/if_rewrites.hpp create mode 100644 mlir-compiler/plier/src/rewrites/common_opts.cpp diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 4bc2ab30c8f..ee48a1c2721 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -27,6 +27,7 @@ #include "plier/rewrites/call_lowering.hpp" #include "plier/rewrites/canonicalize_reductions.hpp" #include "plier/rewrites/cast_lowering.hpp" +#include "plier/rewrites/common_opts.hpp" #include "plier/rewrites/cse.hpp" #include "plier/rewrites/promote_to_parallel.hpp" #include "plier/rewrites/type_conversion.hpp" @@ -722,6 +723,103 @@ struct BinopRewriter : public mlir::OpRewritePattern resolver_t resolver; }; +struct SimplifyExpandDims : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::linalg::GenericOp op, mlir::PatternRewriter &rewriter) const override + { + if (!op.hasTensorSemantics()) + { + return mlir::failure(); + } + if (op.getNumInputs() != 1 || op.getNumOutputs() != 1) + { + return mlir::failure(); + } + + auto context = op.getContext(); + auto parallel_attr = mlir::StringAttr::get(context, "parallel"); + if (llvm::any_of(op.iterator_types(), [&](auto attr) { return attr != parallel_attr; })) + { + return mlir::failure(); + } + + auto maps = op.indexing_maps(); + assert(maps.size() == 2); + auto out_map = maps[1].cast().getValue(); + if (!out_map.isIdentity()) + { + return mlir::failure(); + } + auto in_map = maps[0].cast().getValue(); + auto num_dims = op.getNumLoops(); + if (in_map.getNumResults() != num_dims) + { + return mlir::failure(); + } + + bool changed = false; + auto out_shape = op.getOutput(0).getType().cast().getShape(); + llvm::SmallVector exprs(num_dims); + for (unsigned i = 0; i < num_dims; ++i) + { + auto prev_expr = in_map.getResult(i); + bool can_convert = [&]() + { + if (out_shape[i] == 1) + { + auto const_expr = prev_expr.dyn_cast(); + if (const_expr && const_expr.getValue() == 0) + { + return true; + } + } + return false; + }(); + if (can_convert) + { + changed = true; + exprs[i] = mlir::getAffineDimExpr(i, context); + } + else + { + exprs[i] = prev_expr; + } + } + + if (changed) + { + const mlir::Attribute new_maps[] = { + mlir::AffineMapAttr::get(mlir::AffineMap::get(num_dims, 0, exprs, context)), + maps[1] + }; + auto new_maps_attr = mlir::ArrayAttr::get(context, new_maps); + rewriter.updateRootInPlace(op, [&]() + { + op.indexing_mapsAttr(new_maps_attr); + }); + } + + return mlir::success(changed); + } +}; + +struct LowerEnforceShape : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + plier::EnforceShapeOp op, mlir::PatternRewriter &rewriter) const override + { + auto type = op.getType(); + auto src = op.value(); + rewriter.replaceOpWithNewOp(op, type, src); + return mlir::success(); + } +}; + void PlierToLinalgPass::runOnOperation() { auto context = &getContext(); @@ -801,38 +899,61 @@ void LowerLinalgPass::runOnOperation() (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } -struct CommonOptPass : - public mlir::PassWrapper> +struct PostPlierToLinalgPass : + public mlir::PassWrapper> { - virtual void getDependentDialects( - mlir::DialectRegistry ®istry) const override - { - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - } + void runOnOperation() override; +}; + +void PostPlierToLinalgPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + plier::populate_common_opts_patterns(context, patterns); + patterns.insert< + SimplifyExpandDims + >(&getContext()); + + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +struct TensorFusionPass : + public mlir::PassWrapper> +{ void runOnOperation() override; }; -void CommonOptPass::runOnOperation() +void TensorFusionPass::runOnOperation() { mlir::OwningRewritePatternList patterns; auto& context = getContext(); - for (auto *op : context.getRegisteredOperations()) - { - op->getCanonicalizationPatterns(patterns, &context); - } + plier::populate_common_opts_patterns(context, patterns); patterns.insert< - // LoopInvariantCodeMotion, TODO - plier::ForceInline, - plier::CSERewrite - >(&context); + SimplifyExpandDims, + LowerEnforceShape + >(&getContext()); + + mlir::populateLinalgTensorOpsFusionPatterns(&context, patterns); - plier::populate_index_propagate_patterns(context, patterns); + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} + +struct CommonOptPass : + public mlir::PassWrapper> +{ + void runOnOperation() override; +}; + +void CommonOptPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + plier::populate_common_opts_patterns(context, patterns); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } @@ -897,15 +1018,6 @@ void RetainArgsPass::runOnFunction() struct PostLinalgOptPass : public mlir::PassWrapper> { - virtual void getDependentDialects( - mlir::DialectRegistry ®istry) const override - { - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - } - void runOnOperation() override; }; @@ -914,35 +1026,26 @@ void PostLinalgOptPass::runOnOperation() mlir::OwningRewritePatternList patterns; auto& context = getContext(); - for (auto *op : context.getRegisteredOperations()) - { - op->getCanonicalizationPatterns(patterns, &context); - } + plier::populate_common_opts_patterns(context, patterns); patterns.insert< plier::CanonicalizeReduction, -// LoopInvariantCodeMotion, TODO - plier::PromoteToParallel, - plier::CmpLoopBoundsSimplify, - plier::CSERewrite + plier::PromoteToParallel >(&context); - plier::populate_index_propagate_patterns(context, patterns); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); - pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); pm.addPass(mlir::createSymbolDCEPass()); } void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) { - pm.addPass(mlir::createLinalgFusionOfTensorOpsPass()); - pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); pm.addPass(mlir::createTensorConstantBufferizePass()); pm.addNestedPass(mlir::createSCFBufferizePass()); @@ -958,10 +1061,14 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addNestedPass(std::make_unique()); pm.addNestedPass(mlir::createBufferDeallocationPass()); + pm.addPass(mlir::createCopyRemovalPass()); pm.addPass(std::make_unique()); + pm.addPass(mlir::createParallelLoopFusionPass()); pm.addPass(std::make_unique()); pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::createParallelLoopFusionPass()); // TODO: make this rewrite and add to PostLinalgOptPass + pm.addPass(std::make_unique()); } } diff --git a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp index a0f11554cff..06804c756c3 100644 --- a/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp +++ b/mlir-compiler/mlir-compiler/src/py_linalg_resolver.cpp @@ -452,8 +452,9 @@ mlir::Value expand_dim(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value { assert(dim < shape.size()); shape[dim] = 1; - mlir::Type casted_type = mlir::RankedTensorType::get(shape, src_type.getElementType()); - auto casted = builder.create(loc, casted_type, src).getResult(); +// mlir::Type casted_type = mlir::RankedTensorType::get(shape, src_type.getElementType()); +// auto casted = builder.create(loc, casted_type, src).getResult(); + auto casted = src; // TODO auto init = builder.create(loc, new_shape, src_type.getElementType()).getResult(); llvm::SmallVector exprs(num_dims); for (unsigned i = 0; i < num_dims; ++i) @@ -503,6 +504,7 @@ mlir::Value expand_dims(mlir::OpBuilder& builder, mlir::Location loc, mlir::Valu { current = expand_dim(builder, loc, val, current, i, target_shape); } + current = builder.create(loc, current, target_shape); return current; } diff --git a/mlir-compiler/plier/CMakeLists.txt b/mlir-compiler/plier/CMakeLists.txt index 53430d38e4d..efa093339f1 100644 --- a/mlir-compiler/plier/CMakeLists.txt +++ b/mlir-compiler/plier/CMakeLists.txt @@ -18,9 +18,11 @@ set(SOURCES_LIST src/rewrites/call_lowering.cpp src/rewrites/canonicalize_reductions.cpp src/rewrites/cast_lowering.cpp + src/rewrites/common_opts.cpp src/rewrites/cse.cpp src/rewrites/force_inline.cpp src/rewrites/index_type_propagation.cpp + include/plier/rewrites/if_rewrites.cpp src/rewrites/loop_rewrites.cpp src/rewrites/promote_to_parallel.cpp src/rewrites/type_conversion.cpp @@ -39,9 +41,11 @@ set(HEADERS_LIST include/plier/rewrites/call_lowering.hpp include/plier/rewrites/canonicalize_reductions.hpp include/plier/rewrites/cast_lowering.hpp + include/plier/rewrites/common_opts.hpp include/plier/rewrites/cse.hpp include/plier/rewrites/force_inline.hpp include/plier/rewrites/index_type_propagation.hpp + include/plier/rewrites/if_rewrites.hpp include/plier/rewrites/loop_rewrites.hpp include/plier/rewrites/promote_to_parallel.hpp include/plier/rewrites/type_conversion.hpp diff --git a/mlir-compiler/plier/include/plier/PlierOps.td b/mlir-compiler/plier/include/plier/PlierOps.td index 24d46b7ff12..0c2f71efa93 100644 --- a/mlir-compiler/plier/include/plier/PlierOps.td +++ b/mlir-compiler/plier/include/plier/PlierOps.td @@ -214,6 +214,20 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { ]; } +def EnforceShapeOp : Plier_Op<"enforce_shape"> { + let arguments = (ins AnyRankedTensor:$value, + Variadic:$sizes); + + let results = (outs AnyRankedTensor:$result); + + let builders = [ + OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::ValueRange":$shape)> + ]; + + let hasFolder = 1; + let hasCanonicalizer = 1; +} + def RetainOp : Plier_Op<"retain"> { let arguments = (ins AnyMemRef:$value); diff --git a/mlir-compiler/plier/include/plier/rewrites/common_opts.hpp b/mlir-compiler/plier/include/plier/rewrites/common_opts.hpp new file mode 100644 index 00000000000..5b31352f4cd --- /dev/null +++ b/mlir-compiler/plier/include/plier/rewrites/common_opts.hpp @@ -0,0 +1,12 @@ +#pragma once + +namespace mlir +{ +class OwningRewritePatternList; +class MLIRContext; +} + +namespace plier +{ +void populate_common_opts_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns); +} diff --git a/mlir-compiler/plier/include/plier/rewrites/cse.hpp b/mlir-compiler/plier/include/plier/rewrites/cse.hpp index f5d41ba64be..bcc9b6578f5 100644 --- a/mlir-compiler/plier/include/plier/rewrites/cse.hpp +++ b/mlir-compiler/plier/include/plier/rewrites/cse.hpp @@ -7,10 +7,10 @@ namespace plier { namespace detail { -mlir::LogicalResult applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter); +mlir::LogicalResult applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter, bool recusive); } -template +template struct CSERewrite : public mlir::OpRewritePattern { CSERewrite(mlir::MLIRContext *context): @@ -19,7 +19,7 @@ struct CSERewrite : public mlir::OpRewritePattern mlir::LogicalResult matchAndRewrite( Op op, mlir::PatternRewriter &rewriter) const override { - return ::plier::detail::applyCSE(op.getRegion(), rewriter); + return ::plier::detail::applyCSE(op.getRegion(), rewriter, Recursive); } }; } diff --git a/mlir-compiler/plier/include/plier/rewrites/if_rewrites.cpp b/mlir-compiler/plier/include/plier/rewrites/if_rewrites.cpp new file mode 100644 index 00000000000..86e14a416ff --- /dev/null +++ b/mlir-compiler/plier/include/plier/rewrites/if_rewrites.cpp @@ -0,0 +1,140 @@ +#include "plier/rewrites/if_rewrites.hpp" + +#include +#include + +mlir::LogicalResult plier::IfOpConstCond::matchAndRewrite(mlir::scf::IfOp op, mlir::PatternRewriter& rewriter) const +{ + auto cond = mlir::dyn_cast_or_null(op.condition().getDefiningOp()); + if (!cond) + { + return mlir::failure(); + } + auto is_const = [](mlir::Value val) + { + if (auto parent = val.getDefiningOp()) + { + return parent->hasTrait(); + } + return false; + }; + + auto replace = [&](mlir::Block& block, mlir::Value to_replace, mlir::Value new_val) + { + for (auto& use : llvm::make_early_inc_range(to_replace.getUses())) + { + auto owner = use.getOwner(); + if (block.findAncestorOpInBlock(*owner)) + { + rewriter.updateRootInPlace(owner, [&]() + { + use.set(new_val); + }); + } + } + }; + + mlir::Value const_val; + mlir::Value to_replace; + if (is_const(cond.lhs())) + { + const_val = cond.lhs(); + to_replace = cond.rhs(); + } + else if (is_const(cond.rhs())) + { + const_val = cond.rhs(); + to_replace = cond.lhs(); + } + else + { + return mlir::failure(); + } + + if (cond.predicate() == mlir::CmpIPredicate::eq) + { + replace(op.thenRegion().front(), to_replace, const_val); + } + else if (cond.predicate() == mlir::CmpIPredicate::ne) + { + replace(op.elseRegion().front(), to_replace, const_val); + } + else + { + return mlir::failure(); + } + + return mlir::success(); +} + +mlir::LogicalResult plier::SimplifyEmptyIf::matchAndRewrite(mlir::scf::IfOp op, mlir::PatternRewriter& rewriter) const +{ + if (op.getNumResults() == 0 || op.elseRegion().empty()) + { + return mlir::failure(); + } + if (!llvm::hasNItems(op.thenRegion().front(), 1) || + !llvm::hasNItems(op.elseRegion().front(), 1)) + { + return mlir::failure(); + } + auto then_yield_args = mlir::cast(op.thenRegion().front().getTerminator()).getOperands(); + auto else_yield_args = mlir::cast(op.elseRegion().front().getTerminator()).getOperands(); + for (auto it : llvm::zip(then_yield_args, else_yield_args)) + { + if (std::get<0>(it) != std::get<1>(it)) + { + return mlir::failure(); + } + } + llvm::SmallVector args(then_yield_args.begin(), then_yield_args.end()); + assert(args.size() == op.getNumResults()); + rewriter.replaceOp(op, args); + return mlir::success(); +} + +mlir::LogicalResult plier::SimplifySelect::matchAndRewrite(mlir::SelectOp op, mlir::PatternRewriter& rewriter) const +{ + auto true_val = op.getTrueValue(); + auto false_val = op.getFalseValue(); + if (true_val == false_val) + { + rewriter.replaceOp(op, true_val); + return mlir::success(); + } + return mlir::failure(); +} + +mlir::LogicalResult plier::SimplifySelectEq::matchAndRewrite(mlir::SelectOp op, mlir::PatternRewriter& rewriter) const +{ + auto cond = mlir::dyn_cast_or_null(op.condition().getDefiningOp()); + if (!cond) + { + return mlir::failure(); + } + if (cond.predicate() != mlir::CmpIPredicate::eq && + cond.predicate() != mlir::CmpIPredicate::ne) + { + return mlir::failure(); + } + + auto cond_lhs = cond.lhs(); + auto cond_rhs = cond.rhs(); + + auto true_val = op.getTrueValue(); + auto false_val = op.getFalseValue(); + + if (cond.predicate() == mlir::CmpIPredicate::ne) + { + std::swap(true_val, false_val); + } + + if ((cond_lhs == true_val && cond_rhs == false_val) || + (cond_rhs == true_val && cond_lhs == false_val)) + { + rewriter.replaceOp(op, false_val); + return mlir::success(); + } + + return mlir::failure(); +} diff --git a/mlir-compiler/plier/include/plier/rewrites/if_rewrites.hpp b/mlir-compiler/plier/include/plier/rewrites/if_rewrites.hpp new file mode 100644 index 00000000000..80302dcd07a --- /dev/null +++ b/mlir-compiler/plier/include/plier/rewrites/if_rewrites.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include + +namespace mlir +{ +class SelectOp; +namespace scf +{ +class IfOp; +} +} + +namespace plier +{ +struct IfOpConstCond : public mlir::OpRewritePattern +{ + IfOpConstCond(mlir::MLIRContext *context): + mlir::OpRewritePattern(context, /*benefit*/1) {} + + mlir::LogicalResult matchAndRewrite( + mlir::scf::IfOp op, mlir::PatternRewriter &rewriter) const override; +}; + +// TODO: upstream +struct SimplifyEmptyIf : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::scf::IfOp op, mlir::PatternRewriter &rewriter) const override; +}; + +// TODO: upstream +struct SimplifySelect : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::SelectOp op, mlir::PatternRewriter &rewriter) const override; +}; + +// TODO: upstream +struct SimplifySelectEq : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::SelectOp op, mlir::PatternRewriter &rewriter) const override; +}; +} diff --git a/mlir-compiler/plier/src/dialect.cpp b/mlir-compiler/plier/src/dialect.cpp index d2a60358a00..892301c4215 100644 --- a/mlir-compiler/plier/src/dialect.cpp +++ b/mlir-compiler/plier/src/dialect.cpp @@ -5,11 +5,32 @@ #include #include #include +#include #include #include +#include "plier/transforms/const_utils.hpp" + +namespace +{ +struct PLierInlinerInterface : public mlir::DialectInlinerInterface +{ + using mlir::DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(mlir::Region *, mlir::Region *, bool, + mlir::BlockAndValueMapping &) const final override + { + return true; + } + bool isLegalToInline(mlir::Operation *op, mlir::Region *, bool, + mlir::BlockAndValueMapping &) const final override + { + return !mlir::isa(op); + } +}; +} + namespace plier { @@ -70,6 +91,7 @@ void PlierDialect::initialize() #include "plier/PlierOps.cpp.inc" >(); addTypes(); + addInterfaces(); } mlir::Type PlierDialect::parseType(mlir::DialectAsmParser &parser) const { @@ -336,6 +358,95 @@ void GetattrOp::getCanonicalizationPatterns( results.insert(context); } +void EnforceShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, + mlir::Value value, mlir::ValueRange shape) { + EnforceShapeOp::build(builder, state, value.getType(), value, shape); +} + +mlir::OpFoldResult EnforceShapeOp::fold(llvm::ArrayRef operands) { + operands = operands.drop_front(); + auto num_dims = static_cast(operands.size()); + auto src_type = getType().cast(); + llvm::SmallVector final_shape(num_dims, -1); + if (src_type.hasRank()) + { + auto shape = src_type.getShape(); + if (shape.size() != num_dims) + { + return nullptr; + } + final_shape.assign(shape.begin(), shape.end()); + } + bool changed = false; + for (unsigned i = 0; i < num_dims; ++i) + { + if (auto attr = operands[i].dyn_cast_or_null()) + { + auto val = attr.getInt(); + if (val != -1) + { + if (final_shape[i] != -1) + { + if (final_shape[i] != val) + { + return nullptr; + } + } + else + { + changed = true; + final_shape[i] = val; + } + } + } + } + + if (changed) + { + auto final_type = mlir::RankedTensorType::get(final_shape, src_type.getElementType()); + result().setType(final_type); + return result(); + } + return nullptr; +} + +namespace +{ +struct EnforceShapeDim : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::DimOp op, mlir::PatternRewriter &rewriter) const override + { + auto enforce_op = mlir::dyn_cast_or_null(op.memrefOrTensor().getDefiningOp()); + if (!enforce_op) + { + return mlir::failure(); + } + auto const_ind = plier::getConstVal(op.index()); + if (!const_ind) + { + return mlir::failure(); + } + auto index = const_ind.getInt(); + if (index < 0 || index >= static_cast(enforce_op.sizes().size())) + { + return mlir::failure(); + } + + rewriter.replaceOp(op, enforce_op.sizes()[static_cast(index)]); + return mlir::success(); + } +}; +} + +void EnforceShapeOp::getCanonicalizationPatterns( + ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) +{ + results.insert(context); +} + void RetainOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value value) { RetainOp::build(builder, state, value.getType(), value); diff --git a/mlir-compiler/plier/src/rewrites/common_opts.cpp b/mlir-compiler/plier/src/rewrites/common_opts.cpp new file mode 100644 index 00000000000..844ccf89dd5 --- /dev/null +++ b/mlir-compiler/plier/src/rewrites/common_opts.cpp @@ -0,0 +1,33 @@ +#include "plier/rewrites/common_opts.hpp" + +#include "plier/rewrites/force_inline.hpp" +#include "plier/rewrites/index_type_propagation.hpp" +#include "plier/rewrites/loop_rewrites.hpp" +#include "plier/rewrites/cse.hpp" +#include "plier/rewrites/if_rewrites.hpp" + +#include +#include +#include +#include + +void plier::populate_common_opts_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns) +{ + for (auto *op : context.getRegisteredOperations()) + { + op->getCanonicalizationPatterns(patterns, &context); + } + + patterns.insert< + // LoopInvariantCodeMotion, TODO + plier::ForceInline, + plier::CmpLoopBoundsSimplify, + SimplifyEmptyIf, + plier::IfOpConstCond, + SimplifySelect, + SimplifySelectEq, + plier::CSERewrite + >(&context); + + plier::populate_index_propagate_patterns(context, patterns); +} diff --git a/mlir-compiler/plier/src/rewrites/cse.cpp b/mlir-compiler/plier/src/rewrites/cse.cpp index 963b338b55e..ba38abb7aad 100644 --- a/mlir-compiler/plier/src/rewrites/cse.cpp +++ b/mlir-compiler/plier/src/rewrites/cse.cpp @@ -32,6 +32,7 @@ using AllocatorTy = llvm::RecyclingAllocator< using ScopedMapTy = llvm::ScopedHashTable; +template mlir::LogicalResult simplifyRegion(ScopedMapTy& map, mlir::Region& region, mlir::PatternRewriter& rewriter) { if (region.empty() || std::next(region.begin()) != region.end()) @@ -52,12 +53,27 @@ mlir::LogicalResult simplifyRegion(ScopedMapTy& map, mlir::Region& region, mlir: } if (!inst.getRegions().empty()) { - for (auto& reg : inst.getRegions()) + if (Recursive && !inst.hasTrait()) { - ScopedMapTy::ScopeTy scope(map); - if (mlir::succeeded(simplifyRegion(map, reg, rewriter))) + for (auto& reg : inst.getRegions()) { - success = true; + ScopedMapTy::ScopeTy scope(map); + if (mlir::succeeded(simplifyRegion(map, reg, rewriter))) + { + success = true; + } + } + } + else + { + for (auto& reg : inst.getRegions()) + { + ScopedMapTy new_map; + ScopedMapTy::ScopeTy scope(new_map); + if (mlir::succeeded(simplifyRegion(new_map, reg, rewriter))) + { + success = true; + } } } continue; @@ -78,9 +94,16 @@ mlir::LogicalResult simplifyRegion(ScopedMapTy& map, mlir::Region& region, mlir: } } -mlir::LogicalResult plier::detail::applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter) +mlir::LogicalResult plier::detail::applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter, bool recusive) { ScopedMapTy map; ScopedMapTy::ScopeTy scope(map); - return simplifyRegion(map, region, rewriter); + if (recusive) + { + return simplifyRegion(map, region, rewriter); + } + else + { + return simplifyRegion(map, region, rewriter); + } } diff --git a/numba/mlir/numpy/funcs.py b/numba/mlir/numpy/funcs.py index e3094418d14..ce82e4b0b90 100644 --- a/numba/mlir/numpy/funcs.py +++ b/numba/mlir/numpy/funcs.py @@ -211,12 +211,11 @@ def reshape_impl(builder, arg, new_shape): flat = flatten(builder, arg, src_count) init = builder.init_tensor(new_shape, arg.dtype) - iterators = ['parallel'] - # dims1 = ','.join(['d%s' % i for i in range(count)]) - # dims2 = ','.join(['d%s' % i if i == size_index else '0' for i in range(count)]) - dims3 = ','.join(['d0' if i == size_index else '0' for i in range(count)]) - expr1 = f'(d0) -> (d0)' - expr2 = f'(d0) -> ({dims3})' + iterators = ['parallel' for _ in range(count)] + dims1 = ','.join(['d%s' % i for i in range(count)]) + dims3 = ','.join(['d%s' % i if i == size_index else '0' for i in range(count)]) + expr1 = f'({dims1}) -> (d{size_index})' + expr2 = f'({dims1}) -> ({dims1})' maps = [expr1, expr2] def body(a, b): From 9ad86d3dad420795c77c8828d73f9dca920f0915 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 13 Mar 2021 18:47:00 +0300 Subject: [PATCH 249/259] Update to mlir master (#201) --- mlir-compiler/llvm-sha.txt | 2 +- mlir-compiler/mlir-compiler/CMakeLists.txt | 2 +- mlir-compiler/mlir-compiler/src/lowering.cpp | 2 +- .../src/pipelines/lower_to_llvm.cpp | 2 +- mlir-compiler/plier/CMakeLists.txt | 1 - mlir-compiler/plier/include/plier/PlierOps.td | 36 +++++++++---------- 6 files changed, 22 insertions(+), 23 deletions(-) diff --git a/mlir-compiler/llvm-sha.txt b/mlir-compiler/llvm-sha.txt index e0624e4fe90..0dea32271e1 100644 --- a/mlir-compiler/llvm-sha.txt +++ b/mlir-compiler/llvm-sha.txt @@ -1 +1 @@ -7d09e1d7cf27ce781e83f9d388a7a3e1e6487ead +7b153b43d3a14d76975039408c4b922beb576735 diff --git a/mlir-compiler/mlir-compiler/CMakeLists.txt b/mlir-compiler/mlir-compiler/CMakeLists.txt index 136b40f02d6..791d54fef3a 100644 --- a/mlir-compiler/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/mlir-compiler/CMakeLists.txt @@ -58,7 +58,7 @@ target_link_libraries(${PROJECT_NAME} PRIVATE LLVMTarget MLIRIR MLIRLLVMIR - MLIRTargetLLVMIR + MLIRLLVMToLLVMIRTranslation MLIRTransforms MLIRStandardOpsTransforms MLIRLinalgTransforms diff --git a/mlir-compiler/mlir-compiler/src/lowering.cpp b/mlir-compiler/mlir-compiler/src/lowering.cpp index 96c0bcf92cc..6c34ac75278 100644 --- a/mlir-compiler/mlir-compiler/src/lowering.cpp +++ b/mlir-compiler/mlir-compiler/src/lowering.cpp @@ -12,8 +12,8 @@ #include #include -#include #include +#include #include diff --git a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp index c8d0badad17..70e94013769 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/lower_to_llvm.cpp @@ -587,7 +587,7 @@ struct ApplyFastmathFlags : public mlir::OpRewritePattern }); if (changed) { - op.fastmathFlagsAttr(mlir::LLVM::FMFAttr::get(fmf, op.getContext())); + op.fastmathFlagsAttr(mlir::LLVM::FMFAttr::get(op.getContext(), fmf)); rewriter.finalizeRootUpdate(op); } else diff --git a/mlir-compiler/plier/CMakeLists.txt b/mlir-compiler/plier/CMakeLists.txt index efa093339f1..0d62fb8b4fa 100644 --- a/mlir-compiler/plier/CMakeLists.txt +++ b/mlir-compiler/plier/CMakeLists.txt @@ -70,7 +70,6 @@ target_compile_definitions(${PLIER_LIB} PRIVATE ${LLVM_DEFINITIONS}) target_link_libraries(${PLIER_LIB} PRIVATE MLIRIR MLIRLLVMIR - MLIRTargetLLVMIR MLIRTransforms MLIRStandardOpsTransforms MLIRLinalgTransforms diff --git a/mlir-compiler/plier/include/plier/PlierOps.td b/mlir-compiler/plier/include/plier/PlierOps.td index 0c2f71efa93..3c222966ef3 100644 --- a/mlir-compiler/plier/include/plier/PlierOps.td +++ b/mlir-compiler/plier/include/plier/PlierOps.td @@ -28,7 +28,7 @@ def ArgOp : Plier_Op<"arg", [NoSideEffect]> { let hasFolder = 1; let builders = [ - OpBuilderDAG<(ins "unsigned":$index, "::mlir::StringRef":$name)> + OpBuilder<(ins "unsigned":$index, "::mlir::StringRef":$name)> ]; } @@ -39,7 +39,7 @@ def ConstOp : Plier_Op<"const", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Attribute":$val)> + OpBuilder<(ins "::mlir::Attribute":$val)> ]; } @@ -50,7 +50,7 @@ def GlobalOp : Plier_Op<"global", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::StringRef":$name)> + OpBuilder<(ins "::mlir::StringRef":$name)> ]; } @@ -63,7 +63,7 @@ def BinOp : Plier_Op<"binop", []> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs, "::mlir::StringRef ":$op)> + OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs, "::mlir::StringRef ":$op)> ]; } @@ -75,7 +75,7 @@ def UnaryOp : Plier_Op<"unary", []> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::StringRef ":$op)> + OpBuilder<(ins "::mlir::Value":$value, "::mlir::StringRef ":$op)> ]; } @@ -98,7 +98,7 @@ def PyCallOp : Plier_Op<"call", []> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$func, "::mlir::StringRef":$func_name, + OpBuilder<(ins "::mlir::Value":$func, "::mlir::StringRef":$func_name, "::mlir::ValueRange":$args, "::mlir::ArrayRef>":$kwargs)> ]; @@ -111,7 +111,7 @@ def BuildTupleOp : Plier_Op<"build_tuple", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::ValueRange":$args)> + OpBuilder<(ins "::mlir::ValueRange":$args)> ]; } @@ -124,7 +124,7 @@ def GetItemOp : Plier_Op<"getitem", []> { let hasFolder = 1; let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::Value":$index)> + OpBuilder<(ins "::mlir::Value":$value, "::mlir::Value":$index)> ]; } @@ -138,7 +138,7 @@ def StaticGetItemOp : Plier_Op<"static_getitem", []> { let hasFolder = 1; let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::Value":$index_var, "unsigned":$index)> + OpBuilder<(ins "::mlir::Value":$value, "::mlir::Value":$index_var, "unsigned":$index)> ]; } @@ -158,7 +158,7 @@ def GetiterOp : Plier_Op<"getiter", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value)> + OpBuilder<(ins "::mlir::Value":$value)> ]; } @@ -169,7 +169,7 @@ def IternextOp : Plier_Op<"iternext", []> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value)> + OpBuilder<(ins "::mlir::Value":$value)> ]; } @@ -180,7 +180,7 @@ def PairfirstOp : Plier_Op<"pair_first", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value)> + OpBuilder<(ins "::mlir::Value":$value)> ]; } @@ -191,7 +191,7 @@ def PairsecondOp : Plier_Op<"pair_second", [NoSideEffect]> { let results = (outs AnyType); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value)> + OpBuilder<(ins "::mlir::Value":$value)> ]; } @@ -210,7 +210,7 @@ def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { let hasCanonicalizer = 1; let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::StringRef":$name)> + OpBuilder<(ins "::mlir::Value":$value, "::mlir::StringRef":$name)> ]; } @@ -221,7 +221,7 @@ def EnforceShapeOp : Plier_Op<"enforce_shape"> { let results = (outs AnyRankedTensor:$result); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value, "::mlir::ValueRange":$shape)> + OpBuilder<(ins "::mlir::Value":$value, "::mlir::ValueRange":$shape)> ]; let hasFolder = 1; @@ -234,7 +234,7 @@ def RetainOp : Plier_Op<"retain"> { let results = (outs Res]>:$memref); let builders = [ - OpBuilderDAG<(ins "::mlir::Value":$value)> + OpBuilder<(ins "::mlir::Value":$value)> ]; } @@ -251,7 +251,7 @@ def ParallelOp : Plier_Op<"parallel", let skipDefaultBuilders = 1; let builders = [ - OpBuilderDAG<(ins "::mlir::ValueRange":$lowerBounds, "::mlir::ValueRange":$upperBounds, "::mlir::ValueRange":$steps, + OpBuilder<(ins "::mlir::ValueRange":$lowerBounds, "::mlir::ValueRange":$upperBounds, "::mlir::ValueRange":$steps, CArg<"::mlir::function_ref", "nullptr">)> ]; @@ -264,7 +264,7 @@ def ParallelOp : Plier_Op<"parallel", def YieldOp : Plier_Op<"yield", [NoSideEffect, ReturnLike, Terminator, ParentOneOf<["ParallelOp"]>]> { let arguments = (ins Variadic:$results); - let builders = [OpBuilderDAG<(ins), [{ /* nothing to do */ }]>]; + let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; // Override default verifier (defined in SCF_Op), no custom verification // needed. let verifier = ?; From 948b8a2f4f408ccaa3df4c3e0d42c80fa5a77f58 Mon Sep 17 00:00:00 2001 From: Alexander-Makaryev <40917969+Alexander-Makaryev@users.noreply.github.com> Date: Tue, 16 Mar 2021 11:21:42 +0300 Subject: [PATCH 250/259] add C language in cmake (#202) --- mlir-compiler/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index eb63fbf04bd..1b3ae548899 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) -project(mlir_compiler LANGUAGES CXX) +project(mlir_compiler LANGUAGES CXX C) set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) From f271a8eb0ec373727a47e58af8fc820ef4686369 Mon Sep 17 00:00:00 2001 From: Alexander-Makaryev <40917969+Alexander-Makaryev@users.noreply.github.com> Date: Wed, 17 Mar 2021 21:47:27 +0300 Subject: [PATCH 251/259] add var for custom plier/dpcomp (#204) --- mlir-compiler/CMakeLists.txt | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 1b3ae548899..2674a3adb76 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -6,5 +6,18 @@ set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -add_subdirectory(plier) add_subdirectory(mlir-compiler) + +if(DEFINED DPCOMP_DIR) + message(STATUS "PLIER from DPCOMP_DIR is used") + target_include_directories(${PROJECT_NAME} PRIVATE + ${DPCOMP_DIR}/include + ) + target_link_directories(${PROJECT_NAME} PRIVATE + ${DPCOMP_DIR} + ${DPCOMP_DIR}/Release + ) +else() + message(STATUS "PLIER from local directory is used") + add_subdirectory(plier) +endif() From 71c8661d08c147939e14b6925739b7664604de78 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Mar 2021 18:50:07 +0300 Subject: [PATCH 252/259] Use out-of-tree dpcomp sources (#206) --- mlir-compiler/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index 2674a3adb76..e71d55c5147 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -8,7 +8,10 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) add_subdirectory(mlir-compiler) -if(DEFINED DPCOMP_DIR) +if(DEFINED DPCOMP_TREE) + message(STATUS "Out of tree PLIER is used") + add_subdirectory(${DPCOMP_TREE} plier) +elseif(DEFINED DPCOMP_DIR) message(STATUS "PLIER from DPCOMP_DIR is used") target_include_directories(${PROJECT_NAME} PRIVATE ${DPCOMP_DIR}/include From 15bc6784bfba7e97cce3ede79c1c66a67bbd214f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Mar 2021 18:51:05 +0300 Subject: [PATCH 253/259] [MLIR] Optimizations (#205) * refac * refac * extend CanonicalizeReduction * PromoteParallelPass pass * promote some simple loads * Rewrite single-element single-store memref * copypaste naiveParallelLoopFusion from mlir * Extend mem rewrites --- .../src/pipelines/plier_to_linalg.cpp | 47 ++++- mlir-compiler/plier/CMakeLists.txt | 6 +- .../plier/rewrites/memory_rewrites.hpp | 28 +++ .../include/plier/transforms/block_utils.hpp | 20 ++ .../include/plier/transforms/loop_utils.hpp | 3 + .../src/rewrites/canonicalize_reductions.cpp | 126 +++++++----- .../plier/src/rewrites/common_opts.cpp | 5 +- .../plier => src}/rewrites/if_rewrites.cpp | 0 .../plier/src/rewrites/memory_rewrites.cpp | 191 ++++++++++++++++++ .../plier/src/transforms/block_utils.cpp | 64 ++++++ .../plier/src/transforms/loop_utils.cpp | 158 +++++++++++++++ 11 files changed, 594 insertions(+), 54 deletions(-) create mode 100644 mlir-compiler/plier/include/plier/rewrites/memory_rewrites.hpp create mode 100644 mlir-compiler/plier/include/plier/transforms/block_utils.hpp rename mlir-compiler/plier/{include/plier => src}/rewrites/if_rewrites.cpp (100%) create mode 100644 mlir-compiler/plier/src/rewrites/memory_rewrites.cpp create mode 100644 mlir-compiler/plier/src/transforms/block_utils.cpp diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index ee48a1c2721..c3559085f71 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -1028,9 +1028,50 @@ void PostLinalgOptPass::runOnOperation() auto& context = getContext(); plier::populate_common_opts_patterns(context, patterns); + patterns.insert< + plier::CanonicalizeReduction + >(&context); + + mlir::FrozenRewritePatternList frozenPatterns(std::move(patterns)); + + while (true) + { + (void)mlir::applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); + bool rerun = false; + for (auto& op : getOperation().getRegion().front()) + { + if (auto func = mlir::dyn_cast(op)) + { + if (mlir::succeeded(plier::naivelyFuseParallelOps(func.getRegion()))) + { + rerun = true; + } + } + } + if (!rerun) + { + break; + } + } + +} + +struct PromoteParallelPass : + public mlir::PassWrapper> +{ + void runOnOperation() override; +}; + +void PromoteParallelPass::runOnOperation() +{ + mlir::OwningRewritePatternList patterns; + + auto& context = getContext(); + plier::populate_common_opts_patterns(context, patterns); + patterns.insert< plier::CanonicalizeReduction, - plier::PromoteToParallel + plier::PromoteToParallel // TODO >(&context); (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); @@ -1064,11 +1105,9 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addPass(mlir::createCopyRemovalPass()); pm.addPass(std::make_unique()); - pm.addPass(mlir::createParallelLoopFusionPass()); pm.addPass(std::make_unique()); pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(mlir::createParallelLoopFusionPass()); // TODO: make this rewrite and add to PostLinalgOptPass - pm.addPass(std::make_unique()); + pm.addPass(std::make_unique()); } } diff --git a/mlir-compiler/plier/CMakeLists.txt b/mlir-compiler/plier/CMakeLists.txt index 0d62fb8b4fa..24c448a5eef 100644 --- a/mlir-compiler/plier/CMakeLists.txt +++ b/mlir-compiler/plier/CMakeLists.txt @@ -22,10 +22,12 @@ set(SOURCES_LIST src/rewrites/cse.cpp src/rewrites/force_inline.cpp src/rewrites/index_type_propagation.cpp - include/plier/rewrites/if_rewrites.cpp + src/rewrites/if_rewrites.cpp src/rewrites/loop_rewrites.cpp + src/rewrites/memory_rewrites.cpp src/rewrites/promote_to_parallel.cpp src/rewrites/type_conversion.cpp + src/transforms/block_utils.cpp src/transforms/cast_utils.cpp src/transforms/const_utils.cpp src/transforms/func_utils.cpp @@ -47,8 +49,10 @@ set(HEADERS_LIST include/plier/rewrites/index_type_propagation.hpp include/plier/rewrites/if_rewrites.hpp include/plier/rewrites/loop_rewrites.hpp + include/plier/rewrites/memory_rewrites.hpp include/plier/rewrites/promote_to_parallel.hpp include/plier/rewrites/type_conversion.hpp + include/plier/transforms/block_utils.hpp include/plier/transforms/cast_utils.hpp include/plier/transforms/const_utils.hpp include/plier/transforms/func_utils.hpp diff --git a/mlir-compiler/plier/include/plier/rewrites/memory_rewrites.hpp b/mlir-compiler/plier/include/plier/rewrites/memory_rewrites.hpp new file mode 100644 index 00000000000..967b51efacc --- /dev/null +++ b/mlir-compiler/plier/include/plier/rewrites/memory_rewrites.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include + +namespace mlir +{ +class FuncOp; +class StoreOp; +} + +namespace plier +{ +struct PromoteLoads : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::FuncOp op, mlir::PatternRewriter &rewriter) const override; +}; + +struct SingeWriteMemref : public mlir::OpRewritePattern +{ + using mlir::OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult matchAndRewrite( + mlir::StoreOp op, mlir::PatternRewriter &rewriter) const override; +}; +} diff --git a/mlir-compiler/plier/include/plier/transforms/block_utils.hpp b/mlir-compiler/plier/include/plier/transforms/block_utils.hpp new file mode 100644 index 00000000000..fcaf961be80 --- /dev/null +++ b/mlir-compiler/plier/include/plier/transforms/block_utils.hpp @@ -0,0 +1,20 @@ +#pragma once + +namespace mlir +{ +class Operation; +} + +namespace plier +{ + +enum class OpRelation +{ + Before, + After, + In, + Unknown +}; + +OpRelation relativeTo(mlir::Operation* op, mlir::Operation* relativeTo); +} diff --git a/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp b/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp index b5dec36ca62..2069f51dff6 100644 --- a/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp +++ b/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp @@ -10,6 +10,7 @@ class Value; class Location; class OpBuilder; class Type; +class Region; namespace scf { class ForOp; @@ -27,4 +28,6 @@ mlir::LogicalResult lower_while_to_for(plier::GetiterOp getiter, mlir::PatternRe llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, llvm::function_ref get_iter_val, llvm::function_ref results = nullptr); + +mlir::LogicalResult naivelyFuseParallelOps(mlir::Region ®ion); } diff --git a/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp index 37c0f31edf9..afb9e789d4d 100644 --- a/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp +++ b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp @@ -4,28 +4,46 @@ #include #include +#include "plier/transforms/block_utils.hpp" + namespace { bool checkMemrefType(mlir::Value value) { if (auto type = value.getType().dyn_cast()) { - auto shape = type.getShape(); - return shape.empty() || (1 == shape.size() && 1 == shape[0]); +// auto shape = type.getShape(); +// return shape.empty() || (1 == shape.size() && 1 == shape[0]); + return true; } return false; } -bool checkForPotentialAliases(mlir::Value value) +bool isOutsideBlock(mlir::ValueRange values, mlir::Block& block) { - auto def_op = value.getDefiningOp(); - if (nullptr == def_op) + auto blockArgs = block.getArguments(); + for (auto val : values) { - return false; + if (llvm::find(blockArgs, val) != blockArgs.end()) + { + return false; + } + auto op = val.getDefiningOp(); + if (op && block.findAncestorOpInBlock(*op)) + { + return false; + } } - if (auto effects = mlir::dyn_cast(def_op)) + return true; +} + +bool checkForPotentialAliases(mlir::Value value, mlir::Operation* parent) +{ + assert(parent->getRegions().size() == 1); + assert(llvm::hasNItems(parent->getRegions().front(), 1)); + if (auto effects = mlir::dyn_cast_or_null(value.getDefiningOp())) { - if (!effects.hasEffect()) + if (!effects.onlyHasEffect()) { return false; } @@ -34,6 +52,10 @@ bool checkForPotentialAliases(mlir::Value value) { return false; } + + mlir::LoadOp load; + mlir::StoreOp store; + auto& parentBlock = parent->getRegions().front().front(); for (auto user : value.getUsers()) { if (mlir::isa(user)) @@ -41,6 +63,42 @@ bool checkForPotentialAliases(mlir::Value value) // TODO: very conservative return false; } + auto relation = plier::relativeTo(user, parent); + if (plier::OpRelation::Unknown == relation) + { + return false; + } + if (plier::OpRelation::In == relation) + { + if (auto effects = mlir::dyn_cast_or_null(user)) + { + if (user->getBlock() != &parentBlock) + { + return false; + } + if (effects.hasEffect()) + { + if (load || !mlir::isa(user)) + { + return false; + } + load = mlir::cast(user); + } + if (effects.hasEffect()) + { + if (store || !mlir::isa(user)) + { + return false; + } + store = mlir::cast(user); + } + } + } + } + if (!load || !store || !load->isBeforeInBlock(store) || + load.indices() != store.indices() || !isOutsideBlock(load.indices(), parentBlock)) + { + return false; } return true; } @@ -59,55 +117,27 @@ bool checkSupportedOps(mlir::Value value, mlir::Operation* parent) bool checkMemref(mlir::Value value, mlir::Operation* parent) { - return checkMemrefType(value) && checkForPotentialAliases(value) && + return checkMemrefType(value) && + checkForPotentialAliases(value, parent) && checkSupportedOps(value, parent); } -mlir::Value createScalarLoad( - mlir::PatternRewriter &builder, mlir::Location loc, mlir::Value memref) +mlir::Value createScalarLoad(mlir::PatternRewriter &builder, mlir::Location loc, mlir::Value memref, mlir::ValueRange indices) { - auto shape = memref.getType().cast().getShape(); - if (shape.empty()) - { - return builder.create(loc, memref); - } - else if (llvm::all_of(shape, [](auto s) { return s == 1; })) - { - auto index = builder.create(loc, 0); - llvm::SmallVector indices(shape.size(), index); - return builder.create(loc, memref, indices); - } - else - { - llvm_unreachable("Invalid shape"); - } + return builder.create(loc, memref, indices); } void createScalarStore( mlir::PatternRewriter &builder, mlir::Location loc, mlir::Value val, - mlir::Value memref) + mlir::Value memref, mlir::ValueRange indices) { - auto shape = memref.getType().cast().getShape(); - if (shape.empty()) - { - builder.create(loc, val, memref); - } - else if (llvm::all_of(shape, [](auto s) { return s == 1; })) - { - auto index = builder.create(loc, 0); - llvm::SmallVector indices(shape.size(), index); - builder.create(loc, val, memref, indices); - } - else - { - llvm_unreachable("Invalid shape"); - } + builder.create(loc, val, memref, indices); } } mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const { - llvm::SmallVector to_process; + llvm::SmallVector> to_process; for (auto& current : op.getLoopBody().front()) { if (auto load = mlir::dyn_cast(current)) @@ -115,7 +145,7 @@ mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::For auto memref = load.memref(); if (checkMemref(memref, op)) { - to_process.emplace_back(memref); + to_process.push_back({memref, load.indices()}); } } } @@ -124,9 +154,9 @@ mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::For { auto loc = op.getLoc(); auto init_args = llvm::to_vector<8>(op.initArgs()); - for (auto val : to_process) + for (auto it : to_process) { - init_args.emplace_back(createScalarLoad(rewriter, loc, val)); + init_args.emplace_back(createScalarLoad(rewriter, loc, it.first, it.second)); } auto prev_args_offset = op.initArgs().size(); auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange iter_vals) @@ -142,7 +172,7 @@ mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::For auto get_iter_index = [&](auto op)->unsigned { auto arg = op.memref(); - for (auto it : llvm::enumerate(to_process)) + for (auto it : llvm::enumerate(llvm::make_first_range(to_process))) { if (arg == it.value()) { @@ -189,7 +219,7 @@ mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::For { auto index = prev_args_offset + it.index(); auto result = results[static_cast(index)]; - createScalarStore(rewriter, loc, result, it.value()); + createScalarStore(rewriter, loc, result, it.value().first, it.value().second); } rewriter.replaceOp(op, results.take_front(prev_args_offset)); return mlir::success(); diff --git a/mlir-compiler/plier/src/rewrites/common_opts.cpp b/mlir-compiler/plier/src/rewrites/common_opts.cpp index 844ccf89dd5..84b5a168d1d 100644 --- a/mlir-compiler/plier/src/rewrites/common_opts.cpp +++ b/mlir-compiler/plier/src/rewrites/common_opts.cpp @@ -3,6 +3,7 @@ #include "plier/rewrites/force_inline.hpp" #include "plier/rewrites/index_type_propagation.hpp" #include "plier/rewrites/loop_rewrites.hpp" +#include "plier/rewrites/memory_rewrites.hpp" #include "plier/rewrites/cse.hpp" #include "plier/rewrites/if_rewrites.hpp" @@ -26,7 +27,9 @@ void plier::populate_common_opts_patterns(mlir::MLIRContext& context, mlir::Owni plier::IfOpConstCond, SimplifySelect, SimplifySelectEq, - plier::CSERewrite + plier::CSERewrite, + PromoteLoads, + SingeWriteMemref >(&context); plier::populate_index_propagate_patterns(context, patterns); diff --git a/mlir-compiler/plier/include/plier/rewrites/if_rewrites.cpp b/mlir-compiler/plier/src/rewrites/if_rewrites.cpp similarity index 100% rename from mlir-compiler/plier/include/plier/rewrites/if_rewrites.cpp rename to mlir-compiler/plier/src/rewrites/if_rewrites.cpp diff --git a/mlir-compiler/plier/src/rewrites/memory_rewrites.cpp b/mlir-compiler/plier/src/rewrites/memory_rewrites.cpp new file mode 100644 index 00000000000..c21bbe22e44 --- /dev/null +++ b/mlir-compiler/plier/src/rewrites/memory_rewrites.cpp @@ -0,0 +1,191 @@ +#include "plier/rewrites/memory_rewrites.hpp" + +#include + +#include + +namespace +{ +bool isWrite(mlir::Operation& op) +{ + if (auto effects = mlir::dyn_cast(op)) + { + return effects.hasEffect(); + } + return false; +} + +bool isRead(mlir::Operation& op) +{ + if (auto effects = mlir::dyn_cast(op)) + { + return effects.hasEffect(); + } + return false; +} + +struct Result +{ + bool changed; + bool hasWrites; + bool hasReads; +}; + +Result promoteLoads(llvm::MutableArrayRef regions, mlir::PatternRewriter& rewriter) +{ + bool changed = false; + bool hasWrites = false; + bool hasReads = false; + bool storeDead = false; + for (auto& region : regions) + { + for (auto& block : region.getBlocks()) + { + mlir::StoreOp currentStore; + for (auto& op : llvm::make_early_inc_range(block)) + { + if (!op.getRegions().empty()) + { + auto res = promoteLoads(op.getRegions(), rewriter); + if (res.changed) + { + changed = true; + } + if (res.hasWrites) + { + currentStore = {}; + } + if (res.hasReads) + { + storeDead = false; + } + continue; + } + + if (auto load = mlir::dyn_cast(op)) + { + hasReads = true; + if (currentStore) + { + if (load.memref() == currentStore.memref() && + load.indices() == currentStore.indices()) + { + rewriter.replaceOp(&op, currentStore.value()); + changed = true; + } + else + { + storeDead = false; + } + } + } + else if (auto store = mlir::dyn_cast(op)) + { + if (currentStore && storeDead && + currentStore.memref() == store.memref() && + currentStore.indices() == store.indices()) + { + rewriter.eraseOp(currentStore); + } + hasWrites = true; + currentStore = store; + storeDead = true; + } + else if (isWrite(op)) + { + hasWrites = true; + currentStore = {}; + } + else if (isRead(op)) + { + hasReads = true; + storeDead = false; + } + else if(op.hasTrait()) + { + currentStore = {}; + hasWrites = true; + hasReads = true; + storeDead = false; + } + } + } + } + return Result{changed, hasWrites, hasReads}; +} + +bool checkIsSingleElementsMemref(mlir::ShapedType type) +{ + if (!type.hasRank()) + { + return false; + } + return llvm::all_of(type.getShape(), [](auto val) { return val == 1; }); +} +} + +mlir::LogicalResult plier::PromoteLoads::matchAndRewrite(mlir::FuncOp op, mlir::PatternRewriter& rewriter) const +{ + auto res = promoteLoads(op->getRegions(), rewriter); + return mlir::success(res.changed); +} + +mlir::LogicalResult plier::SingeWriteMemref::matchAndRewrite(mlir::StoreOp op, mlir::PatternRewriter& rewriter) const +{ + auto memref = op.memref(); + if (!checkIsSingleElementsMemref(memref.getType().cast())) + { + return mlir::failure(); + } + auto parent = memref.getDefiningOp(); + if (!mlir::isa_and_nonnull(parent)) + { + return mlir::failure(); + } + + mlir::StoreOp valueStore; + llvm::SmallVector loads; + for (auto user : memref.getUsers()) + { + if (auto store = mlir::dyn_cast(user)) + { + if (valueStore) + { + // More than one store + return mlir::failure(); + } + valueStore = store; + } + else if (auto load = mlir::dyn_cast(user)) + { + loads.emplace_back(load); + } + else if (mlir::isa(user)) + { + // nothing + } + else + { + // Unsupported op + return mlir::failure(); + } + } + + auto parentBlock = parent->getBlock(); + if (!valueStore || valueStore->getBlock() != parentBlock) + { + return mlir::failure(); + } + + auto val = valueStore.value(); + for (auto load : loads) + { + rewriter.replaceOp(load, val); + } + for (auto user : llvm::make_early_inc_range(parent->getUsers())) + { + rewriter.eraseOp(user); + } + rewriter.eraseOp(parent); + return mlir::success(); +} diff --git a/mlir-compiler/plier/src/transforms/block_utils.cpp b/mlir-compiler/plier/src/transforms/block_utils.cpp new file mode 100644 index 00000000000..e9b346cd59d --- /dev/null +++ b/mlir-compiler/plier/src/transforms/block_utils.cpp @@ -0,0 +1,64 @@ +#include "plier/transforms/block_utils.hpp" + +#include + +namespace +{ +auto collectParentOps(mlir::Operation* op) +{ + llvm::SmallVector ops; + while (true) + { + assert(op); + ops.emplace_back(op); + auto parent = op->getParentOp(); + if (!parent) + { + break; + } + op = parent; + } + return ops; +} +} + +plier::OpRelation plier::relativeTo(mlir::Operation* op, mlir::Operation* relativeTo) +{ + assert(op); + assert(relativeTo); + + for (auto& reg : relativeTo->getRegions()) + { + for (auto& block : reg) + { + if (block.findAncestorOpInBlock(*op)) + { + return OpRelation::In; + } + } + } + + auto ops1 = collectParentOps(op); + auto ops2 = collectParentOps(relativeTo); + + for (auto op1 : ops1) + { + assert(op1); + for (auto op2 : ops2) + { + assert(op2); + if (op1->getBlock() == op1->getBlock()) + { + if (op1->isBeforeInBlock(op1)) + { + return OpRelation::Before; + } + else + { + return OpRelation::After; + } + } + } + } + return OpRelation::Unknown; +} diff --git a/mlir-compiler/plier/src/transforms/loop_utils.cpp b/mlir-compiler/plier/src/transforms/loop_utils.cpp index a5898cf7d70..970669acc6e 100644 --- a/mlir-compiler/plier/src/transforms/loop_utils.cpp +++ b/mlir-compiler/plier/src/transforms/loop_utils.cpp @@ -237,3 +237,161 @@ mlir::LogicalResult plier::lower_while_to_for( return mlir::success(changed); } +// TODO: Copypasted from mlir +namespace +{ +using namespace mlir; + +/// Verify there are no nested ParallelOps. +static bool hasNestedParallelOp(scf::ParallelOp ploop) { + auto walkResult = + ploop.getBody()->walk([](scf::ParallelOp) { return WalkResult::interrupt(); }); + return walkResult.wasInterrupted(); +} + +/// Verify equal iteration spaces. +static bool equalIterationSpaces(scf::ParallelOp firstPloop, + scf::ParallelOp secondPloop) { + if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) + return false; + + auto matchOperands = [&](const OperandRange &lhs, + const OperandRange &rhs) -> bool { + // TODO: Extend this to support aliases and equal constants. + return std::equal(lhs.begin(), lhs.end(), rhs.begin()); + }; + return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) && + matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) && + matchOperands(firstPloop.step(), secondPloop.step()); +} + +/// Checks if the parallel loops have mixed access to the same buffers. Returns +/// `true` if the first parallel loop writes to the same indices that the second +/// loop reads. +static bool haveNoReadsAfterWriteExceptSameIndex( + scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, + const BlockAndValueMapping &firstToSecondPloopIndices) { + DenseMap> bufferStores; + firstPloop.getBody()->walk([&](StoreOp store) { + bufferStores[store.getMemRef()].push_back(store.indices()); + }); + auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) { + // Stop if the memref is defined in secondPloop body. Careful alias analysis + // is needed. + auto *memrefDef = load.getMemRef().getDefiningOp(); + if (memrefDef && memrefDef->getBlock() == load->getBlock()) + return WalkResult::interrupt(); + + auto write = bufferStores.find(load.getMemRef()); + if (write == bufferStores.end()) + return WalkResult::advance(); + + // Allow only single write access per buffer. + if (write->second.size() != 1) + return WalkResult::interrupt(); + + // Check that the load indices of secondPloop coincide with store indices of + // firstPloop for the same memrefs. + auto storeIndices = write->second.front(); + auto loadIndices = load.indices(); + if (storeIndices.size() != loadIndices.size()) + return WalkResult::interrupt(); + for (size_t i = 0, e = storeIndices.size(); i < e; ++i) { + if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != + loadIndices[i]) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); +} + +/// Analyzes dependencies in the most primitive way by checking simple read and +/// write patterns. +static LogicalResult +verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, + const BlockAndValueMapping &firstToSecondPloopIndices) { + if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop, + firstToSecondPloopIndices)) + return failure(); + + BlockAndValueMapping secondToFirstPloopIndices; + secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), + firstPloop.getBody()->getArguments()); + return success(haveNoReadsAfterWriteExceptSameIndex( + secondPloop, firstPloop, secondToFirstPloopIndices)); +} + +static bool +isFusionLegal(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, + const BlockAndValueMapping &firstToSecondPloopIndices) { + return !hasNestedParallelOp(firstPloop) && + !hasNestedParallelOp(secondPloop) && + equalIterationSpaces(firstPloop, secondPloop) && + succeeded(verifyDependencies(firstPloop, secondPloop, + firstToSecondPloopIndices)); +} + +/// Prepends operations of firstPloop's body into secondPloop's body. +static bool fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, + OpBuilder &b) { + BlockAndValueMapping firstToSecondPloopIndices; + firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), + secondPloop.getBody()->getArguments()); + + if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices)) + return false; + + b.setInsertionPointToStart(secondPloop.getBody()); + for (auto &op : firstPloop.getBody()->without_terminator()) + b.clone(op, firstToSecondPloopIndices); + firstPloop.erase(); + return true; +} + +bool hasNoEffect(mlir::Operation* op) +{ + if (op->getNumRegions() != 0) + { + return false; + } + if (auto interface = dyn_cast(op)) + { + return !interface.hasEffect() && + !interface.hasEffect(); + } + return !op->hasTrait<::mlir::OpTrait::HasRecursiveSideEffects>(); +} +} + +mlir::LogicalResult plier::naivelyFuseParallelOps(Region ®ion) { + OpBuilder b(region); + // Consider every single block and attempt to fuse adjacent loops. + bool changed = false; + for (auto &block : region) { + SmallVector, 1> ploopChains{{}}; + // Not using `walk()` to traverse only top-level parallel loops and also + // make sure that there are no side-effecting ops between the parallel + // loops. + bool noSideEffects = true; + for (auto &op : block) { + if (auto ploop = dyn_cast(op)) { + if (noSideEffects) { + ploopChains.back().push_back(ploop); + } else { + ploopChains.push_back({ploop}); + noSideEffects = true; + } + continue; + } + // TODO: Handle region side effects properly. + noSideEffects &= hasNoEffect(&op); + } + for (llvm::ArrayRef ploops : ploopChains) { + for (size_t i = 0, e = ploops.size(); i + 1 < e; ++i) + if (fuseIfLegal(ploops[i], ploops[i + 1], b)) + changed = true; + } + } + return mlir::success(changed); +} From c84a4171578b5a77bbbd094851ebff7aaf9901ab Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 19 Mar 2021 19:31:08 +0300 Subject: [PATCH 254/259] remove plier (#207) --- mlir-compiler/CMakeLists.txt | 3 +- mlir-compiler/plier/CMakeLists.txt | 95 ---- .../plier/include/plier/CMakeLists.txt | 11 - mlir-compiler/plier/include/plier/PlierOps.td | 273 ---------- .../plier/include/plier/compiler/compiler.hpp | 39 -- .../plier/compiler/pipeline_registry.hpp | 42 -- mlir-compiler/plier/include/plier/dialect.hpp | 59 -- .../include/plier/rewrites/call_lowering.hpp | 30 -- .../rewrites/canonicalize_reductions.hpp | 22 - .../include/plier/rewrites/cast_lowering.hpp | 31 -- .../include/plier/rewrites/common_opts.hpp | 12 - .../plier/include/plier/rewrites/cse.hpp | 25 - .../include/plier/rewrites/force_inline.hpp | 19 - .../include/plier/rewrites/if_rewrites.hpp | 51 -- .../plier/rewrites/index_type_propagation.hpp | 12 - .../include/plier/rewrites/loop_rewrites.hpp | 22 - .../plier/rewrites/memory_rewrites.hpp | 28 - .../plier/rewrites/promote_to_parallel.hpp | 22 - .../plier/rewrites/type_conversion.hpp | 25 - .../include/plier/transforms/block_utils.hpp | 20 - .../include/plier/transforms/cast_utils.hpp | 15 - .../include/plier/transforms/const_utils.hpp | 29 - .../include/plier/transforms/func_utils.hpp | 20 - .../include/plier/transforms/loop_utils.hpp | 33 -- .../plier/transforms/pipeline_utils.hpp | 15 - mlir-compiler/plier/include/plier/utils.hpp | 27 - mlir-compiler/plier/src/compiler/compiler.cpp | 236 -------- .../plier/src/compiler/pipeline_registry.cpp | 256 --------- mlir-compiler/plier/src/dialect.cpp | 509 ------------------ .../plier/src/rewrites/call_lowering.cpp | 39 -- .../src/rewrites/canonicalize_reductions.cpp | 229 -------- .../plier/src/rewrites/cast_lowering.cpp | 34 -- .../plier/src/rewrites/common_opts.cpp | 36 -- mlir-compiler/plier/src/rewrites/cse.cpp | 109 ---- .../plier/src/rewrites/force_inline.cpp | 43 -- .../plier/src/rewrites/if_rewrites.cpp | 140 ----- .../src/rewrites/index_type_propagation.cpp | 161 ------ .../plier/src/rewrites/loop_rewrites.cpp | 110 ---- .../plier/src/rewrites/memory_rewrites.cpp | 191 ------- .../src/rewrites/promote_to_parallel.cpp | 143 ----- .../plier/src/rewrites/type_conversion.cpp | 185 ------- .../plier/src/transforms/block_utils.cpp | 64 --- .../plier/src/transforms/cast_utils.cpp | 19 - .../plier/src/transforms/const_utils.cpp | 39 -- .../plier/src/transforms/func_utils.cpp | 19 - .../plier/src/transforms/loop_utils.cpp | 397 -------------- .../plier/src/transforms/pipeline_utils.cpp | 60 --- mlir-compiler/plier/src/utils.cpp | 10 - 48 files changed, 1 insertion(+), 4008 deletions(-) delete mode 100644 mlir-compiler/plier/CMakeLists.txt delete mode 100644 mlir-compiler/plier/include/plier/CMakeLists.txt delete mode 100644 mlir-compiler/plier/include/plier/PlierOps.td delete mode 100644 mlir-compiler/plier/include/plier/compiler/compiler.hpp delete mode 100644 mlir-compiler/plier/include/plier/compiler/pipeline_registry.hpp delete mode 100644 mlir-compiler/plier/include/plier/dialect.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/canonicalize_reductions.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/cast_lowering.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/common_opts.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/cse.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/force_inline.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/if_rewrites.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/index_type_propagation.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/loop_rewrites.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/memory_rewrites.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/promote_to_parallel.hpp delete mode 100644 mlir-compiler/plier/include/plier/rewrites/type_conversion.hpp delete mode 100644 mlir-compiler/plier/include/plier/transforms/block_utils.hpp delete mode 100644 mlir-compiler/plier/include/plier/transforms/cast_utils.hpp delete mode 100644 mlir-compiler/plier/include/plier/transforms/const_utils.hpp delete mode 100644 mlir-compiler/plier/include/plier/transforms/func_utils.hpp delete mode 100644 mlir-compiler/plier/include/plier/transforms/loop_utils.hpp delete mode 100644 mlir-compiler/plier/include/plier/transforms/pipeline_utils.hpp delete mode 100644 mlir-compiler/plier/include/plier/utils.hpp delete mode 100644 mlir-compiler/plier/src/compiler/compiler.cpp delete mode 100644 mlir-compiler/plier/src/compiler/pipeline_registry.cpp delete mode 100644 mlir-compiler/plier/src/dialect.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/call_lowering.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/cast_lowering.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/common_opts.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/cse.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/force_inline.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/if_rewrites.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/index_type_propagation.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/loop_rewrites.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/memory_rewrites.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp delete mode 100644 mlir-compiler/plier/src/rewrites/type_conversion.cpp delete mode 100644 mlir-compiler/plier/src/transforms/block_utils.cpp delete mode 100644 mlir-compiler/plier/src/transforms/cast_utils.cpp delete mode 100644 mlir-compiler/plier/src/transforms/const_utils.cpp delete mode 100644 mlir-compiler/plier/src/transforms/func_utils.cpp delete mode 100644 mlir-compiler/plier/src/transforms/loop_utils.cpp delete mode 100644 mlir-compiler/plier/src/transforms/pipeline_utils.cpp delete mode 100644 mlir-compiler/plier/src/utils.cpp diff --git a/mlir-compiler/CMakeLists.txt b/mlir-compiler/CMakeLists.txt index e71d55c5147..9d3b7b20117 100644 --- a/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/CMakeLists.txt @@ -21,6 +21,5 @@ elseif(DEFINED DPCOMP_DIR) ${DPCOMP_DIR}/Release ) else() - message(STATUS "PLIER from local directory is used") - add_subdirectory(plier) + message(FATAL_ERROR "dpcomp not found") endif() diff --git a/mlir-compiler/plier/CMakeLists.txt b/mlir-compiler/plier/CMakeLists.txt deleted file mode 100644 index 24c448a5eef..00000000000 --- a/mlir-compiler/plier/CMakeLists.txt +++ /dev/null @@ -1,95 +0,0 @@ - -find_package(LLVM REQUIRED CONFIG) -find_package(MLIR REQUIRED CONFIG) - -list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") -list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") -include(TableGen) -include(AddLLVM) -include(AddMLIR) -include(HandleLLVMOptions) - -add_subdirectory(include/plier) - -set(SOURCES_LIST - src/compiler/compiler.cpp - src/compiler/pipeline_registry.cpp - src/dialect.cpp - src/rewrites/call_lowering.cpp - src/rewrites/canonicalize_reductions.cpp - src/rewrites/cast_lowering.cpp - src/rewrites/common_opts.cpp - src/rewrites/cse.cpp - src/rewrites/force_inline.cpp - src/rewrites/index_type_propagation.cpp - src/rewrites/if_rewrites.cpp - src/rewrites/loop_rewrites.cpp - src/rewrites/memory_rewrites.cpp - src/rewrites/promote_to_parallel.cpp - src/rewrites/type_conversion.cpp - src/transforms/block_utils.cpp - src/transforms/cast_utils.cpp - src/transforms/const_utils.cpp - src/transforms/func_utils.cpp - src/transforms/loop_utils.cpp - src/transforms/pipeline_utils.cpp - src/utils.cpp - ) -set(HEADERS_LIST - include/plier/compiler/compiler.hpp - include/plier/compiler/pipeline_registry.hpp - include/plier/dialect.hpp - include/plier/PlierOps.td - include/plier/rewrites/call_lowering.hpp - include/plier/rewrites/canonicalize_reductions.hpp - include/plier/rewrites/cast_lowering.hpp - include/plier/rewrites/common_opts.hpp - include/plier/rewrites/cse.hpp - include/plier/rewrites/force_inline.hpp - include/plier/rewrites/index_type_propagation.hpp - include/plier/rewrites/if_rewrites.hpp - include/plier/rewrites/loop_rewrites.hpp - include/plier/rewrites/memory_rewrites.hpp - include/plier/rewrites/promote_to_parallel.hpp - include/plier/rewrites/type_conversion.hpp - include/plier/transforms/block_utils.hpp - include/plier/transforms/cast_utils.hpp - include/plier/transforms/const_utils.hpp - include/plier/transforms/func_utils.hpp - include/plier/transforms/loop_utils.hpp - include/plier/transforms/pipeline_utils.hpp - include/plier/utils.hpp - ) - -set(PLIER_LIB "plier") - -add_library(${PLIER_LIB} STATIC ${SOURCES_LIST} ${HEADERS_LIST}) - -if (MSVC) - target_compile_options(${PLIER_LIB} PRIVATE /EHsc) -endif () - -target_compile_definitions(${PLIER_LIB} PRIVATE ${LLVM_DEFINITIONS}) - -target_link_libraries(${PLIER_LIB} PRIVATE - MLIRIR - MLIRLLVMIR - MLIRTransforms - MLIRStandardOpsTransforms - MLIRLinalgTransforms - MLIRSCFToStandard - MLIRTensorTransforms - ) - -target_include_directories(${PLIER_LIB} PRIVATE - ./src - ${LLVM_INCLUDE_DIRS} - ${MLIR_INCLUDE_DIRS} - ) - -target_include_directories(${PLIER_LIB} PUBLIC - ./include - ${PROJECT_BINARY_DIR}/plier/include - ) - -add_dependencies(${PLIER_LIB} MLIRPlierOpsIncGen) diff --git a/mlir-compiler/plier/include/plier/CMakeLists.txt b/mlir-compiler/plier/include/plier/CMakeLists.txt deleted file mode 100644 index 2966b672b41..00000000000 --- a/mlir-compiler/plier/include/plier/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -include_directories(${MLIR_INCLUDE_DIRS}) -set(dialect PlierOps) -set(dialect_namespace plier) -set(LLVM_TARGET_DEFINITIONS ${dialect}.td) -mlir_tablegen(${dialect}Enums.h.inc -gen-enum-decls) -mlir_tablegen(${dialect}Enums.cpp.inc -gen-enum-defs) -mlir_tablegen(${dialect}.h.inc -gen-op-decls) -mlir_tablegen(${dialect}.cpp.inc -gen-op-defs) -mlir_tablegen(${dialect}Dialect.h.inc -gen-dialect-decls -dialect=${dialect_namespace}) -add_public_tablegen_target(MLIR${dialect}IncGen) -add_dependencies(mlir-headers MLIR${dialect}IncGen) diff --git a/mlir-compiler/plier/include/plier/PlierOps.td b/mlir-compiler/plier/include/plier/PlierOps.td deleted file mode 100644 index 3c222966ef3..00000000000 --- a/mlir-compiler/plier/include/plier/PlierOps.td +++ /dev/null @@ -1,273 +0,0 @@ -#ifndef PLIER_OPS -#define PLIER_OPS - -include "mlir/IR/OpBase.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" -include "mlir/Interfaces/LoopLikeInterface.td" -include "mlir/Interfaces/SideEffectInterfaces.td" - -def Plier_Dialect : Dialect { - let name = "plier"; - let cppNamespace = "plier"; -} - -def Plier_PyType : DialectType()">, "pytype">, - BuildableType<"$_builder.getType<::plier::PyType>()"> { -} - -class Plier_Op traits = []> : - Op; - -def ArgOp : Plier_Op<"arg", [NoSideEffect]> { - let arguments = (ins - UI32Attr:$index, - StrAttr:$name); - - let results = (outs AnyType); - let hasFolder = 1; - - let builders = [ - OpBuilder<(ins "unsigned":$index, "::mlir::StringRef":$name)> - ]; -} - -def ConstOp : Plier_Op<"const", [NoSideEffect]> { - let arguments = (ins - AnyAttr:$val); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::Attribute":$val)> - ]; -} - -def GlobalOp : Plier_Op<"global", [NoSideEffect]> { - let arguments = (ins - StrAttr:$name); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::StringRef":$name)> - ]; -} - -def BinOp : Plier_Op<"binop", []> { - let arguments = (ins - AnyType:$lhs, - AnyType:$rhs, - StrAttr:$op); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$lhs, "::mlir::Value":$rhs, "::mlir::StringRef ":$op)> - ]; -} - -def UnaryOp : Plier_Op<"unary", []> { - let arguments = (ins - AnyType:$value, - StrAttr:$op); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value, "::mlir::StringRef ":$op)> - ]; -} - -def CastOp : Plier_Op<"cast", []> { - let arguments = (ins - AnyType:$value); - - let results = (outs AnyType); - let hasFolder = 1; -} - -def PyCallOp : Plier_Op<"call", []> { - let arguments = (ins - AnyType:$func, - Variadic:$args, - StrAttr:$func_name, - UI32Attr:$kw_start, - ArrayAttr:$kw_names); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$func, "::mlir::StringRef":$func_name, - "::mlir::ValueRange":$args, - "::mlir::ArrayRef>":$kwargs)> - ]; -} - -def BuildTupleOp : Plier_Op<"build_tuple", [NoSideEffect]> { - let arguments = (ins - Variadic:$args); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::ValueRange":$args)> - ]; -} - -def GetItemOp : Plier_Op<"getitem", []> { - let arguments = (ins - AnyType:$value, - AnyType:$index); - - let results = (outs AnyType); - let hasFolder = 1; - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value, "::mlir::Value":$index)> - ]; -} - -def StaticGetItemOp : Plier_Op<"static_getitem", []> { - let arguments = (ins - AnyType:$value, - AnyType:$index_var, - UI32Attr:$index); - - let results = (outs AnyType); - let hasFolder = 1; - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value, "::mlir::Value":$index_var, "unsigned":$index)> - ]; -} - -def SetItemOp : Plier_Op<"setitem", []> { - let arguments = (ins - AnyType:$target, - AnyType:$index, - AnyType:$value); - - let builders = []; -} - -def GetiterOp : Plier_Op<"getiter", [NoSideEffect]> { - let arguments = (ins - AnyType:$value); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value)> - ]; -} - -def IternextOp : Plier_Op<"iternext", []> { - let arguments = (ins - AnyType:$value); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value)> - ]; -} - -def PairfirstOp : Plier_Op<"pair_first", [NoSideEffect]> { - let arguments = (ins - AnyType:$value); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value)> - ]; -} - -def PairsecondOp : Plier_Op<"pair_second", [NoSideEffect]> { - let arguments = (ins - AnyType:$value); - - let results = (outs AnyType); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value)> - ]; -} - -def DelOp : Plier_Op<"del", []> { - let arguments = (ins - AnyType:$value); -} - -def GetattrOp : Plier_Op<"getattr", [NoSideEffect]> { - let arguments = (ins - AnyType:$value, - StrAttr:$name); - - let results = (outs AnyType); - - let hasCanonicalizer = 1; - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value, "::mlir::StringRef":$name)> - ]; -} - -def EnforceShapeOp : Plier_Op<"enforce_shape"> { - let arguments = (ins AnyRankedTensor:$value, - Variadic:$sizes); - - let results = (outs AnyRankedTensor:$result); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value, "::mlir::ValueRange":$shape)> - ]; - - let hasFolder = 1; - let hasCanonicalizer = 1; -} - -def RetainOp : Plier_Op<"retain"> { - let arguments = (ins AnyMemRef:$value); - - let results = (outs Res]>:$memref); - - let builders = [ - OpBuilder<(ins "::mlir::Value":$value)> - ]; -} - -def ParallelOp : Plier_Op<"parallel", - [AttrSizedOperandSegments, - DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"plier::YieldOp">, - RecursiveSideEffects]> { - - let arguments = (ins Variadic:$lowerBounds, - Variadic:$upperBounds, - Variadic:$steps); - let regions = (region SizedRegion<1>:$region); - - let skipDefaultBuilders = 1; - let builders = [ - OpBuilder<(ins "::mlir::ValueRange":$lowerBounds, "::mlir::ValueRange":$upperBounds, "::mlir::ValueRange":$steps, - CArg<"::mlir::function_ref", - "nullptr">)> - ]; - - let extraClassDeclaration = [{ - unsigned getNumLoops() { return steps().size(); } - }]; -} - -def YieldOp : Plier_Op<"yield", [NoSideEffect, ReturnLike, Terminator, - ParentOneOf<["ParallelOp"]>]> { - let arguments = (ins Variadic:$results); - let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - // Override default verifier (defined in SCF_Op), no custom verification - // needed. - let verifier = ?; -} - -#endif // PLIER_OPS diff --git a/mlir-compiler/plier/include/plier/compiler/compiler.hpp b/mlir-compiler/plier/include/plier/compiler/compiler.hpp deleted file mode 100644 index 7d18abbfcbf..00000000000 --- a/mlir-compiler/plier/include/plier/compiler/compiler.hpp +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include - -namespace mlir -{ -class MLIRContext; -class ModuleOp; -} - -namespace plier -{ -class PipelineRegistry; - -class CompilerContext -{ -public: - struct Settings - { - bool verify = false; - bool pass_statistics = false; - bool pass_timings = false; - bool ir_printing = false; - }; - - class CompilerContextImpl; - - CompilerContext(mlir::MLIRContext& ctx, const Settings& settings, - const PipelineRegistry& registry); - ~CompilerContext(); - - CompilerContext(CompilerContext&&) = default; - - void run(mlir::ModuleOp module); - -private: - std::unique_ptr impl; -}; -} diff --git a/mlir-compiler/plier/include/plier/compiler/pipeline_registry.hpp b/mlir-compiler/plier/include/plier/compiler/pipeline_registry.hpp deleted file mode 100644 index 1c7e20272f5..00000000000 --- a/mlir-compiler/plier/include/plier/compiler/pipeline_registry.hpp +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once - -#include -#include -#include - -#include -#include - -namespace mlir -{ -class OpPassManager; -} - -namespace plier -{ -class PipelineRegistry -{ -public: - PipelineRegistry() = default; - PipelineRegistry(const PipelineRegistry&) = delete; - - using pipeline_funt_t = void(*)(mlir::OpPassManager&); - using registry_entry_sink_t = void( - llvm::StringRef pipeline_name, - llvm::ArrayRef prev_pipelines, - llvm::ArrayRef next_pipelines, - llvm::ArrayRef jumps, - pipeline_funt_t func); - using registry_entry_t = std::function)>; - - void register_pipeline(registry_entry_t func); - - using fill_stage_sink_t = llvm::function_ref jumps, llvm::function_ref)>; - using populate_pass_manager_sink_t = llvm::function_ref; - using populate_pass_manager_t = llvm::function_ref; - void populate_pass_manager(populate_pass_manager_t result_sink) const; - -private: - std::vector pipelines; -}; -} diff --git a/mlir-compiler/plier/include/plier/dialect.hpp b/mlir-compiler/plier/include/plier/dialect.hpp deleted file mode 100644 index f859191215a..00000000000 --- a/mlir-compiler/plier/include/plier/dialect.hpp +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace plier -{ -// TODO: needed for LoopLikeInterface -using Value = ::mlir::Value; -using Region = ::mlir::Region; -using LogicalResult = ::mlir::LogicalResult; -using Operation = ::mlir::Operation; -namespace MemoryEffects = ::mlir::MemoryEffects; - -template -using ArrayRef = ::mlir::ArrayRef; -} - -#include "plier/PlierOpsEnums.h.inc" -#include "plier/PlierOpsDialect.h.inc" -#define GET_OP_CLASSES -#include "plier/PlierOps.h.inc" - -namespace plier -{ -namespace attributes -{ -llvm::StringRef getFastmathName(); -llvm::StringRef getJumpMarkersName(); -llvm::StringRef getParallelName(); -llvm::StringRef getMaxConcurrencyName(); -llvm::StringRef getForceInlineName(); -} - -namespace detail -{ -struct PyTypeStorage; -} - -class PyType : public mlir::Type::TypeBase<::plier::PyType, mlir::Type, - ::plier::detail::PyTypeStorage> -{ -public: - using Base::Base; - - static PyType get(mlir::MLIRContext *context, mlir::StringRef name); - static PyType getUndefined(mlir::MLIRContext *context); - static PyType getNone(mlir::MLIRContext *context); - - mlir::StringRef getName() const; -}; - - -} diff --git a/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp b/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp deleted file mode 100644 index b87c11b0ad3..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/call_lowering.hpp +++ /dev/null @@ -1,30 +0,0 @@ -#pragma once - -#include - -#include "plier/dialect.hpp" - -#include - -namespace mlir -{ -class TypeConverter; -} - -namespace plier -{ -struct CallOpLowering : public mlir::OpRewritePattern -{ - using resolver_t = std::function, llvm::ArrayRef> , mlir::PatternRewriter&)>; - - CallOpLowering(mlir::TypeConverter &typeConverter, - mlir::MLIRContext *context, - resolver_t resolver); - - mlir::LogicalResult matchAndRewrite( - plier::PyCallOp op, mlir::PatternRewriter &rewriter) const override; - -private: - resolver_t resolver; -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/canonicalize_reductions.hpp b/mlir-compiler/plier/include/plier/rewrites/canonicalize_reductions.hpp deleted file mode 100644 index 45f66b7590d..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/canonicalize_reductions.hpp +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include - -namespace mlir -{ -namespace scf -{ -class ForOp; -} -} - -namespace plier -{ -struct CanonicalizeReduction : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/cast_lowering.hpp b/mlir-compiler/plier/include/plier/rewrites/cast_lowering.hpp deleted file mode 100644 index a003eca8568..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/cast_lowering.hpp +++ /dev/null @@ -1,31 +0,0 @@ -#pragma once - -#include - -#include "plier/dialect.hpp" - -#include - -namespace mlir -{ -class TypeConverter; -} - -namespace plier -{ -struct CastOpLowering : public mlir::OpRewritePattern -{ - using cast_t = std::function; - - CastOpLowering(mlir::TypeConverter &typeConverter, - mlir::MLIRContext *context, - cast_t cast_func = nullptr); - - mlir::LogicalResult matchAndRewrite( - plier::CastOp op, mlir::PatternRewriter &rewriter) const override; - -private: - mlir::TypeConverter& converter; - cast_t cast_func; -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/common_opts.hpp b/mlir-compiler/plier/include/plier/rewrites/common_opts.hpp deleted file mode 100644 index 5b31352f4cd..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/common_opts.hpp +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -namespace mlir -{ -class OwningRewritePatternList; -class MLIRContext; -} - -namespace plier -{ -void populate_common_opts_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns); -} diff --git a/mlir-compiler/plier/include/plier/rewrites/cse.hpp b/mlir-compiler/plier/include/plier/rewrites/cse.hpp deleted file mode 100644 index bcc9b6578f5..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/cse.hpp +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include -#include - -namespace plier -{ -namespace detail -{ -mlir::LogicalResult applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter, bool recusive); -} - -template -struct CSERewrite : public mlir::OpRewritePattern -{ - CSERewrite(mlir::MLIRContext *context): - mlir::OpRewritePattern(context, /*benefit*/1) {} // TODO: benefit=0 - - mlir::LogicalResult matchAndRewrite( - Op op, mlir::PatternRewriter &rewriter) const override - { - return ::plier::detail::applyCSE(op.getRegion(), rewriter, Recursive); - } -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/force_inline.hpp b/mlir-compiler/plier/include/plier/rewrites/force_inline.hpp deleted file mode 100644 index 5c518c71211..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/force_inline.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#include - -namespace mlir -{ -class CallOp; -} - -namespace plier -{ -struct ForceInline : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::CallOp op, mlir::PatternRewriter &rewriter) const override; -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/if_rewrites.hpp b/mlir-compiler/plier/include/plier/rewrites/if_rewrites.hpp deleted file mode 100644 index 80302dcd07a..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/if_rewrites.hpp +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include - -namespace mlir -{ -class SelectOp; -namespace scf -{ -class IfOp; -} -} - -namespace plier -{ -struct IfOpConstCond : public mlir::OpRewritePattern -{ - IfOpConstCond(mlir::MLIRContext *context): - mlir::OpRewritePattern(context, /*benefit*/1) {} - - mlir::LogicalResult matchAndRewrite( - mlir::scf::IfOp op, mlir::PatternRewriter &rewriter) const override; -}; - -// TODO: upstream -struct SimplifyEmptyIf : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::scf::IfOp op, mlir::PatternRewriter &rewriter) const override; -}; - -// TODO: upstream -struct SimplifySelect : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::SelectOp op, mlir::PatternRewriter &rewriter) const override; -}; - -// TODO: upstream -struct SimplifySelectEq : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::SelectOp op, mlir::PatternRewriter &rewriter) const override; -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/index_type_propagation.hpp b/mlir-compiler/plier/include/plier/rewrites/index_type_propagation.hpp deleted file mode 100644 index d5e1f5c14c3..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/index_type_propagation.hpp +++ /dev/null @@ -1,12 +0,0 @@ -#pragma once - -namespace mlir -{ -class OwningRewritePatternList; -class MLIRContext; -} - -namespace plier -{ -void populate_index_propagate_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns); -} diff --git a/mlir-compiler/plier/include/plier/rewrites/loop_rewrites.hpp b/mlir-compiler/plier/include/plier/rewrites/loop_rewrites.hpp deleted file mode 100644 index 9e2ecf684af..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/loop_rewrites.hpp +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include - -namespace mlir -{ -namespace scf -{ -class ForOp; -} -} - -namespace plier -{ -struct CmpLoopBoundsSimplify : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/memory_rewrites.hpp b/mlir-compiler/plier/include/plier/rewrites/memory_rewrites.hpp deleted file mode 100644 index 967b51efacc..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/memory_rewrites.hpp +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#include - -namespace mlir -{ -class FuncOp; -class StoreOp; -} - -namespace plier -{ -struct PromoteLoads : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::FuncOp op, mlir::PatternRewriter &rewriter) const override; -}; - -struct SingeWriteMemref : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::StoreOp op, mlir::PatternRewriter &rewriter) const override; -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/promote_to_parallel.hpp b/mlir-compiler/plier/include/plier/rewrites/promote_to_parallel.hpp deleted file mode 100644 index eabbd31d359..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/promote_to_parallel.hpp +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#include - -namespace mlir -{ -namespace scf -{ -class ForOp; -} -} - -namespace plier -{ -struct PromoteToParallel : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::scf::ForOp op, mlir::PatternRewriter &rewriter) const override; -}; -} diff --git a/mlir-compiler/plier/include/plier/rewrites/type_conversion.hpp b/mlir-compiler/plier/include/plier/rewrites/type_conversion.hpp deleted file mode 100644 index 80fa189d476..00000000000 --- a/mlir-compiler/plier/include/plier/rewrites/type_conversion.hpp +++ /dev/null @@ -1,25 +0,0 @@ -#pragma once - -#include -#include - -namespace mlir -{ -class TypeConverter; -} - -namespace plier -{ -struct FuncOpSignatureConversion : public mlir::OpRewritePattern -{ - FuncOpSignatureConversion(mlir::TypeConverter& conv, - mlir::MLIRContext* ctx); - - /// Hook for derived classes to implement combined matching and rewriting. - mlir::LogicalResult - matchAndRewrite(mlir::FuncOp funcOp, mlir::PatternRewriter &rewriter) const override; - -private: - mlir::TypeConverter& converter; -}; -} diff --git a/mlir-compiler/plier/include/plier/transforms/block_utils.hpp b/mlir-compiler/plier/include/plier/transforms/block_utils.hpp deleted file mode 100644 index fcaf961be80..00000000000 --- a/mlir-compiler/plier/include/plier/transforms/block_utils.hpp +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -namespace mlir -{ -class Operation; -} - -namespace plier -{ - -enum class OpRelation -{ - Before, - After, - In, - Unknown -}; - -OpRelation relativeTo(mlir::Operation* op, mlir::Operation* relativeTo); -} diff --git a/mlir-compiler/plier/include/plier/transforms/cast_utils.hpp b/mlir-compiler/plier/include/plier/transforms/cast_utils.hpp deleted file mode 100644 index d36a4e95719..00000000000 --- a/mlir-compiler/plier/include/plier/transforms/cast_utils.hpp +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -namespace mlir -{ -class Value; -class Location; -class OpBuilder; -class Type; -} - -namespace plier -{ -mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, mlir::Type dst_type); -mlir::Value index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src); -} diff --git a/mlir-compiler/plier/include/plier/transforms/const_utils.hpp b/mlir-compiler/plier/include/plier/transforms/const_utils.hpp deleted file mode 100644 index ea4f22b0eec..00000000000 --- a/mlir-compiler/plier/include/plier/transforms/const_utils.hpp +++ /dev/null @@ -1,29 +0,0 @@ -#pragma once - -#include -#include - -namespace mlir -{ -class Operation; -} - -namespace plier -{ -mlir::Attribute getConstVal(mlir::Operation* op); -mlir::Attribute getConstVal(mlir::Value op); - -template -T getConstVal(mlir::Operation* op) -{ - return getConstVal(op).dyn_cast_or_null(); -} - -template -T getConstVal(mlir::Value op) -{ - return getConstVal(op).dyn_cast_or_null(); -} - -mlir::Attribute getZeroVal(mlir::Type type); -} diff --git a/mlir-compiler/plier/include/plier/transforms/func_utils.hpp b/mlir-compiler/plier/include/plier/transforms/func_utils.hpp deleted file mode 100644 index 8065ddc8ede..00000000000 --- a/mlir-compiler/plier/include/plier/transforms/func_utils.hpp +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -namespace mlir -{ -class ModuleOp; -class FuncOp; -class OpBuilder; -class FunctionType; -} - -namespace llvm -{ -class StringRef; -} - -namespace plier -{ -mlir::FuncOp add_function(mlir::OpBuilder& builder, mlir::ModuleOp module, - llvm::StringRef name, mlir::FunctionType type); -} diff --git a/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp b/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp deleted file mode 100644 index 2069f51dff6..00000000000 --- a/mlir-compiler/plier/include/plier/transforms/loop_utils.hpp +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#include - -namespace mlir -{ -struct LogicalResult; -class PatternRewriter; -class Value; -class Location; -class OpBuilder; -class Type; -class Region; -namespace scf -{ -class ForOp; -} -} - -namespace plier -{ -class GetiterOp; -} - -namespace plier -{ -mlir::LogicalResult lower_while_to_for(plier::GetiterOp getiter, mlir::PatternRewriter& builder, - llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, - llvm::function_ref get_iter_val, - llvm::function_ref results = nullptr); - -mlir::LogicalResult naivelyFuseParallelOps(mlir::Region ®ion); -} diff --git a/mlir-compiler/plier/include/plier/transforms/pipeline_utils.hpp b/mlir-compiler/plier/include/plier/transforms/pipeline_utils.hpp deleted file mode 100644 index 0d80ebfb1eb..00000000000 --- a/mlir-compiler/plier/include/plier/transforms/pipeline_utils.hpp +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -namespace mlir -{ -class ArrayAttr; -class ModuleOp; -class StringAttr; -} - -namespace plier -{ -mlir::ArrayAttr get_pipeline_jump_markers(mlir::ModuleOp module); -void add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name); -void remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name); -} diff --git a/mlir-compiler/plier/include/plier/utils.hpp b/mlir-compiler/plier/include/plier/utils.hpp deleted file mode 100644 index 83610967d66..00000000000 --- a/mlir-compiler/plier/include/plier/utils.hpp +++ /dev/null @@ -1,27 +0,0 @@ -#pragma once - -#include - -#include - -namespace llvm -{ -class Twine; -} - -namespace plier -{ -[[noreturn]] void report_error(const llvm::Twine& msg); - -template -void scoped_diag_handler(T& ctx, H&& diag_handler, F&& func) -{ - auto& diag_engine = ctx.getDiagEngine(); - auto diag_id = diag_engine.registerHandler(std::forward(diag_handler)); - auto diag_guard = llvm::make_scope_exit([&]() - { - diag_engine.eraseHandler(diag_id); - }); - func(); -} -} diff --git a/mlir-compiler/plier/src/compiler/compiler.cpp b/mlir-compiler/plier/src/compiler/compiler.cpp deleted file mode 100644 index d469fa79547..00000000000 --- a/mlir-compiler/plier/src/compiler/compiler.cpp +++ /dev/null @@ -1,236 +0,0 @@ -#include "plier/compiler/compiler.hpp" - -#include -#include -#include - -#include - -#include - -#include - -#include "plier/utils.hpp" - -#include "plier/compiler/pipeline_registry.hpp" - -#include "plier/transforms/pipeline_utils.hpp" - -namespace -{ -struct PassManagerStage -{ - template - PassManagerStage(mlir::MLIRContext& ctx, - const plier::CompilerContext::Settings& settings, - F&& init_func): - pm(&ctx) - { - pm.enableVerifier(settings.verify); - - if (settings.pass_statistics) - { - pm.enableStatistics(); - } - if (settings.pass_timings) - { - pm.enableTiming(); - } - if (settings.ir_printing) - { - ctx.enableMultithreading(false); - pm.enableIRPrinting(); - } - - init_func(pm); - } - - void add_jump(mlir::StringAttr name, PassManagerStage* stage) - { - assert(!name.getValue().empty()); - assert(nullptr != stage); - jumps.emplace_back(name, stage); - } - - std::pair get_jump(mlir::ArrayAttr names) const - { - if (names) - { - for (auto& it : jumps) - { - for (auto name : names) - { - auto str = name.cast(); - if (it.first == str) - { - return {it.second, str}; - } - } - } - } - return {nullptr, nullptr}; - } - - void set_next_stage(PassManagerStage* stage) - { - assert(nullptr == next_stage); - assert(nullptr != stage); - next_stage = stage; - } - - PassManagerStage* get_next_stage() const - { - return next_stage; - } - - mlir::LogicalResult run(mlir::ModuleOp op) - { - return pm.run(op); - } - -private: - mlir::PassManager pm; - llvm::SmallVector, 1> jumps; - PassManagerStage* next_stage = nullptr; -}; - -struct PassManagerSchedule -{ - PassManagerSchedule(mlir::MLIRContext& ctx, - const plier::CompilerContext::Settings& settings, - const plier::PipelineRegistry& registry) - { - auto func = [&](auto sink) - { - struct StageDesc - { - llvm::StringRef name; - llvm::ArrayRef jumps; - std::unique_ptr stage; - }; - - assert(nullptr == stages); - llvm::SmallVector stages_temp; - std::unordered_map stages_map; - - auto add_stage = [&](llvm::StringRef name, llvm::ArrayRef jumps, auto pm_init_func) - { - assert(!name.empty()); - auto prev_stage = (stages_map.empty() ? nullptr : stages_temp.back().stage.get()); - stages_temp.push_back({name, jumps, std::make_unique(ctx, settings, pm_init_func)}); - assert(stages_map.count(name.data()) == 0); - stages_map.insert({name.data(), stages_temp.back().stage.get()}); - if (nullptr != prev_stage) - { - prev_stage->set_next_stage(stages_temp.back().stage.get()); - } - }; - - sink(add_stage); - - for (auto& stage : stages_temp) - { - for (auto jump : stage.jumps) - { - assert(!jump.empty()); - auto it = stages_map.find(jump.data()); - assert(it != stages_map.end()); - assert(nullptr != it->second); - auto name = mlir::StringAttr::get(&ctx, jump); - stage.stage->add_jump(name, it->second); - } - } - - stages = std::make_unique[]>(stages_temp.size()); - for (auto it : llvm::enumerate(stages_temp)) - { - stages[it.index()] = std::move(it.value().stage); - } - }; - registry.populate_pass_manager(func); - } - - mlir::LogicalResult run(mlir::ModuleOp module) - { - assert(nullptr != stages); - auto current = stages[0].get(); - do - { - assert(nullptr != current); - if (mlir::failed(current->run(module))) - { - return mlir::failure(); - } - auto markers = plier::get_pipeline_jump_markers(module); - auto jump_target = current->get_jump(markers); - if (nullptr != jump_target.first) - { - plier::remove_pipeline_jump_marker(module, jump_target.second); - current = jump_target.first; - } - else - { - current = current->get_next_stage(); - } - } - while (nullptr != current); - return mlir::success(); - } - -private: - std::unique_ptr[]> stages; -}; -} - -class plier::CompilerContext::CompilerContextImpl -{ -public: - CompilerContextImpl(mlir::MLIRContext& ctx, - const CompilerContext::Settings& settings, - const plier::PipelineRegistry& registry): - schedule(ctx, settings, registry) {} - - void run(mlir::ModuleOp module) - { - std::string err; - llvm::raw_string_ostream err_stream(err); - auto diag_handler = [&](mlir::Diagnostic& diag) - { - if (diag.getSeverity() == mlir::DiagnosticSeverity::Error) - { - err_stream << diag; - } - }; - - plier::scoped_diag_handler(*module.getContext(), diag_handler, [&]() - { - if (mlir::failed(schedule.run(module))) - { - err_stream << "\n"; - module.print(err_stream); - err_stream.flush(); - plier::report_error(llvm::Twine("MLIR pipeline failed\n") + err); - } - }); - } -private: - PassManagerSchedule schedule; -}; - -plier::CompilerContext::CompilerContext(mlir::MLIRContext& ctx, - const Settings& settings, - const PipelineRegistry& registry): - impl(std::make_unique(ctx, settings, registry)) -{ - -} - -plier::CompilerContext::~CompilerContext() -{ - -} - -void plier::CompilerContext::run(mlir::ModuleOp module) -{ - impl->run(module); -} diff --git a/mlir-compiler/plier/src/compiler/pipeline_registry.cpp b/mlir-compiler/plier/src/compiler/pipeline_registry.cpp deleted file mode 100644 index 8a16446bc5f..00000000000 --- a/mlir-compiler/plier/src/compiler/pipeline_registry.cpp +++ /dev/null @@ -1,256 +0,0 @@ -#include "plier/compiler/pipeline_registry.hpp" - -#include -#include -#include - -#include "plier/utils.hpp" - -#include -#include -#include - -void plier::PipelineRegistry::register_pipeline(PipelineRegistry::registry_entry_t func) -{ - assert(nullptr != func); - pipelines.push_back(std::move(func)); -} - -namespace -{ -template -void topo_visit(T& elem, IterF&& iter_func, VisitF&& func) -{ - if (elem.visited) - { - return; - } - elem.visited = true; - iter_func(elem, [&](T& next) - { - topo_visit(next, std::forward(iter_func), std::forward(func)); - }); - func(elem); -} -} - -void plier::PipelineRegistry::populate_pass_manager(populate_pass_manager_t result_sink) const -{ - llvm::BumpPtrAllocator allocator; - llvm::UniqueStringSaver string_set(allocator); - - using name_id = const void*; - auto get_id = [](llvm::StringRef name)->name_id - { - assert(!name.empty()); - return name.data(); - }; - std::set pipelines_ordered; // sorted set to make order consistent - - auto get_pipeline = [&](llvm::StringRef name)->llvm::StringRef - { - if (name.empty()) - { - report_error("Empty pipeline name"); - } - auto str = string_set.save(name); - pipelines_ordered.insert(str); - return str; - }; - - struct PipelineSet : protected llvm::SmallVector - { - using Base = llvm::SmallVector; - using Base::begin; - using Base::end; - using Base::value_type; - void push_back(llvm::StringRef id) - { - auto it = std::equal_range(begin(), end(), id); - if (it.first == it.second) - { - insert(it.first, id); - } - } - }; - - struct PipelineInfo - { - llvm::StringRef name; - PipelineSet prev_pipelines; - PipelineSet next_pipelines; - pipeline_funt_t func = nullptr; - PipelineInfo* next = nullptr; - llvm::ArrayRef jumps; - bool visited = false; - bool iterating = false; - bool jump_target = false; - }; - - std::unordered_map pipelines_map; - - auto sink = [&](llvm::StringRef pipeline_name, - llvm::ArrayRef prev_pipelines, - llvm::ArrayRef next_pipelines, - llvm::ArrayRef jumps, - pipeline_funt_t func) - { - assert(!pipeline_name.empty()); - assert(nullptr != func); - auto i = get_pipeline(pipeline_name); - auto it = pipelines_map.insert({get_id(i), {}}); - if (!it.second) - { - report_error("Duplicated pipeline name"); - } - auto& info = it.first->second; - info.name = i; - info.func = func; - llvm::transform(prev_pipelines, std::back_inserter(info.prev_pipelines), get_pipeline); - llvm::transform(next_pipelines, std::back_inserter(info.next_pipelines), get_pipeline); - if (!jumps.empty()) - { - auto data = allocator.Allocate(jumps.size()); - llvm::transform(jumps, data, [&](llvm::StringRef str) - { - assert(!str.empty()); - return string_set.save(str); - }); - info.jumps = { data, jumps.size() }; - } - }; - - for (auto& p : pipelines) - { - assert(nullptr != p); - p(sink); - } - - auto get_pipeline_info = [&](llvm::StringRef name)->PipelineInfo& - { - auto id = get_id(name); - auto it = pipelines_map.find(id); - if (it == pipelines_map.end()) - { - report_error(llvm::Twine("Pipeline not found") + name); - } - return it->second; - }; - - // Make all deps bidirectional - for (auto name : pipelines_ordered) - { - auto& info = get_pipeline_info(name); - for (auto prev : info.prev_pipelines) - { - auto& prev_info = get_pipeline_info(prev); - prev_info.next_pipelines.push_back(name); - } - for (auto next : info.next_pipelines) - { - auto& next_info = get_pipeline_info(next); - next_info.prev_pipelines.push_back(name); - } - } - - // toposort - PipelineInfo* first_pipeline = nullptr; - PipelineInfo* current_pipeline = nullptr; - for (auto name : pipelines_ordered) - { - auto iter_func = [&](PipelineInfo& elem, auto func) - { - elem.iterating = true; - for (auto it : elem.prev_pipelines) - { - auto& info = get_pipeline_info(it); - if (info.iterating) - { - report_error(llvm::Twine("Pipeline depends on itself: ") + elem.name); - } - func(info); - } - elem.iterating = false; - }; - auto visit_func = [&](PipelineInfo& elem) - { - assert(nullptr == elem.next); - auto current = &elem; - if (nullptr == first_pipeline) - { - first_pipeline = current; - } - else - { - assert(nullptr != current_pipeline); - current_pipeline->next = current; - } - current_pipeline = current; - }; - topo_visit(get_pipeline_info(name), iter_func, visit_func); - } - - assert(nullptr != first_pipeline); - - auto iterate_pipelines = [&](auto func) - { - for (auto current = first_pipeline; nullptr != current; - current = current->next) - { - func(*current); - } - }; - - iterate_pipelines([&](PipelineInfo& pipeline) - { - if (!pipeline.jumps.empty()) - { - for (auto jump : pipeline.jumps) - { - get_pipeline_info(jump).jump_target = true; - } - if (nullptr != pipeline.next) - { - pipeline.next->jump_target = true; - } - } - }); - - llvm::SmallVector funcs; - llvm::StringRef current_name = first_pipeline->name; - llvm::ArrayRef current_jumps; - result_sink([&](auto add_stage) - { - auto flush_stages = [&]() - { - if (!funcs.empty()) - { - assert(!current_name.empty()); - auto flusher = [&](mlir::OpPassManager& pm) - { - for (auto f : funcs) - { - f(pm); - } - }; - add_stage(current_name, current_jumps, flusher); - funcs.clear(); - current_name = {}; - current_jumps = {}; - } - assert(current_name.empty()); - assert(current_jumps.empty()); - }; - iterate_pipelines([&](PipelineInfo& pipeline) - { - if (pipeline.jump_target) - { - flush_stages(); - current_name = pipeline.name; - } - funcs.emplace_back(pipeline.func); - current_jumps = pipeline.jumps; - }); - flush_stages(); - }); -} diff --git a/mlir-compiler/plier/src/dialect.cpp b/mlir-compiler/plier/src/dialect.cpp deleted file mode 100644 index 892301c4215..00000000000 --- a/mlir-compiler/plier/src/dialect.cpp +++ /dev/null @@ -1,509 +0,0 @@ -#include "plier/dialect.hpp" - -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include "plier/transforms/const_utils.hpp" - -namespace -{ -struct PLierInlinerInterface : public mlir::DialectInlinerInterface -{ - using mlir::DialectInlinerInterface::DialectInlinerInterface; - bool isLegalToInline(mlir::Region *, mlir::Region *, bool, - mlir::BlockAndValueMapping &) const final override - { - return true; - } - bool isLegalToInline(mlir::Operation *op, mlir::Region *, bool, - mlir::BlockAndValueMapping &) const final override - { - return !mlir::isa(op); - } -}; -} - -namespace plier -{ - -llvm::StringRef attributes::getFastmathName() -{ - return "#plier.fastmath"; -} - -llvm::StringRef attributes::getJumpMarkersName() -{ - return "#plier.pipeline_jump_markers"; -} - -llvm::StringRef attributes::getParallelName() -{ - return "#plier.parallel"; -} - -llvm::StringRef attributes::getMaxConcurrencyName() -{ - return "#plier.max_concurrency"; -} - -llvm::StringRef attributes::getForceInlineName() -{ - return "#plier.force_inline"; -} - - -namespace detail -{ -struct PyTypeStorage : public mlir::TypeStorage -{ - using KeyTy = mlir::StringRef; - - PyTypeStorage(mlir::StringRef name): name(name) {} - - bool operator==(const KeyTy& key) const - { - return key == name; - } - - static PyTypeStorage* construct(mlir::TypeStorageAllocator& allocator, - const KeyTy& key) - { - return new(allocator.allocate()) - PyTypeStorage(allocator.copyInto(key)); - } - - mlir::StringRef name; -}; -} - -void PlierDialect::initialize() -{ - addOperations< -#define GET_OP_LIST -#include "plier/PlierOps.cpp.inc" - >(); - addTypes(); - addInterfaces(); -} - -mlir::Type PlierDialect::parseType(mlir::DialectAsmParser &parser) const { - parser.emitError(parser.getNameLoc(), "unknown type"); - return mlir::Type(); -} - -void PlierDialect::printType(mlir::Type type, mlir::DialectAsmPrinter &os) const { - llvm::TypeSwitch(type) - .Case([&](auto t){ os << "PyType<" << t.getName() << ">"; }) - .Default([](auto){ llvm_unreachable("unexpected type"); }); -} - -PyType PyType::get(mlir::MLIRContext* context, llvm::StringRef name) -{ - assert(!name.empty()); - return Base::get(context, name); -} - -PyType PyType::getUndefined(mlir::MLIRContext* context) -{ - return Base::get(context, ""); -} - -PyType PyType::getNone(mlir::MLIRContext* context) -{ - return Base::get(context, "none"); -} - -llvm::StringRef PyType::getName() const -{ - return getImpl()->name; -} - -void ArgOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - unsigned index, mlir::StringRef name) { - ArgOp::build(builder, state, PyType::getUndefined(state.getContext()), - index, name); -} - -mlir::OpFoldResult ArgOp::fold(llvm::ArrayRef /*operands*/) -{ - auto func = getOperation()->getParentOfType(); - if (func) - { - auto ind = index(); - if (ind < func.getNumArguments() && - func.getArgument(ind).getType() == getType()) - { - return func.getArgument(ind); - } - } - return nullptr; -} - -void ConstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - - mlir::Attribute val) { - ConstOp::build(builder, state, PyType::getUndefined(state.getContext()), - val); -} - -void GlobalOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::StringRef name) { - GlobalOp::build(builder, state, PyType::getUndefined(state.getContext()), - name); -} - -void BinOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value lhs, mlir::Value rhs, mlir::StringRef op) { - BinOp::build(builder, state, PyType::getUndefined(state.getContext()), lhs, - rhs, op); -} - -void UnaryOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value value, mlir::StringRef op) { - UnaryOp::build(builder, state, PyType::getUndefined(state.getContext()), - value, op); -} - -mlir::OpFoldResult CastOp::fold(llvm::ArrayRef /*operands*/) -{ - auto op_type = getOperand().getType(); - auto ret_type = getType(); - if (op_type == ret_type && op_type != PyType::getUndefined(getContext())) - { - return getOperand(); - } - return nullptr; -} - -void PyCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::Value func, - llvm::StringRef func_name, mlir::ValueRange args, - mlir::ArrayRef> kwargs) { - auto ctx = builder.getContext(); - mlir::SmallVector all_args; - all_args.reserve(args.size() + kwargs.size()); - std::copy(args.begin(), args.end(), std::back_inserter(all_args)); - auto kw_start = static_cast(all_args.size()); - mlir::SmallVector kw_names; - kw_names.reserve(kwargs.size()); - for (auto& a : kwargs) - { - kw_names.push_back(mlir::StringAttr::get(ctx, a.first)); - all_args.push_back(a.second); - } - PyCallOp::build(builder, state, PyType::getUndefined(state.getContext()), - func, all_args, func_name, kw_start, mlir::ArrayAttr::get(ctx, kw_names)); -} - -void BuildTupleOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - ::mlir::ValueRange args) -{ - BuildTupleOp::build(builder, state, - PyType::getUndefined(state.getContext()), args); -} - -//mlir::LogicalResult BuildTupleOp::fold( -// llvm::ArrayRef /*operands*/, -// llvm::SmallVectorImpl &results) -//{ -// auto res_types = getResultTypes(); -// auto args = getOperands(); -// if (res_types.size() == args.size()) -// { -// std::copy(args.begin(), args.end(), std::back_inserter(results)); -// return mlir::success(); -// } -// return mlir::failure(); -//} - -mlir::Value fold_build_tuple_getitem(mlir::Value val, mlir::Type type, llvm::ArrayRef operands) -{ - auto build_tuple = val.getDefiningOp(); - if (build_tuple) - { - if (auto val = operands[1].dyn_cast_or_null()) - { - auto index = val.getInt(); - if (index >= 0 && index < build_tuple.getNumOperands()) - { - auto op = build_tuple.getOperand(static_cast(index)); - if (op.getType() == type) - { - return op; - } - } - } - } - return {}; -} - -void GetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - ::mlir::Value value, ::mlir::Value index) -{ - GetItemOp::build(builder, state, - PyType::getUndefined(state.getContext()), value, index); -} - -mlir::OpFoldResult GetItemOp::fold(llvm::ArrayRef operands) -{ - if (auto val = fold_build_tuple_getitem(value(), getType(), operands)) - { - return val; - } - return nullptr; -} - -void StaticGetItemOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - ::mlir::Value value, ::mlir::Value index_var, - unsigned int index) -{ - StaticGetItemOp::build(builder, state, - PyType::getUndefined(state.getContext()), - value, index_var, index); -} - -mlir::OpFoldResult StaticGetItemOp::fold(llvm::ArrayRef operands) -{ - if (auto val = fold_build_tuple_getitem(value(), getType(), operands)) - { - return val; - } - return nullptr; -} - -void GetiterOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - ::mlir::Value value) -{ - GetiterOp::build(builder, state, PyType::getUndefined(state.getContext()), - value); -} - -void IternextOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - ::mlir::Value value) -{ - IternextOp::build(builder, state, PyType::getUndefined(state.getContext()), - value); -} - -void PairfirstOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - ::mlir::Value value) -{ - PairfirstOp::build(builder, state, PyType::getUndefined(state.getContext()), - value); -} - -//mlir::OpFoldResult PairfirstOp::fold(llvm::ArrayRef /*operands*/) -//{ -// if (getNumOperands() == 2) -// { -// return getOperand(0); -// } -// return nullptr; -//} - -void PairsecondOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - ::mlir::Value value) -{ - PairsecondOp::build(builder, state, - PyType::getUndefined(state.getContext()), value); -} - -//mlir::OpFoldResult PairsecondOp::fold(llvm::ArrayRef /*operands*/) -//{ -// if (getNumOperands() == 2) -// { -// return getOperand(1); -// } -// return nullptr; -//} - -void GetattrOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value value, mlir::StringRef name) { - GetattrOp::build(builder, state, PyType::getUndefined(state.getContext()), - value, name); -} - -namespace -{ -struct GetattrGlobalRewrite : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - GetattrOp op, mlir::PatternRewriter &rewriter) const override - { - auto prev_op = mlir::dyn_cast_or_null(op.getOperand().getDefiningOp()); - if (prev_op) - { - auto new_name = llvm::Twine(prev_op.name() + "." + op.name()).str(); - auto new_op = rewriter.create(op.getLoc(), op.getType(), new_name); - rewriter.replaceOp(op, new_op.getResult()); - return mlir::success(); - } - return mlir::failure(); - } -}; -} - -void GetattrOp::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) -{ - results.insert(context); -} - -void EnforceShapeOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value value, mlir::ValueRange shape) { - EnforceShapeOp::build(builder, state, value.getType(), value, shape); -} - -mlir::OpFoldResult EnforceShapeOp::fold(llvm::ArrayRef operands) { - operands = operands.drop_front(); - auto num_dims = static_cast(operands.size()); - auto src_type = getType().cast(); - llvm::SmallVector final_shape(num_dims, -1); - if (src_type.hasRank()) - { - auto shape = src_type.getShape(); - if (shape.size() != num_dims) - { - return nullptr; - } - final_shape.assign(shape.begin(), shape.end()); - } - bool changed = false; - for (unsigned i = 0; i < num_dims; ++i) - { - if (auto attr = operands[i].dyn_cast_or_null()) - { - auto val = attr.getInt(); - if (val != -1) - { - if (final_shape[i] != -1) - { - if (final_shape[i] != val) - { - return nullptr; - } - } - else - { - changed = true; - final_shape[i] = val; - } - } - } - } - - if (changed) - { - auto final_type = mlir::RankedTensorType::get(final_shape, src_type.getElementType()); - result().setType(final_type); - return result(); - } - return nullptr; -} - -namespace -{ -struct EnforceShapeDim : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::DimOp op, mlir::PatternRewriter &rewriter) const override - { - auto enforce_op = mlir::dyn_cast_or_null(op.memrefOrTensor().getDefiningOp()); - if (!enforce_op) - { - return mlir::failure(); - } - auto const_ind = plier::getConstVal(op.index()); - if (!const_ind) - { - return mlir::failure(); - } - auto index = const_ind.getInt(); - if (index < 0 || index >= static_cast(enforce_op.sizes().size())) - { - return mlir::failure(); - } - - rewriter.replaceOp(op, enforce_op.sizes()[static_cast(index)]); - return mlir::success(); - } -}; -} - -void EnforceShapeOp::getCanonicalizationPatterns( - ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) -{ - results.insert(context); -} - -void RetainOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, - mlir::Value value) { - RetainOp::build(builder, state, value.getType(), value); -} - -mlir::LogicalResult ParallelOp::moveOutOfLoop(mlir::ArrayRef ops) -{ - for (mlir::Operation *op : ops) - { - op->moveBefore(*this); - } - return mlir::success(); -} - -mlir::Region &ParallelOp::getLoopBody() { return region(); } - -bool ParallelOp::isDefinedOutsideOfLoop(mlir::Value value) -{ - return !region().isAncestor(value.getParentRegion()); -} - -void ParallelOp::build( - mlir::OpBuilder &odsBuilder, mlir::OperationState &odsState, - mlir::ValueRange lowerBounds, mlir::ValueRange upperBounds, mlir::ValueRange steps, - mlir::function_ref bodyBuilder) { - assert(lowerBounds.size() == upperBounds.size()); - assert(lowerBounds.size() == steps.size()); - odsState.addOperands(lowerBounds); - odsState.addOperands(upperBounds); - odsState.addOperands(steps); - odsState.addAttribute( - ParallelOp::getOperandSegmentSizeAttr(), - odsBuilder.getI32VectorAttr({static_cast(lowerBounds.size()), - static_cast(upperBounds.size()), - static_cast(steps.size())})); - auto bodyRegion = odsState.addRegion(); - auto count = lowerBounds.size(); - mlir::OpBuilder::InsertionGuard guard(odsBuilder); - llvm::SmallVector argTypes(count * 2 + 1, odsBuilder.getIndexType()); - auto *bodyBlock = odsBuilder.createBlock(bodyRegion, {}, argTypes); - - if (bodyBuilder) - { - odsBuilder.setInsertionPointToStart(bodyBlock); - auto args = bodyBlock->getArguments(); - bodyBuilder(odsBuilder, odsState.location, - args.take_front(count), - args.drop_front(count).take_front(count), - args.back()); - ParallelOp::ensureTerminator(*bodyRegion, odsBuilder, odsState.location); - } -} - -} - -#define GET_OP_CLASSES -#include "plier/PlierOps.cpp.inc" - -#include "plier/PlierOpsEnums.cpp.inc" diff --git a/mlir-compiler/plier/src/rewrites/call_lowering.cpp b/mlir-compiler/plier/src/rewrites/call_lowering.cpp deleted file mode 100644 index 29bd46e1627..00000000000 --- a/mlir-compiler/plier/src/rewrites/call_lowering.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "plier/rewrites/call_lowering.hpp" - -plier::CallOpLowering::CallOpLowering( - mlir::TypeConverter&, mlir::MLIRContext* context, - CallOpLowering::resolver_t resolver): - OpRewritePattern(context), resolver(resolver) {} - -mlir::LogicalResult plier::CallOpLowering::matchAndRewrite(plier::PyCallOp op, mlir::PatternRewriter& rewriter) const -{ - auto operands = op.getOperands(); - if (operands.empty()) - { - return mlir::failure(); - } - auto func_type = operands[0].getType(); - if (!func_type.isa()) - { - return mlir::failure(); - } - - llvm::SmallVector args; - llvm::SmallVector> kwargs; - auto getattr = mlir::dyn_cast_or_null(operands[0].getDefiningOp()); - if (getattr) - { - args.push_back(getattr.getOperand()); - } - auto kw_start = op.kw_start(); - operands = operands.drop_front(); - llvm::copy(operands.take_front(kw_start), std::back_inserter(args)); - for (auto it : llvm::zip(operands.drop_front(kw_start), op.kw_names())) - { - auto arg = std::get<0>(it); - auto name = std::get<1>(it).cast(); - kwargs.emplace_back(name.getValue(), arg); - } - - return resolver(op, op.func_name(), args, kwargs, rewriter); -} diff --git a/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp b/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp deleted file mode 100644 index afb9e789d4d..00000000000 --- a/mlir-compiler/plier/src/rewrites/canonicalize_reductions.cpp +++ /dev/null @@ -1,229 +0,0 @@ -#include "plier/rewrites/canonicalize_reductions.hpp" - -#include -#include -#include - -#include "plier/transforms/block_utils.hpp" - -namespace -{ -bool checkMemrefType(mlir::Value value) -{ - if (auto type = value.getType().dyn_cast()) - { -// auto shape = type.getShape(); -// return shape.empty() || (1 == shape.size() && 1 == shape[0]); - return true; - } - return false; -} - -bool isOutsideBlock(mlir::ValueRange values, mlir::Block& block) -{ - auto blockArgs = block.getArguments(); - for (auto val : values) - { - if (llvm::find(blockArgs, val) != blockArgs.end()) - { - return false; - } - auto op = val.getDefiningOp(); - if (op && block.findAncestorOpInBlock(*op)) - { - return false; - } - } - return true; -} - -bool checkForPotentialAliases(mlir::Value value, mlir::Operation* parent) -{ - assert(parent->getRegions().size() == 1); - assert(llvm::hasNItems(parent->getRegions().front(), 1)); - if (auto effects = mlir::dyn_cast_or_null(value.getDefiningOp())) - { - if (!effects.onlyHasEffect()) - { - return false; - } - } - else - { - return false; - } - - mlir::LoadOp load; - mlir::StoreOp store; - auto& parentBlock = parent->getRegions().front().front(); - for (auto user : value.getUsers()) - { - if (mlir::isa(user)) - { - // TODO: very conservative - return false; - } - auto relation = plier::relativeTo(user, parent); - if (plier::OpRelation::Unknown == relation) - { - return false; - } - if (plier::OpRelation::In == relation) - { - if (auto effects = mlir::dyn_cast_or_null(user)) - { - if (user->getBlock() != &parentBlock) - { - return false; - } - if (effects.hasEffect()) - { - if (load || !mlir::isa(user)) - { - return false; - } - load = mlir::cast(user); - } - if (effects.hasEffect()) - { - if (store || !mlir::isa(user)) - { - return false; - } - store = mlir::cast(user); - } - } - } - } - if (!load || !store || !load->isBeforeInBlock(store) || - load.indices() != store.indices() || !isOutsideBlock(load.indices(), parentBlock)) - { - return false; - } - return true; -} - -bool checkSupportedOps(mlir::Value value, mlir::Operation* parent) -{ - for (auto user : value.getUsers()) - { - if (user->getParentOp() == parent && !mlir::isa(user)) - { - return false; - } - } - return true; -} - -bool checkMemref(mlir::Value value, mlir::Operation* parent) -{ - return checkMemrefType(value) && - checkForPotentialAliases(value, parent) && - checkSupportedOps(value, parent); -} - -mlir::Value createScalarLoad(mlir::PatternRewriter &builder, mlir::Location loc, mlir::Value memref, mlir::ValueRange indices) -{ - return builder.create(loc, memref, indices); -} - -void createScalarStore( - mlir::PatternRewriter &builder, mlir::Location loc, mlir::Value val, - mlir::Value memref, mlir::ValueRange indices) -{ - builder.create(loc, val, memref, indices); -} -} - -mlir::LogicalResult plier::CanonicalizeReduction::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const -{ - llvm::SmallVector> to_process; - for (auto& current : op.getLoopBody().front()) - { - if (auto load = mlir::dyn_cast(current)) - { - auto memref = load.memref(); - if (checkMemref(memref, op)) - { - to_process.push_back({memref, load.indices()}); - } - } - } - - if (!to_process.empty()) - { - auto loc = op.getLoc(); - auto init_args = llvm::to_vector<8>(op.initArgs()); - for (auto it : to_process) - { - init_args.emplace_back(createScalarLoad(rewriter, loc, it.first, it.second)); - } - auto prev_args_offset = op.initArgs().size(); - auto body = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value iter, mlir::ValueRange iter_vals) - { - auto& old_body = op.getLoopBody().front(); - mlir::BlockAndValueMapping mapping; - mapping.map(old_body.getArguments().front(), iter); - mapping.map(old_body.getArguments().drop_front(), iter_vals); - auto yield_args = llvm::to_vector<8>(iter_vals); - for (auto& body_op : old_body.without_terminator()) - { - auto invalid_index = static_cast(-1); - auto get_iter_index = [&](auto op)->unsigned - { - auto arg = op.memref(); - for (auto it : llvm::enumerate(llvm::make_first_range(to_process))) - { - if (arg == it.value()) - { - return static_cast(it.index() + prev_args_offset); - } - } - return invalid_index; - }; - if (auto load = mlir::dyn_cast(body_op)) - { - auto index = get_iter_index(load); - if (index != invalid_index) - { - mapping.map(body_op.getResults().front(), yield_args[index]); - } - else - { - builder.clone(body_op, mapping); - } - } - else if (auto store = mlir::dyn_cast(body_op)) - { - auto index = get_iter_index(store); - if (index != invalid_index) - { - yield_args[index] = mapping.lookup(store.value()); - } - else - { - builder.clone(body_op, mapping); - } - } - else - { - builder.clone(body_op, mapping); - } - } - auto yield = mlir::cast(old_body.getTerminator()); - llvm::copy(yield.results(), yield_args.begin()); - builder.create(loc, yield_args); - }; - auto results = rewriter.create(loc, op.lowerBound(), op.upperBound(), op.step(), init_args, body).results(); - for (auto it : llvm::enumerate(to_process)) - { - auto index = prev_args_offset + it.index(); - auto result = results[static_cast(index)]; - createScalarStore(rewriter, loc, result, it.value().first, it.value().second); - } - rewriter.replaceOp(op, results.take_front(prev_args_offset)); - return mlir::success(); - } - - return mlir::failure(); -} diff --git a/mlir-compiler/plier/src/rewrites/cast_lowering.cpp b/mlir-compiler/plier/src/rewrites/cast_lowering.cpp deleted file mode 100644 index b30ad1da4a1..00000000000 --- a/mlir-compiler/plier/src/rewrites/cast_lowering.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "plier/rewrites/cast_lowering.hpp" - -#include - -plier::CastOpLowering::CastOpLowering( - mlir::TypeConverter& typeConverter, mlir::MLIRContext* context, - CastOpLowering::cast_t cast_func): - OpRewritePattern(context), converter(typeConverter), - cast_func(std::move(cast_func)) {} - -mlir::LogicalResult plier::CastOpLowering::matchAndRewrite( - plier::CastOp op, mlir::PatternRewriter& rewriter) const -{ - auto src = op.getOperand(); - auto src_type = src.getType(); - auto dst_type = converter.convertType(op.getType()); - if (dst_type) - { - if (src_type == dst_type) - { - rewriter.replaceOp(op, src); - return mlir::success(); - } - if (nullptr != cast_func) - { - if (auto new_op = cast_func(dst_type, src, rewriter)) - { - rewriter.replaceOp(op, new_op); - return mlir::success(); - } - } - } - return mlir::failure(); -} diff --git a/mlir-compiler/plier/src/rewrites/common_opts.cpp b/mlir-compiler/plier/src/rewrites/common_opts.cpp deleted file mode 100644 index 84b5a168d1d..00000000000 --- a/mlir-compiler/plier/src/rewrites/common_opts.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "plier/rewrites/common_opts.hpp" - -#include "plier/rewrites/force_inline.hpp" -#include "plier/rewrites/index_type_propagation.hpp" -#include "plier/rewrites/loop_rewrites.hpp" -#include "plier/rewrites/memory_rewrites.hpp" -#include "plier/rewrites/cse.hpp" -#include "plier/rewrites/if_rewrites.hpp" - -#include -#include -#include -#include - -void plier::populate_common_opts_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns) -{ - for (auto *op : context.getRegisteredOperations()) - { - op->getCanonicalizationPatterns(patterns, &context); - } - - patterns.insert< - // LoopInvariantCodeMotion, TODO - plier::ForceInline, - plier::CmpLoopBoundsSimplify, - SimplifyEmptyIf, - plier::IfOpConstCond, - SimplifySelect, - SimplifySelectEq, - plier::CSERewrite, - PromoteLoads, - SingeWriteMemref - >(&context); - - plier::populate_index_propagate_patterns(context, patterns); -} diff --git a/mlir-compiler/plier/src/rewrites/cse.cpp b/mlir-compiler/plier/src/rewrites/cse.cpp deleted file mode 100644 index ba38abb7aad..00000000000 --- a/mlir-compiler/plier/src/rewrites/cse.cpp +++ /dev/null @@ -1,109 +0,0 @@ -#include "plier/rewrites/cse.hpp" - -#include -#include -#include - -#include -#include - -namespace -{ -struct SimpleOperationInfo : public llvm::DenseMapInfo { - static unsigned getHashValue(const mlir::Operation *opC) { - return static_cast(mlir::OperationEquivalence::computeHash(const_cast(opC))); - } - static bool isEqual(const mlir::Operation *lhsC, const mlir::Operation *rhsC) { - auto *lhs = const_cast(lhsC); - auto *rhs = const_cast(rhsC); - if (lhs == rhs) - return true; - if (lhs == getTombstoneKey() || lhs == getEmptyKey() || - rhs == getTombstoneKey() || rhs == getEmptyKey()) - return false; - return mlir::OperationEquivalence::isEquivalentTo(const_cast(lhsC), - const_cast(rhsC)); - } -}; - -using AllocatorTy = llvm::RecyclingAllocator< - llvm::BumpPtrAllocator, - llvm::ScopedHashTableVal>; -using ScopedMapTy = llvm::ScopedHashTable; - -template -mlir::LogicalResult simplifyRegion(ScopedMapTy& map, mlir::Region& region, mlir::PatternRewriter& rewriter) -{ - if (region.empty() || std::next(region.begin()) != region.end()) - { - return mlir::failure(); - } - - bool success = false; - for (auto &inst : llvm::make_early_inc_range(region.front())) - { - if (inst.hasTrait()) - { - break; - } - if (!mlir::MemoryEffectOpInterface::hasNoEffect(&inst)) - { - continue; - } - if (!inst.getRegions().empty()) - { - if (Recursive && !inst.hasTrait()) - { - for (auto& reg : inst.getRegions()) - { - ScopedMapTy::ScopeTy scope(map); - if (mlir::succeeded(simplifyRegion(map, reg, rewriter))) - { - success = true; - } - } - } - else - { - for (auto& reg : inst.getRegions()) - { - ScopedMapTy new_map; - ScopedMapTy::ScopeTy scope(new_map); - if (mlir::succeeded(simplifyRegion(new_map, reg, rewriter))) - { - success = true; - } - } - } - continue; - } - - auto* previous_op = map.lookup(&inst); - if (previous_op != nullptr) - { - rewriter.replaceOp(&inst, previous_op->getResults()); - success = true; - } - else - { - map.insert(&inst, &inst); - } - } - return mlir::success(success); -} -} - -mlir::LogicalResult plier::detail::applyCSE(mlir::Region& region, mlir::PatternRewriter& rewriter, bool recusive) -{ - ScopedMapTy map; - ScopedMapTy::ScopeTy scope(map); - if (recusive) - { - return simplifyRegion(map, region, rewriter); - } - else - { - return simplifyRegion(map, region, rewriter); - } -} diff --git a/mlir-compiler/plier/src/rewrites/force_inline.cpp b/mlir-compiler/plier/src/rewrites/force_inline.cpp deleted file mode 100644 index a6b0b5440f2..00000000000 --- a/mlir-compiler/plier/src/rewrites/force_inline.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "plier/rewrites/force_inline.hpp" - -#include -#include - -#include "plier/dialect.hpp" - -mlir::LogicalResult plier::ForceInline::matchAndRewrite(mlir::CallOp op, mlir::PatternRewriter& rewriter) const -{ - auto attr_name = plier::attributes::getForceInlineName(); - auto mod = op->getParentOfType(); - assert(mod); - auto func = mod.lookupSymbol(op.callee()); - if (!func) - { - return mlir::failure(); - } - if (!op->hasAttr(attr_name) && - !func->hasAttr(attr_name)) - { - return mlir::failure(); - } - - if (!llvm::hasNItems(func.getRegion(), 1)) - { - return mlir::failure(); - } - mlir::InlinerInterface inliner_interface(op->getContext()); - auto parent = op->getParentOp(); - rewriter.startRootUpdate(parent); - auto res = mlir::inlineCall(inliner_interface, op, func, &func.getRegion()); - if (mlir::succeeded(res)) - { - assert(op->getUsers().empty()); - rewriter.eraseOp(op); - rewriter.finalizeRootUpdate(parent); - } - else - { - rewriter.cancelRootUpdate(parent); - } - return res; -} diff --git a/mlir-compiler/plier/src/rewrites/if_rewrites.cpp b/mlir-compiler/plier/src/rewrites/if_rewrites.cpp deleted file mode 100644 index 86e14a416ff..00000000000 --- a/mlir-compiler/plier/src/rewrites/if_rewrites.cpp +++ /dev/null @@ -1,140 +0,0 @@ -#include "plier/rewrites/if_rewrites.hpp" - -#include -#include - -mlir::LogicalResult plier::IfOpConstCond::matchAndRewrite(mlir::scf::IfOp op, mlir::PatternRewriter& rewriter) const -{ - auto cond = mlir::dyn_cast_or_null(op.condition().getDefiningOp()); - if (!cond) - { - return mlir::failure(); - } - auto is_const = [](mlir::Value val) - { - if (auto parent = val.getDefiningOp()) - { - return parent->hasTrait(); - } - return false; - }; - - auto replace = [&](mlir::Block& block, mlir::Value to_replace, mlir::Value new_val) - { - for (auto& use : llvm::make_early_inc_range(to_replace.getUses())) - { - auto owner = use.getOwner(); - if (block.findAncestorOpInBlock(*owner)) - { - rewriter.updateRootInPlace(owner, [&]() - { - use.set(new_val); - }); - } - } - }; - - mlir::Value const_val; - mlir::Value to_replace; - if (is_const(cond.lhs())) - { - const_val = cond.lhs(); - to_replace = cond.rhs(); - } - else if (is_const(cond.rhs())) - { - const_val = cond.rhs(); - to_replace = cond.lhs(); - } - else - { - return mlir::failure(); - } - - if (cond.predicate() == mlir::CmpIPredicate::eq) - { - replace(op.thenRegion().front(), to_replace, const_val); - } - else if (cond.predicate() == mlir::CmpIPredicate::ne) - { - replace(op.elseRegion().front(), to_replace, const_val); - } - else - { - return mlir::failure(); - } - - return mlir::success(); -} - -mlir::LogicalResult plier::SimplifyEmptyIf::matchAndRewrite(mlir::scf::IfOp op, mlir::PatternRewriter& rewriter) const -{ - if (op.getNumResults() == 0 || op.elseRegion().empty()) - { - return mlir::failure(); - } - if (!llvm::hasNItems(op.thenRegion().front(), 1) || - !llvm::hasNItems(op.elseRegion().front(), 1)) - { - return mlir::failure(); - } - auto then_yield_args = mlir::cast(op.thenRegion().front().getTerminator()).getOperands(); - auto else_yield_args = mlir::cast(op.elseRegion().front().getTerminator()).getOperands(); - for (auto it : llvm::zip(then_yield_args, else_yield_args)) - { - if (std::get<0>(it) != std::get<1>(it)) - { - return mlir::failure(); - } - } - llvm::SmallVector args(then_yield_args.begin(), then_yield_args.end()); - assert(args.size() == op.getNumResults()); - rewriter.replaceOp(op, args); - return mlir::success(); -} - -mlir::LogicalResult plier::SimplifySelect::matchAndRewrite(mlir::SelectOp op, mlir::PatternRewriter& rewriter) const -{ - auto true_val = op.getTrueValue(); - auto false_val = op.getFalseValue(); - if (true_val == false_val) - { - rewriter.replaceOp(op, true_val); - return mlir::success(); - } - return mlir::failure(); -} - -mlir::LogicalResult plier::SimplifySelectEq::matchAndRewrite(mlir::SelectOp op, mlir::PatternRewriter& rewriter) const -{ - auto cond = mlir::dyn_cast_or_null(op.condition().getDefiningOp()); - if (!cond) - { - return mlir::failure(); - } - if (cond.predicate() != mlir::CmpIPredicate::eq && - cond.predicate() != mlir::CmpIPredicate::ne) - { - return mlir::failure(); - } - - auto cond_lhs = cond.lhs(); - auto cond_rhs = cond.rhs(); - - auto true_val = op.getTrueValue(); - auto false_val = op.getFalseValue(); - - if (cond.predicate() == mlir::CmpIPredicate::ne) - { - std::swap(true_val, false_val); - } - - if ((cond_lhs == true_val && cond_rhs == false_val) || - (cond_rhs == true_val && cond_lhs == false_val)) - { - rewriter.replaceOp(op, false_val); - return mlir::success(); - } - - return mlir::failure(); -} diff --git a/mlir-compiler/plier/src/rewrites/index_type_propagation.cpp b/mlir-compiler/plier/src/rewrites/index_type_propagation.cpp deleted file mode 100644 index 9531bd57ebf..00000000000 --- a/mlir-compiler/plier/src/rewrites/index_type_propagation.cpp +++ /dev/null @@ -1,161 +0,0 @@ -#include "plier/rewrites/index_type_propagation.hpp" - -#include -#include - -namespace -{ -bool is_index_compatible(mlir::Type lhs_type, mlir::Type rhs_type) -{ - if (!lhs_type.isa() || lhs_type != rhs_type) - { - return false; - } - - if (lhs_type.cast().getWidth() < 64) - { - return false; - } - return true; -} - -template -struct ArithIndexCastSimplify : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - Op op, mlir::PatternRewriter &rewriter) const override - { - auto lhs_type = op.lhs().getType(); - auto rhs_type = op.rhs().getType(); - if (!is_index_compatible(lhs_type, rhs_type)) - { - return mlir::failure(); - } - - auto get_cast = [](mlir::Value val)->mlir::Value - { - if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) - { - return op.getOperand(); - } - return {}; - }; - - auto get_const = [](mlir::Value val)->mlir::IntegerAttr - { - if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) - { - return op.getValue().cast(); - } - return {}; - }; - - auto lhs = get_cast(op.lhs()); - auto rhs = get_cast(op.rhs()); - auto lhs_const = get_const(op.lhs()); - auto rhs_const = get_const(op.rhs()); - if (lhs && rhs) - { - auto new_op = rewriter.create(op.getLoc(), lhs, rhs); - auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); - rewriter.replaceOp(op, result.getResult()); - return mlir::success(); - } - if (lhs && rhs_const) - { - auto new_const = rewriter.create(op.getLoc(), rhs_const.getInt()); - auto new_op = rewriter.create(op.getLoc(), lhs, new_const); - auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); - rewriter.replaceOp(op, result.getResult()); - return mlir::success(); - } - if (lhs_const && rhs) - { - auto new_const = rewriter.create(op.getLoc(), lhs_const.getInt()); - auto new_op = rewriter.create(op.getLoc(), new_const, rhs); - auto result = rewriter.create(op.getLoc(), new_op.getResult(), lhs_type); - rewriter.replaceOp(op, result.getResult()); - return mlir::success(); - } - - return mlir::failure(); - } -}; - -struct CmpIndexCastSimplify : public mlir::OpRewritePattern -{ - using mlir::OpRewritePattern::OpRewritePattern; - - mlir::LogicalResult matchAndRewrite( - mlir::CmpIOp op, mlir::PatternRewriter &rewriter) const override - { - auto lhs_type = op.lhs().getType(); - auto rhs_type = op.rhs().getType(); - if (!is_index_compatible(lhs_type, rhs_type)) - { - return mlir::failure(); - } - - auto get_cast = [](mlir::Value val)->mlir::Value - { - if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) - { - return op.getOperand(); - } - return {}; - }; - - auto get_const = [](mlir::Value val)->mlir::IntegerAttr - { - if (auto op = mlir::dyn_cast_or_null(val.getDefiningOp())) - { - return op.getValue().cast(); - } - return {}; - }; - - auto lhs = get_cast(op.lhs()); - auto rhs = get_cast(op.rhs()); - auto lhs_const = get_const(op.lhs()); - auto rhs_const = get_const(op.rhs()); - if (lhs && rhs) - { - auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), lhs, rhs); - rewriter.replaceOp(op, new_cmp.getResult()); - return mlir::success(); - } - if (lhs && rhs_const) - { - auto new_const = rewriter.create(op.getLoc(), rhs_const.getInt()); - auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), lhs, new_const); - rewriter.replaceOp(op, new_cmp.getResult()); - return mlir::success(); - } - if (lhs_const && rhs) - { - auto new_const = rewriter.create(op.getLoc(), lhs_const.getInt()); - auto new_cmp = rewriter.create(op.getLoc(), op.predicate(), new_const, rhs); - rewriter.replaceOp(op, new_cmp.getResult()); - return mlir::success(); - } - - return mlir::failure(); - } -}; -} - -void plier::populate_index_propagate_patterns(mlir::MLIRContext& context, mlir::OwningRewritePatternList& patterns) -{ - patterns.insert< - CmpIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify, - ArithIndexCastSimplify - >(&context); -} diff --git a/mlir-compiler/plier/src/rewrites/loop_rewrites.cpp b/mlir-compiler/plier/src/rewrites/loop_rewrites.cpp deleted file mode 100644 index 1637587ecf9..00000000000 --- a/mlir-compiler/plier/src/rewrites/loop_rewrites.cpp +++ /dev/null @@ -1,110 +0,0 @@ -#include "plier/rewrites/loop_rewrites.hpp" - -#include -#include - -namespace -{ -template -bool norm_impl2(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs) -{ - if (pred != SrcPred) - { - return false; - } - if (index != lhs) - { - std::swap(lhs, rhs); - pred = DstPred; - } - return true; -} - -template -bool norm_impl(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs) -{ - return norm_impl2(pred, index, lhs, rhs) || - norm_impl2(pred, index, lhs, rhs); -} - -enum EBound -{ - LowerBound, - UpperBound, -}; -template -llvm::Optional handler_impl(mlir::CmpIPredicate pred, mlir::Value lhs, mlir::Value rhs, mlir::Value index, mlir::Value lowerBound, mlir::Value upperBound) -{ - if (pred != Pred) - { - return {}; - } - auto bound = (Bound == LowerBound ? lowerBound : upperBound); - if(rhs == bound && lhs == index) - { - return Value; - } - return {}; -} -} - -mlir::LogicalResult plier::CmpLoopBoundsSimplify::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const -{ - auto index_var = op.getLoopBody().front().getArgument(0); - if (auto step_var = mlir::dyn_cast_or_null(op.step().getDefiningOp())) - { - assert(step_var.value().cast().getInt() > 0); - } - bool matched = false; - for (auto user : llvm::make_early_inc_range(index_var.getUsers())) - { - auto cmp = mlir::dyn_cast(user); - if (cmp) - { - auto pred = cmp.predicate(); - auto lhs = cmp.lhs(); - auto rhs = cmp.rhs(); - // Normalize index and predicate (index always on the left) - using norm_fptr_t = bool(*)(mlir::CmpIPredicate& pred, mlir::Value index, mlir::Value& lhs, mlir::Value& rhs); - using Predicate = mlir::CmpIPredicate; - const norm_fptr_t norm_handlers[] = { - &norm_impl, - &norm_impl, - &norm_impl, - &norm_impl, - &norm_impl, - &norm_impl, - }; - - for (auto h : norm_handlers) - { - if (h(pred, index_var, lhs, rhs)) - { - break; - } - } - - using fptr_t = llvm::Optional(*)(Predicate pred, mlir::Value lhs, mlir::Value rhs, mlir::Value index, mlir::Value lowerBound, mlir::Value upperBound); - const fptr_t handlers[] = { - &handler_impl, - &handler_impl, - &handler_impl, - &handler_impl, - }; - - for (auto h : handlers) - { - if (auto c = h(pred, lhs, rhs, index_var, op.lowerBound(), op.upperBound())) - { - auto type = rewriter.getI1Type(); - auto val = rewriter.getIntegerAttr(type, *c); - auto const_val = rewriter.create(cmp.getLoc(), val); - rewriter.replaceOp(cmp, const_val.getResult()); - matched = true; - break; - } - } - } - } - return mlir::success(matched); -} diff --git a/mlir-compiler/plier/src/rewrites/memory_rewrites.cpp b/mlir-compiler/plier/src/rewrites/memory_rewrites.cpp deleted file mode 100644 index c21bbe22e44..00000000000 --- a/mlir-compiler/plier/src/rewrites/memory_rewrites.cpp +++ /dev/null @@ -1,191 +0,0 @@ -#include "plier/rewrites/memory_rewrites.hpp" - -#include - -#include - -namespace -{ -bool isWrite(mlir::Operation& op) -{ - if (auto effects = mlir::dyn_cast(op)) - { - return effects.hasEffect(); - } - return false; -} - -bool isRead(mlir::Operation& op) -{ - if (auto effects = mlir::dyn_cast(op)) - { - return effects.hasEffect(); - } - return false; -} - -struct Result -{ - bool changed; - bool hasWrites; - bool hasReads; -}; - -Result promoteLoads(llvm::MutableArrayRef regions, mlir::PatternRewriter& rewriter) -{ - bool changed = false; - bool hasWrites = false; - bool hasReads = false; - bool storeDead = false; - for (auto& region : regions) - { - for (auto& block : region.getBlocks()) - { - mlir::StoreOp currentStore; - for (auto& op : llvm::make_early_inc_range(block)) - { - if (!op.getRegions().empty()) - { - auto res = promoteLoads(op.getRegions(), rewriter); - if (res.changed) - { - changed = true; - } - if (res.hasWrites) - { - currentStore = {}; - } - if (res.hasReads) - { - storeDead = false; - } - continue; - } - - if (auto load = mlir::dyn_cast(op)) - { - hasReads = true; - if (currentStore) - { - if (load.memref() == currentStore.memref() && - load.indices() == currentStore.indices()) - { - rewriter.replaceOp(&op, currentStore.value()); - changed = true; - } - else - { - storeDead = false; - } - } - } - else if (auto store = mlir::dyn_cast(op)) - { - if (currentStore && storeDead && - currentStore.memref() == store.memref() && - currentStore.indices() == store.indices()) - { - rewriter.eraseOp(currentStore); - } - hasWrites = true; - currentStore = store; - storeDead = true; - } - else if (isWrite(op)) - { - hasWrites = true; - currentStore = {}; - } - else if (isRead(op)) - { - hasReads = true; - storeDead = false; - } - else if(op.hasTrait()) - { - currentStore = {}; - hasWrites = true; - hasReads = true; - storeDead = false; - } - } - } - } - return Result{changed, hasWrites, hasReads}; -} - -bool checkIsSingleElementsMemref(mlir::ShapedType type) -{ - if (!type.hasRank()) - { - return false; - } - return llvm::all_of(type.getShape(), [](auto val) { return val == 1; }); -} -} - -mlir::LogicalResult plier::PromoteLoads::matchAndRewrite(mlir::FuncOp op, mlir::PatternRewriter& rewriter) const -{ - auto res = promoteLoads(op->getRegions(), rewriter); - return mlir::success(res.changed); -} - -mlir::LogicalResult plier::SingeWriteMemref::matchAndRewrite(mlir::StoreOp op, mlir::PatternRewriter& rewriter) const -{ - auto memref = op.memref(); - if (!checkIsSingleElementsMemref(memref.getType().cast())) - { - return mlir::failure(); - } - auto parent = memref.getDefiningOp(); - if (!mlir::isa_and_nonnull(parent)) - { - return mlir::failure(); - } - - mlir::StoreOp valueStore; - llvm::SmallVector loads; - for (auto user : memref.getUsers()) - { - if (auto store = mlir::dyn_cast(user)) - { - if (valueStore) - { - // More than one store - return mlir::failure(); - } - valueStore = store; - } - else if (auto load = mlir::dyn_cast(user)) - { - loads.emplace_back(load); - } - else if (mlir::isa(user)) - { - // nothing - } - else - { - // Unsupported op - return mlir::failure(); - } - } - - auto parentBlock = parent->getBlock(); - if (!valueStore || valueStore->getBlock() != parentBlock) - { - return mlir::failure(); - } - - auto val = valueStore.value(); - for (auto load : loads) - { - rewriter.replaceOp(load, val); - } - for (auto user : llvm::make_early_inc_range(parent->getUsers())) - { - rewriter.eraseOp(user); - } - rewriter.eraseOp(parent); - return mlir::success(); -} diff --git a/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp b/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp deleted file mode 100644 index 2c1b84391eb..00000000000 --- a/mlir-compiler/plier/src/rewrites/promote_to_parallel.cpp +++ /dev/null @@ -1,143 +0,0 @@ -#include "plier/rewrites/promote_to_parallel.hpp" - -#include -#include -#include - -#include "plier/dialect.hpp" - -namespace -{ -bool hasSideEffects(mlir::Operation *op) -{ - return op->walk([&](mlir::Operation *op) - { - if (auto effects = mlir::dyn_cast(op)) - { - if(effects.hasEffect()) - { - return mlir::WalkResult::interrupt(); - } - } -// if (op->hasTrait()) -// { -// return mlir::WalkResult::interrupt(); -// } - if (mlir::isa(op)) - { - return mlir::WalkResult::interrupt(); - } - return mlir::WalkResult::advance(); - }).wasInterrupted(); -} -} - -mlir::LogicalResult plier::PromoteToParallel::matchAndRewrite(mlir::scf::ForOp op, mlir::PatternRewriter& rewriter) const -{ - auto has_parallel_attr = op->hasAttr(plier::attributes::getParallelName()); - if (!has_parallel_attr && hasSideEffects(op)) - { - return mlir::failure(); - } - - auto& old_body = op.getLoopBody().front(); - auto old_yield = mlir::cast(old_body.getTerminator()); - auto reduce_args = old_body.getArguments().drop_front(); - llvm::SmallVector> reduce_bodies(reduce_args.size()); - llvm::DenseSet reduce_ops; - for (auto it : llvm::enumerate(reduce_args)) - { - auto reduce_arg = it.value(); - auto reduce_index = it.index(); - if (!reduce_arg.hasOneUse()) - { - return mlir::failure(); - } - auto reduce_op = *reduce_arg.user_begin(); - if (reduce_op->getNumOperands() != 2) - { - return mlir::failure(); - } - auto& reduce_body = reduce_bodies[reduce_index]; - while (true) - { - if (!reduce_op->hasOneUse()) - { - return mlir::failure(); - } - reduce_body.push_back(reduce_op); - reduce_ops.insert(reduce_op); - auto next_op = *reduce_op->user_begin(); - if (next_op == old_yield) - { - auto yield_operand = old_yield.getOperand(static_cast(reduce_index)); - if (yield_operand != reduce_op->getResult(0)) - { - return mlir::failure(); - } - break; - } - for (auto operand : next_op->getOperands()) - { - if (operand.getDefiningOp() != reduce_op && - operand.getParentBlock() == &old_body) - { - return mlir::failure(); - } - } - reduce_op = next_op; - } - } - - auto body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::ValueRange iter_vals, mlir::ValueRange temp) - { - assert(1 == iter_vals.size()); - assert(temp.empty()); - mlir::BlockAndValueMapping mapping; - mapping.map(old_body.getArguments().front(), iter_vals.front()); - for (auto& old_op : old_body.without_terminator()) - { - if (0 == reduce_ops.count(&old_op)) - { - builder.clone(old_op, mapping); - } - } - mlir::BlockAndValueMapping reduce_mapping; - for (auto it : llvm::enumerate(reduce_bodies)) - { - auto& reduce_body = it.value(); - assert(!reduce_body.empty()); - reduce_mapping = mapping; - auto first_op = reduce_body.front(); - assert(first_op->getNumOperands() == 2); - auto reduce_body_builder = [&](mlir::OpBuilder &builder, mlir::Location loc, mlir::Value val0, mlir::Value val1) - { - reduce_mapping.map(first_op->getOperand(0), val0); - reduce_mapping.map(first_op->getOperand(1), val1); - mlir::Operation* last_op = nullptr; - for (auto reduce_op : reduce_body) - { - last_op = builder.clone(*reduce_op, reduce_mapping); - assert(1 == last_op->getNumResults()); - } - builder.create(loc, last_op->getResult(0)); - }; - auto reduce_arg = reduce_args[it.index()]; - auto first_op_operands = first_op->getOperands(); - auto reduce_operand = (first_op_operands[0] == reduce_arg ? first_op_operands[1] : first_op_operands[0]); - assert(reduce_operand != reduce_arg); - reduce_operand = mapping.lookupOrDefault(reduce_operand); - assert(reduce_operand); - builder.create(loc, reduce_operand, reduce_body_builder); - } - }; - - auto parallel_op = rewriter.create(op.getLoc(), op.lowerBound(), op.upperBound(), op.step(), op.initArgs(), body_builder); - if (has_parallel_attr) - { - parallel_op->setAttr(plier::attributes::getParallelName(), rewriter.getUnitAttr()); - } - rewriter.replaceOp(op, parallel_op.getResults()); - - return mlir::success(); -} diff --git a/mlir-compiler/plier/src/rewrites/type_conversion.cpp b/mlir-compiler/plier/src/rewrites/type_conversion.cpp deleted file mode 100644 index 76652d9e4ab..00000000000 --- a/mlir-compiler/plier/src/rewrites/type_conversion.cpp +++ /dev/null @@ -1,185 +0,0 @@ -#include "plier/rewrites/type_conversion.hpp" - -#include -#include - -#include "plier/dialect.hpp" - -namespace -{ -mlir::LogicalResult setBlockSig( - mlir::Block& block, mlir::OpBuilder& builder, - const mlir::TypeConverter::SignatureConversion& conversion) -{ - if (conversion.getConvertedTypes().size() != block.getNumArguments()) - { - return mlir::failure(); - } - unsigned i = 0; - for (auto it : llvm::zip(block.getArguments(), conversion.getConvertedTypes())) - { - auto arg = std::get<0>(it); - auto type = std::get<1>(it); - if (arg.getType() != type) - { - builder.setInsertionPointToStart(&block); - auto res = builder.create(builder.getUnknownLoc(), arg.getType(), arg); - arg.replaceUsesWithIf(res, [&](mlir::OpOperand& op) - { - return op.getOwner() != res; - }); - - for (auto& use : block.getUses()) - { - auto op = use.getOwner(); - builder.setInsertionPoint(op); - if (auto br = mlir::dyn_cast(op)) - { - assert(&block == br.dest()); - auto src = br.destOperands()[i]; - auto new_op = builder.create(op->getLoc(), type, src); - br.destOperandsMutable().slice(i, 1).assign(new_op); - } - else if (auto cond_br = mlir::dyn_cast(op)) - { - if (&block == cond_br.trueDest()) - { - auto src = cond_br.trueDestOperands()[i]; - auto new_op = builder.create(op->getLoc(), type, src); - cond_br.trueDestOperandsMutable().slice(i, 1).assign(new_op); - } - if (&block == cond_br.falseDest()) - { - auto src = cond_br.falseDestOperands()[i]; - auto new_op = builder.create(op->getLoc(), type, src); - cond_br.falseDestOperandsMutable().slice(i, 1).assign(new_op); - } - } - else - { - llvm_unreachable("setBlockSig: unknown operation type"); - } - } - arg.setType(type); - } - ++i; - } - return mlir::success(); -} - -mlir::LogicalResult convertRegionTypes( - mlir::Region *region, mlir::TypeConverter &converter, bool apply) -{ - assert(nullptr != region); - if (region->empty()) - { - return mlir::failure(); - } - - mlir::OpBuilder builder(region->getContext()); - - // Convert the arguments of each block within the region. - auto sig = converter.convertBlockSignature(®ion->front()); - assert(static_cast(sig)); - if (apply) - { - auto res = setBlockSig(region->front(), builder, *sig); - assert(mlir::succeeded(res)); - (void)res; - } - for (auto &block : llvm::make_early_inc_range(llvm::drop_begin(*region, 1))) - { - sig = converter.convertBlockSignature(&block); - if (!sig) - { - return mlir::failure(); - } - if (apply) - { - if (mlir::failed(setBlockSig(block, builder, *sig))) - { - return mlir::failure(); - } - } - } - return mlir::success(); -} -} - -plier::FuncOpSignatureConversion::FuncOpSignatureConversion(mlir::TypeConverter& conv, - mlir::MLIRContext* ctx) - : OpRewritePattern(ctx), converter(conv) {} - -mlir::LogicalResult plier::FuncOpSignatureConversion::matchAndRewrite( - mlir::FuncOp funcOp, mlir::PatternRewriter& rewriter) const -{ - auto type = funcOp.getType(); - - // Convert the original function types. - mlir::TypeConverter::SignatureConversion result(type.getNumInputs()); - llvm::SmallVector newResults; - if (mlir::failed(converter.convertSignatureArgs(type.getInputs(), result)) || - mlir::failed(converter.convertTypes(type.getResults(), newResults)) || - mlir::failed(convertRegionTypes(&funcOp.getBody(), converter, false))) - { - return mlir::failure(); - } - - bool ret_type_changed = false; - // Update the function signature in-place. - rewriter.updateRootInPlace(funcOp, [&] { - ret_type_changed = (newResults != funcOp.getType().getResults()); - funcOp.setType(mlir::FunctionType::get( - funcOp.getContext(), result.getConvertedTypes(), newResults)); - auto res = convertRegionTypes(&funcOp.getBody(), converter, true); - assert(mlir::succeeded(res)); - }); - - if (ret_type_changed) - { - auto ret_types = funcOp.getType().getResults(); - funcOp.walk([&](mlir::ReturnOp ret) - { - if (ret->getParentOp() == funcOp) - { - mlir::OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(ret); - for (auto it : llvm::enumerate(llvm::zip(ret.getOperandTypes(), ret_types))) - { - auto prev_type = std::get<0>(it.value()); - auto new_type = std::get<1>(it.value()); - if (prev_type != new_type) - { - auto index = static_cast(it.index()); - auto cast = rewriter.create(ret.getLoc(), new_type, ret.getOperand(index)); - rewriter.updateRootInPlace(ret, [&]() - { - ret.setOperand(index, cast); - }); - } - } - } - }); - auto mod = funcOp->getParentOfType(); - auto uses = funcOp.getSymbolUses(mod); - if (uses) - { - for (auto use : *uses) - { - if (auto call = mlir::dyn_cast(use.getUser())) - { - rewriter.updateRootInPlace(call, [&]() - { - for (auto it : llvm::zip(call.getResults(), ret_types)) - { - auto res = std::get<0>(it); - auto type = std::get<1>(it); - res.setType(type); - } - }); - } - } - } - } - return mlir::success(); -} diff --git a/mlir-compiler/plier/src/transforms/block_utils.cpp b/mlir-compiler/plier/src/transforms/block_utils.cpp deleted file mode 100644 index e9b346cd59d..00000000000 --- a/mlir-compiler/plier/src/transforms/block_utils.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "plier/transforms/block_utils.hpp" - -#include - -namespace -{ -auto collectParentOps(mlir::Operation* op) -{ - llvm::SmallVector ops; - while (true) - { - assert(op); - ops.emplace_back(op); - auto parent = op->getParentOp(); - if (!parent) - { - break; - } - op = parent; - } - return ops; -} -} - -plier::OpRelation plier::relativeTo(mlir::Operation* op, mlir::Operation* relativeTo) -{ - assert(op); - assert(relativeTo); - - for (auto& reg : relativeTo->getRegions()) - { - for (auto& block : reg) - { - if (block.findAncestorOpInBlock(*op)) - { - return OpRelation::In; - } - } - } - - auto ops1 = collectParentOps(op); - auto ops2 = collectParentOps(relativeTo); - - for (auto op1 : ops1) - { - assert(op1); - for (auto op2 : ops2) - { - assert(op2); - if (op1->getBlock() == op1->getBlock()) - { - if (op1->isBeforeInBlock(op1)) - { - return OpRelation::Before; - } - else - { - return OpRelation::After; - } - } - } - } - return OpRelation::Unknown; -} diff --git a/mlir-compiler/plier/src/transforms/cast_utils.cpp b/mlir-compiler/plier/src/transforms/cast_utils.cpp deleted file mode 100644 index 600a0af86ac..00000000000 --- a/mlir-compiler/plier/src/transforms/cast_utils.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "plier/transforms/cast_utils.hpp" - -#include - -mlir::Value plier::index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src, mlir::Type dst_type) -{ - auto src_type = src.getType(); - assert(src_type.isa() || dst_type.isa()); - if (src_type != dst_type) - { - return builder.create(loc, src, dst_type); - } - return src; -} - -mlir::Value plier::index_cast(mlir::OpBuilder& builder, mlir::Location loc, mlir::Value src) -{ - return index_cast(builder, loc, src, mlir::IndexType::get(builder.getContext())); -} diff --git a/mlir-compiler/plier/src/transforms/const_utils.cpp b/mlir-compiler/plier/src/transforms/const_utils.cpp deleted file mode 100644 index c0bd0646799..00000000000 --- a/mlir-compiler/plier/src/transforms/const_utils.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "plier/transforms/const_utils.hpp" - -#include -#include - -mlir::Attribute plier::getConstVal(mlir::Operation* op) -{ - assert(op); - if (!op->hasTrait()) - { - return {}; - } - - return op->getAttr("value"); -} - -mlir::Attribute plier::getConstVal(mlir::Value op) -{ - assert(op); - if (auto parent_op = op.getDefiningOp()) - { - return getConstVal(parent_op); - } - return {}; -} - -mlir::Attribute plier::getZeroVal(mlir::Type type) -{ - assert(type); - if (type.isa()) - { - return mlir::FloatAttr::get(type, 0.0); - } - if (type.isa()) - { - return mlir::IntegerAttr::get(type, 0); - } - return {}; -} diff --git a/mlir-compiler/plier/src/transforms/func_utils.cpp b/mlir-compiler/plier/src/transforms/func_utils.cpp deleted file mode 100644 index 73a0edd89fb..00000000000 --- a/mlir-compiler/plier/src/transforms/func_utils.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "plier/transforms/func_utils.hpp" - -#include -#include - -#include - -mlir::FuncOp plier::add_function( - mlir::OpBuilder& builder, mlir::ModuleOp module, llvm::StringRef name, - mlir::FunctionType type) -{ - mlir::OpBuilder::InsertionGuard guard(builder); - // Insert before module terminator. - builder.setInsertionPoint(module.getBody(), - std::prev(module.getBody()->end())); - auto func = builder.create(builder.getUnknownLoc(), name, type); - func.setPrivate(); - return func; -} diff --git a/mlir-compiler/plier/src/transforms/loop_utils.cpp b/mlir-compiler/plier/src/transforms/loop_utils.cpp deleted file mode 100644 index 970669acc6e..00000000000 --- a/mlir-compiler/plier/src/transforms/loop_utils.cpp +++ /dev/null @@ -1,397 +0,0 @@ -#include "plier/transforms/loop_utils.hpp" - -#include - -#include -#include -#include -#include -#include - -#include "plier/dialect.hpp" - -#include "plier/transforms/cast_utils.hpp" - -namespace -{ -template -Op get_next_op(llvm::iterator_range& iters) -{ - if (iters.empty()) - { - return nullptr; - } - auto res = mlir::dyn_cast(iters.begin()); - if (res) - { - auto next = std::next(iters.begin()); - iters = {next, iters.end()}; - } - return res; -} - -mlir::Value get_last_iter_value( - mlir::PatternRewriter& builder, mlir::Location loc, - mlir::Value lower_bound, mlir::Value upper_bound, mlir::Value step) -{ - auto len = builder.create(loc, upper_bound, lower_bound); - auto count = builder.create(loc, len, step); - auto inc = builder.create(loc, count, step); - return builder.create(loc, lower_bound, inc); -} - -} - -mlir::LogicalResult plier::lower_while_to_for( - plier::GetiterOp getiter, mlir::PatternRewriter& builder, - llvm::function_ref(mlir::OpBuilder&, mlir::Location)> get_bounds, - llvm::function_ref get_iter_val, - llvm::function_ref results) -{ - llvm::SmallVector to_process; - for (auto user : getiter.getOperation()->getUsers()) - { - if( auto while_op = mlir::dyn_cast(user->getParentOp())) - { - to_process.emplace_back(while_op); - } - } - - auto loc = getiter.getLoc(); - mlir::Value zero_val; - auto get_zero_index = [&]() - { - if (!zero_val) - { - mlir::OpBuilder::InsertionGuard guard(builder); - builder.setInsertionPoint(getiter); - zero_val = builder.create(loc, 0); - } - return zero_val; - }; - - auto get_neg = [&](mlir::Value value) - { - return builder.create(loc, get_zero_index(), value); - }; - - bool changed = false; - for (auto while_op : to_process) - { - auto& before_block = while_op.before().front(); - auto iters = llvm::iterator_range(before_block); - auto iternext = get_next_op(iters); - auto pairfirst = get_next_op(iters); - auto pairsecond = get_next_op(iters); - while (get_next_op(iters)) {} // skip casts - auto before_term = get_next_op(iters); - - auto skip_casts = [](mlir::Value op) - { - while (auto cast = mlir::dyn_cast_or_null(op.getDefiningOp())) - { - op = cast.getOperand(); - } - return op; - }; - if (!iternext || !pairsecond || !before_term || - skip_casts(before_term.condition()) != pairsecond) - { - continue; - } - - auto& after_block = while_op.after().front(); - - auto index_cast = [&](mlir::Value val)->mlir::Value - { - return ::plier::index_cast(builder, loc, val); - }; - - auto bounds = get_bounds(builder, loc); - auto orig_lower_bound = index_cast(std::get<0>(bounds)); - auto orig_upper_bound = index_cast(std::get<1>(bounds)); - auto orig_step = index_cast(std::get<2>(bounds)); - - // scf::ForOp/ParallelOp doesn't support negative step, so generate - // IfOp and 2 version for different step signs - // branches for const steps will be pruned later - auto gen_for = [&](bool positive) - { - auto get_loop_body_builder = [&](bool positive) - { - return [&, positive](mlir::OpBuilder& builder, mlir::Location loc, mlir::Value iv, mlir::ValueRange iterargs) - { - if (!positive) - { - iv = get_neg(iv); - } - mlir::BlockAndValueMapping mapper; - assert(before_block.getNumArguments() == iterargs.size()); - assert(after_block.getNumArguments() == before_term.args().size()); - mapper.map(before_block.getArguments(), iterargs); - for (auto it : llvm::zip(after_block.getArguments(), before_term.args())) - { - auto block_arg = std::get<0>(it); - auto term_arg = std::get<1>(it); - if (pairfirst && term_arg == pairfirst) // iter arg - { - auto iter_val = get_iter_val(builder, loc, pairfirst.getType(), iv); - mapper.map(block_arg, iter_val); - } - else - { - mapper.map(block_arg, mapper.lookupOrDefault(term_arg)); - } - } - - for (auto& op : after_block) // with terminator - { - builder.clone(op, mapper); - } - }; - }; - - auto lower_bound = orig_lower_bound; - auto upper_bound = orig_upper_bound; - auto step = orig_step; - - if (!positive) - { - lower_bound = get_neg(lower_bound); - upper_bound = get_neg(upper_bound); - step = get_neg(step); - } - - return builder.create( - loc, - lower_bound, - upper_bound, - step, - while_op.getOperands(), // iterArgs - get_loop_body_builder(positive) - ); - }; - - - auto get_if_body_builder = [&](bool positive) - { - return [&, positive](mlir::OpBuilder& builder, mlir::Location loc) - { - auto loop_op = gen_for(positive); - if (results) - { - results(loop_op); - } - builder.create(loc, loop_op.getResults()); - }; - }; - - builder.setInsertionPoint(while_op); - auto step_sign = builder.create(loc, mlir::CmpIPredicate::sge, orig_step, get_zero_index()); - auto loop_op = builder.create( - loc, - while_op.getOperands().getTypes(), - step_sign, - get_if_body_builder(true), - get_if_body_builder(false)); - - assert(while_op.getNumResults() >= loop_op.getNumResults()); - builder.updateRootInPlace(while_op, [&]() - { - assert(while_op.getNumResults() == before_term.args().size()); - for (auto it : llvm::zip(while_op.getResults(), before_term.args())) - { - auto old_res = std::get<0>(it); - auto operand = std::get<1>(it); - for (auto it2 : llvm::enumerate(before_block.getArguments())) - { - auto arg = it2.value(); - if (arg == operand) - { - assert(it2.index() < loop_op.getNumResults()); - auto new_res = loop_op.getResult(static_cast(it2.index())); - old_res.replaceAllUsesWith(new_res); - break; - } - } - if (pairfirst && operand == pairfirst && !old_res.getUsers().empty()) - { - auto val = get_last_iter_value(builder, loc, orig_lower_bound, orig_upper_bound, orig_step); - auto new_res = builder.create(loc, old_res.getType(), val); - old_res.replaceAllUsesWith(new_res); - } - assert(old_res.getUsers().empty()); - } - }); - - assert(while_op.getOperation()->getUsers().empty()); - builder.eraseOp(while_op); - changed = true; - } - - if (getiter.getOperation()->getUsers().empty()) - { - builder.eraseOp(getiter); - changed = true; - } - return mlir::success(changed); -} - -// TODO: Copypasted from mlir -namespace -{ -using namespace mlir; - -/// Verify there are no nested ParallelOps. -static bool hasNestedParallelOp(scf::ParallelOp ploop) { - auto walkResult = - ploop.getBody()->walk([](scf::ParallelOp) { return WalkResult::interrupt(); }); - return walkResult.wasInterrupted(); -} - -/// Verify equal iteration spaces. -static bool equalIterationSpaces(scf::ParallelOp firstPloop, - scf::ParallelOp secondPloop) { - if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) - return false; - - auto matchOperands = [&](const OperandRange &lhs, - const OperandRange &rhs) -> bool { - // TODO: Extend this to support aliases and equal constants. - return std::equal(lhs.begin(), lhs.end(), rhs.begin()); - }; - return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) && - matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) && - matchOperands(firstPloop.step(), secondPloop.step()); -} - -/// Checks if the parallel loops have mixed access to the same buffers. Returns -/// `true` if the first parallel loop writes to the same indices that the second -/// loop reads. -static bool haveNoReadsAfterWriteExceptSameIndex( - scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, - const BlockAndValueMapping &firstToSecondPloopIndices) { - DenseMap> bufferStores; - firstPloop.getBody()->walk([&](StoreOp store) { - bufferStores[store.getMemRef()].push_back(store.indices()); - }); - auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) { - // Stop if the memref is defined in secondPloop body. Careful alias analysis - // is needed. - auto *memrefDef = load.getMemRef().getDefiningOp(); - if (memrefDef && memrefDef->getBlock() == load->getBlock()) - return WalkResult::interrupt(); - - auto write = bufferStores.find(load.getMemRef()); - if (write == bufferStores.end()) - return WalkResult::advance(); - - // Allow only single write access per buffer. - if (write->second.size() != 1) - return WalkResult::interrupt(); - - // Check that the load indices of secondPloop coincide with store indices of - // firstPloop for the same memrefs. - auto storeIndices = write->second.front(); - auto loadIndices = load.indices(); - if (storeIndices.size() != loadIndices.size()) - return WalkResult::interrupt(); - for (size_t i = 0, e = storeIndices.size(); i < e; ++i) { - if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != - loadIndices[i]) - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - return !walkResult.wasInterrupted(); -} - -/// Analyzes dependencies in the most primitive way by checking simple read and -/// write patterns. -static LogicalResult -verifyDependencies(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, - const BlockAndValueMapping &firstToSecondPloopIndices) { - if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop, - firstToSecondPloopIndices)) - return failure(); - - BlockAndValueMapping secondToFirstPloopIndices; - secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), - firstPloop.getBody()->getArguments()); - return success(haveNoReadsAfterWriteExceptSameIndex( - secondPloop, firstPloop, secondToFirstPloopIndices)); -} - -static bool -isFusionLegal(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, - const BlockAndValueMapping &firstToSecondPloopIndices) { - return !hasNestedParallelOp(firstPloop) && - !hasNestedParallelOp(secondPloop) && - equalIterationSpaces(firstPloop, secondPloop) && - succeeded(verifyDependencies(firstPloop, secondPloop, - firstToSecondPloopIndices)); -} - -/// Prepends operations of firstPloop's body into secondPloop's body. -static bool fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp secondPloop, - OpBuilder &b) { - BlockAndValueMapping firstToSecondPloopIndices; - firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), - secondPloop.getBody()->getArguments()); - - if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices)) - return false; - - b.setInsertionPointToStart(secondPloop.getBody()); - for (auto &op : firstPloop.getBody()->without_terminator()) - b.clone(op, firstToSecondPloopIndices); - firstPloop.erase(); - return true; -} - -bool hasNoEffect(mlir::Operation* op) -{ - if (op->getNumRegions() != 0) - { - return false; - } - if (auto interface = dyn_cast(op)) - { - return !interface.hasEffect() && - !interface.hasEffect(); - } - return !op->hasTrait<::mlir::OpTrait::HasRecursiveSideEffects>(); -} -} - -mlir::LogicalResult plier::naivelyFuseParallelOps(Region ®ion) { - OpBuilder b(region); - // Consider every single block and attempt to fuse adjacent loops. - bool changed = false; - for (auto &block : region) { - SmallVector, 1> ploopChains{{}}; - // Not using `walk()` to traverse only top-level parallel loops and also - // make sure that there are no side-effecting ops between the parallel - // loops. - bool noSideEffects = true; - for (auto &op : block) { - if (auto ploop = dyn_cast(op)) { - if (noSideEffects) { - ploopChains.back().push_back(ploop); - } else { - ploopChains.push_back({ploop}); - noSideEffects = true; - } - continue; - } - // TODO: Handle region side effects properly. - noSideEffects &= hasNoEffect(&op); - } - for (llvm::ArrayRef ploops : ploopChains) { - for (size_t i = 0, e = ploops.size(); i + 1 < e; ++i) - if (fuseIfLegal(ploops[i], ploops[i + 1], b)) - changed = true; - } - } - return mlir::success(changed); -} diff --git a/mlir-compiler/plier/src/transforms/pipeline_utils.cpp b/mlir-compiler/plier/src/transforms/pipeline_utils.cpp deleted file mode 100644 index c7127522bb9..00000000000 --- a/mlir-compiler/plier/src/transforms/pipeline_utils.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "plier/transforms/pipeline_utils.hpp" - -#include -#include - -#include "plier/dialect.hpp" - -mlir::ArrayAttr plier::get_pipeline_jump_markers(mlir::ModuleOp module) -{ - return module->getAttrOfType(plier::attributes::getJumpMarkersName()); -} - -void plier::add_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) -{ - assert(name); - assert(!name.getValue().empty()); - - auto jump_markers = plier::attributes::getJumpMarkersName(); - llvm::SmallVector name_list; - if (auto old_attr = module->getAttrOfType(jump_markers)) - { - name_list.assign(old_attr.begin(), old_attr.end()); - } - auto it = llvm::lower_bound(name_list, name, - [](mlir::Attribute lhs, mlir::StringAttr rhs) - { - return lhs.cast().getValue() < rhs.getValue(); - }); - if (it == name_list.end()) - { - name_list.emplace_back(name); - } - else if (*it != name) - { - name_list.insert(it, name); - } - module->setAttr(jump_markers, mlir::ArrayAttr::get(module.getContext(), name_list)); -} - - -void plier::remove_pipeline_jump_marker(mlir::ModuleOp module, mlir::StringAttr name) -{ - assert(name); - assert(!name.getValue().empty()); - - auto jump_markers = plier::attributes::getJumpMarkersName(); - llvm::SmallVector name_list; - if (auto old_attr = module->getAttrOfType(jump_markers)) - { - name_list.assign(old_attr.begin(), old_attr.end()); - } - auto it = llvm::lower_bound(name_list, name, - [](mlir::Attribute lhs, mlir::StringAttr rhs) - { - return lhs.cast().getValue() < rhs.getValue(); - }); - assert(it != name_list.end()); - name_list.erase(it); - module->setAttr(jump_markers, mlir::ArrayAttr::get(module.getContext(), name_list)); -} diff --git a/mlir-compiler/plier/src/utils.cpp b/mlir-compiler/plier/src/utils.cpp deleted file mode 100644 index 36a9a414a92..00000000000 --- a/mlir-compiler/plier/src/utils.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include "plier/utils.hpp" - -#include - -#include "llvm/ADT/Twine.h" - -void plier::report_error(const llvm::Twine& msg) -{ - throw std::runtime_error(msg.str()); -} From 9799bf28736615d1b9671a5c8b9dab11737942a0 Mon Sep 17 00:00:00 2001 From: Alexander-Makaryev <40917969+Alexander-Makaryev@users.noreply.github.com> Date: Fri, 19 Mar 2021 20:16:57 +0300 Subject: [PATCH 255/259] control exported symbols (#208) --- mlir-compiler/mlir-compiler/CMakeLists.txt | 3 +-- mlir-compiler/mlir-compiler/export.txt | 4 ++++ 2 files changed, 5 insertions(+), 2 deletions(-) create mode 100644 mlir-compiler/mlir-compiler/export.txt diff --git a/mlir-compiler/mlir-compiler/CMakeLists.txt b/mlir-compiler/mlir-compiler/CMakeLists.txt index 791d54fef3a..b3b34ed07d4 100644 --- a/mlir-compiler/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/mlir-compiler/CMakeLists.txt @@ -45,10 +45,9 @@ if (MSVC) endif () if(UNIX) - target_link_options(${PROJECT_NAME} PRIVATE "LINKER:--exclude-libs,ALL") + target_link_options(${PROJECT_NAME} PRIVATE "LINKER:--version-script=export.txt") endif() - target_compile_definitions(${PROJECT_NAME} PRIVATE ${LLVM_DEFINITIONS}) target_link_libraries(${PROJECT_NAME} PRIVATE diff --git a/mlir-compiler/mlir-compiler/export.txt b/mlir-compiler/mlir-compiler/export.txt new file mode 100644 index 00000000000..dce06e3f92c --- /dev/null +++ b/mlir-compiler/mlir-compiler/export.txt @@ -0,0 +1,4 @@ +{ + global: PyInit_mlir_compiler; + local: *; +}; From a9cc160fd9d31fa6e3799e780317e81b0c373f54 Mon Sep 17 00:00:00 2001 From: Alexander-Makaryev <40917969+Alexander-Makaryev@users.noreply.github.com> Date: Fri, 19 Mar 2021 22:40:36 +0300 Subject: [PATCH 256/259] Darwin: another format for export symbols list (#209) --- mlir-compiler/mlir-compiler/CMakeLists.txt | 6 +++++- mlir-compiler/mlir-compiler/export_darwin.txt | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 mlir-compiler/mlir-compiler/export_darwin.txt diff --git a/mlir-compiler/mlir-compiler/CMakeLists.txt b/mlir-compiler/mlir-compiler/CMakeLists.txt index b3b34ed07d4..5458d260ab8 100644 --- a/mlir-compiler/mlir-compiler/CMakeLists.txt +++ b/mlir-compiler/mlir-compiler/CMakeLists.txt @@ -44,10 +44,14 @@ if (MSVC) target_compile_options(${PROJECT_NAME} PRIVATE /EHsc) endif () -if(UNIX) +if (CMAKE_SYSTEM_NAME STREQUAL Linux) target_link_options(${PROJECT_NAME} PRIVATE "LINKER:--version-script=export.txt") endif() +if (CMAKE_SYSTEM_NAME STREQUAL Darwin) + target_link_libraries(${PROJECT_NAME} PRIVATE "-Wl,-exported_symbols_list,export_darwin.txt") +endif() + target_compile_definitions(${PROJECT_NAME} PRIVATE ${LLVM_DEFINITIONS}) target_link_libraries(${PROJECT_NAME} PRIVATE diff --git a/mlir-compiler/mlir-compiler/export_darwin.txt b/mlir-compiler/mlir-compiler/export_darwin.txt new file mode 100644 index 00000000000..a5d5900af2d --- /dev/null +++ b/mlir-compiler/mlir-compiler/export_darwin.txt @@ -0,0 +1 @@ +_PyInit_mlir_compiler From 3a0c70894a7bb4b2672b8f7bb09ccae5539855fb Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 20 Mar 2021 16:19:09 +0300 Subject: [PATCH 257/259] Refactor linalg optimizations flow (#210) --- .../src/pipelines/plier_to_linalg.cpp | 76 ++++++++++--------- 1 file changed, 41 insertions(+), 35 deletions(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index c3559085f71..476f95f722b 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -34,6 +34,7 @@ #include "plier/rewrites/force_inline.hpp" #include "plier/rewrites/index_type_propagation.hpp" #include "plier/rewrites/loop_rewrites.hpp" +#include "plier/rewrites/memory_rewrites.hpp" #include "plier/transforms/loop_utils.hpp" #include "base_pipeline.hpp" @@ -44,6 +45,29 @@ namespace { +void applyOptimizations(mlir::FuncOp op, const mlir::FrozenRewritePatternList& patterns, llvm::function_ref additionalOpts = nullptr) +{ + bool repeat = false; + do + { + repeat = false; + (void)mlir::applyPatternsAndFoldGreedily(op, patterns); + if (mlir::succeeded(plier::applyCSE(op.getRegion(), false))) + { + repeat = true; + } + if (mlir::succeeded(plier::promoteLoads(op.getRegion()))) + { + repeat = true; + } + if (additionalOpts && mlir::succeeded(additionalOpts(op))) + { + repeat = true; + } + } + while(repeat); +} + enum class ArrayLayout { C, @@ -900,12 +924,12 @@ void LowerLinalgPass::runOnOperation() } struct PostPlierToLinalgPass : - public mlir::PassWrapper> + public mlir::PassWrapper { - void runOnOperation() override; + void runOnFunction() override; }; -void PostPlierToLinalgPass::runOnOperation() +void PostPlierToLinalgPass::runOnFunction() { mlir::OwningRewritePatternList patterns; @@ -916,7 +940,7 @@ void PostPlierToLinalgPass::runOnOperation() SimplifyExpandDims >(&getContext()); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyOptimizations(getFunction(), std::move(patterns)); } struct TensorFusionPass : @@ -1016,12 +1040,12 @@ void RetainArgsPass::runOnFunction() } struct PostLinalgOptPass : - public mlir::PassWrapper> + public mlir::PassWrapper { - void runOnOperation() override; + void runOnFunction() override; }; -void PostLinalgOptPass::runOnOperation() +void PostLinalgOptPass::runOnFunction() { mlir::OwningRewritePatternList patterns; @@ -1032,37 +1056,19 @@ void PostLinalgOptPass::runOnOperation() plier::CanonicalizeReduction >(&context); - mlir::FrozenRewritePatternList frozenPatterns(std::move(patterns)); - - while (true) + applyOptimizations(getFunction(), std::move(patterns), [](mlir::FuncOp op) { - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), frozenPatterns); - bool rerun = false; - for (auto& op : getOperation().getRegion().front()) - { - if (auto func = mlir::dyn_cast(op)) - { - if (mlir::succeeded(plier::naivelyFuseParallelOps(func.getRegion()))) - { - rerun = true; - } - } - } - if (!rerun) - { - break; - } - } - + return plier::naivelyFuseParallelOps(op.getRegion()); + }); } struct PromoteParallelPass : - public mlir::PassWrapper> + public mlir::PassWrapper { - void runOnOperation() override; + void runOnFunction() override; }; -void PromoteParallelPass::runOnOperation() +void PromoteParallelPass::runOnFunction() { mlir::OwningRewritePatternList patterns; @@ -1074,13 +1080,13 @@ void PromoteParallelPass::runOnOperation() plier::PromoteToParallel // TODO >(&context); - (void)mlir::applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + applyOptimizations(getFunction(), std::move(patterns)); } void populate_plier_to_linalg_gen_pipeline(mlir::OpPassManager& pm) { pm.addPass(std::make_unique()); - pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); pm.addPass(mlir::createSymbolDCEPass()); } @@ -1105,9 +1111,9 @@ void populate_plier_to_linalg_opt_pipeline(mlir::OpPassManager& pm) pm.addPass(mlir::createCopyRemovalPass()); pm.addPass(std::make_unique()); - pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); pm.addPass(mlir::createSymbolDCEPass()); - pm.addPass(std::make_unique()); + pm.addNestedPass(std::make_unique()); } } From 99de6b5b2695d805034ef0c52e1068aeea36345d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 21 Mar 2021 19:55:14 +0300 Subject: [PATCH 258/259] use new memory optimization (#211) --- mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp index 476f95f722b..67c5d2d166d 100644 --- a/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp +++ b/mlir-compiler/mlir-compiler/src/pipelines/plier_to_linalg.cpp @@ -56,7 +56,7 @@ void applyOptimizations(mlir::FuncOp op, const mlir::FrozenRewritePatternList& p { repeat = true; } - if (mlir::succeeded(plier::promoteLoads(op.getRegion()))) + if (mlir::succeeded(plier::optimizeMemoryOps(op))) { repeat = true; } From 345e70078c83e9ea51a9a9ac9c38d4b597b5a6c2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 22 Mar 2021 22:34:35 +0300 Subject: [PATCH 259/259] move mlir numba passes to mlir dir (#212) --- numba/core/compiler.py | 5 +- numba/core/typed_passes.py | 128 --------------------------------- numba/mlir/inner_compiler.py | 5 +- numba/mlir/passes.py | 136 +++++++++++++++++++++++++++++++++++ 4 files changed, 142 insertions(+), 132 deletions(-) create mode 100644 numba/mlir/passes.py diff --git a/numba/core/compiler.py b/numba/core/compiler.py index df3a4fa70c1..3b79fa2608a 100644 --- a/numba/core/compiler.py +++ b/numba/core/compiler.py @@ -28,12 +28,13 @@ NopythonRewrites, PreParforPass, ParforPass, DumpParforDiagnostics, IRLegalization, NoPythonBackend, - InlineOverloads, PreLowerStripPhis, - MlirDumpPlier, MlirBackend) + InlineOverloads, PreLowerStripPhis) from numba.core.object_mode_passes import (ObjectModeFrontEnd, ObjectModeBackEnd, CompileInterpMode) +from numba.mlir.passes import (MlirDumpPlier, MlirBackend) + class Flags(utils.ConfigOptions): # These options are all false by default, but the defaults are # different with the @jit decorator (see targets.options.TargetOptions). diff --git a/numba/core/typed_passes.py b/numba/core/typed_passes.py index 0aa0e74b8cc..98d7c7a0865 100644 --- a/numba/core/typed_passes.py +++ b/numba/core/typed_passes.py @@ -467,134 +467,6 @@ def run_pass(self, state): ) return True -import numba.mlir.settings -import numba.mlir.func_registry -import numba.core.types.functions -_mlir_last_compiled_func = None -_mlir_active_module = None - -class MlirBackendBase(FunctionPass): - - def __init__(self): - import numba.mlir.func_registry - self._get_func_name = numba.mlir.func_registry.get_func_name - FunctionPass.__init__(self) - - def run_pass(self, state): - numba.mlir.func_registry.push_active_funcs_stack() - try: - res = self.run_pass_impl(state) - finally: - numba.mlir.func_registry.pop_active_funcs_stack() - return res - - def _resolve_func_name(self, obj): - name, func = self._resolve_func_name_impl(obj) - if not (name is None or func is None): - numba.mlir.func_registry.add_active_funcs(name, func) - return name - - def _resolve_func_name_impl(self, obj): - if isinstance(obj, types.Function): - func = obj.typing_key - return (self._get_func_name(func), None) - if isinstance(obj, types.BoundFunction): - return (str(obj.typing_key), None) - if isinstance(obj, numba.core.types.functions.Dispatcher): - func = obj.dispatcher.py_func - return (func.__module__ + "." + func.__qualname__, func) - return (None, None) - - def _get_func_context(self, state): - mangler = state.targetctx.mangler - mangler = default_mangler if mangler is None else mangler - unique_name = state.func_ir.func_id.unique_name - modname = state.func_ir.func_id.func.__module__ - from numba.core.funcdesc import qualifying_prefix - qualprefix = qualifying_prefix(modname, unique_name) - fn_name = mangler(qualprefix, state.args) - - from numba.np.ufunc.parallel import get_thread_count - - ctx = {} - ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': numba.mlir.settings.PRINT_IR} - ctx['typemap'] = lambda op: state.typemap[op.name] - ctx['fnargs'] = lambda: state.args - ctx['restype'] = lambda: state.return_type - ctx['fnname'] = lambda: fn_name - ctx['resolve_func'] = self._resolve_func_name - ctx['fastmath'] = lambda: state.targetctx.fastmath - ctx['max_concurrency'] = lambda: get_thread_count() if state.flags.auto_parallel.enabled else 0 - return ctx - -@register_pass(mutates_CFG=True, analysis_only=False) -class MlirDumpPlier(MlirBackendBase): - - _name = "mlir_dump_plier" - - def __init__(self): - MlirBackendBase.__init__(self) - - def run_pass(self, state): - import mlir_compiler - module = mlir_compiler.create_module() - ctx = self._get_func_context(state) - mlir_compiler.lower_function(ctx, module, state.func_ir) - print(mlir_compiler.module_str(module)) - return True - -def get_mlir_func(): - global _mlir_last_compiled_func - return _mlir_last_compiled_func - -@register_pass(mutates_CFG=True, analysis_only=False) -class MlirBackend(MlirBackendBase): - - _name = "mlir_backend" - - def __init__(self): - MlirBackendBase.__init__(self) - - def run_pass_impl(self, state): - import mlir_compiler - global _mlir_active_module - old_module = _mlir_active_module - - try: - module = mlir_compiler.create_module() - _mlir_active_module = module - global _mlir_last_compiled_func - ctx = self._get_func_context(state) - _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) - mod_ir = mlir_compiler.compile_module(ctx, module) - finally: - _mlir_active_module = old_module - setattr(state, 'mlir_blob', mod_ir) - _reload_parfors() - state.reload_init.append(_reload_parfors) - return True - -@register_pass(mutates_CFG=True, analysis_only=False) -class MlirBackendInner(MlirBackendBase): - - _name = "mlir_backend_inner" - - def __init__(self): - MlirBackendBase.__init__(self) - - def run_pass_impl(self, state): - import mlir_compiler - global _mlir_active_module - module = _mlir_active_module - assert not module is None - global _mlir_last_compiled_func - ctx = self._get_func_context(state) - _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) - from numba.core.compiler import compile_result - state.cr = compile_result() - return True - - @register_pass(mutates_CFG=True, analysis_only=False) class InlineOverloads(FunctionPass): """ diff --git a/numba/mlir/inner_compiler.py b/numba/mlir/inner_compiler.py index de83883a705..b74987db8fe 100644 --- a/numba/mlir/inner_compiler.py +++ b/numba/mlir/inner_compiler.py @@ -1,8 +1,9 @@ -from numba.core.typed_passes import get_mlir_func, NopythonTypeInference, AnnotateTypes, MlirBackendInner +from numba.core.typed_passes import NopythonTypeInference, AnnotateTypes from numba.core.compiler import CompilerBase, DefaultPassBuilder, DEFAULT_FLAGS, compile_extra from numba.core.compiler_machinery import PassManager from numba.core import typing, cpu -# from numba import njit + +from numba.mlir.passes import MlirBackendInner, get_mlir_func class MlirTempCompiler(CompilerBase): # custom compiler extends from CompilerBase diff --git a/numba/mlir/passes.py b/numba/mlir/passes.py new file mode 100644 index 00000000000..c14652494ec --- /dev/null +++ b/numba/mlir/passes.py @@ -0,0 +1,136 @@ +from numba.core.compiler_machinery import (FunctionPass, register_pass) +from numba.core import (types) + +import numba.mlir.settings +import numba.mlir.func_registry +import numba.core.types.functions +_mlir_last_compiled_func = None +_mlir_active_module = None + +def _reload_parfors(): + """Reloader for cached parfors + """ + # Re-initialize the parallel backend when load from cache. + from numba.np.ufunc.parallel import _launch_threads + _launch_threads() + +class MlirBackendBase(FunctionPass): + + def __init__(self): + import numba.mlir.func_registry + self._get_func_name = numba.mlir.func_registry.get_func_name + FunctionPass.__init__(self) + + def run_pass(self, state): + numba.mlir.func_registry.push_active_funcs_stack() + try: + res = self.run_pass_impl(state) + finally: + numba.mlir.func_registry.pop_active_funcs_stack() + return res + + def _resolve_func_name(self, obj): + name, func = self._resolve_func_name_impl(obj) + if not (name is None or func is None): + numba.mlir.func_registry.add_active_funcs(name, func) + return name + + def _resolve_func_name_impl(self, obj): + if isinstance(obj, types.Function): + func = obj.typing_key + return (self._get_func_name(func), None) + if isinstance(obj, types.BoundFunction): + return (str(obj.typing_key), None) + if isinstance(obj, numba.core.types.functions.Dispatcher): + func = obj.dispatcher.py_func + return (func.__module__ + "." + func.__qualname__, func) + return (None, None) + + def _get_func_context(self, state): + mangler = state.targetctx.mangler + mangler = default_mangler if mangler is None else mangler + unique_name = state.func_ir.func_id.unique_name + modname = state.func_ir.func_id.func.__module__ + from numba.core.funcdesc import qualifying_prefix + qualprefix = qualifying_prefix(modname, unique_name) + fn_name = mangler(qualprefix, state.args) + + from numba.np.ufunc.parallel import get_thread_count + + ctx = {} + ctx['compiler_settings'] = {'verify': True, 'pass_statistics': False, 'pass_timings': False, 'ir_printing': numba.mlir.settings.PRINT_IR} + ctx['typemap'] = lambda op: state.typemap[op.name] + ctx['fnargs'] = lambda: state.args + ctx['restype'] = lambda: state.return_type + ctx['fnname'] = lambda: fn_name + ctx['resolve_func'] = self._resolve_func_name + ctx['fastmath'] = lambda: state.targetctx.fastmath + ctx['max_concurrency'] = lambda: get_thread_count() if state.flags.auto_parallel.enabled else 0 + return ctx + +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirDumpPlier(MlirBackendBase): + + _name = "mlir_dump_plier" + + def __init__(self): + MlirBackendBase.__init__(self) + + def run_pass(self, state): + import mlir_compiler + module = mlir_compiler.create_module() + ctx = self._get_func_context(state) + mlir_compiler.lower_function(ctx, module, state.func_ir) + print(mlir_compiler.module_str(module)) + return True + +def get_mlir_func(): + global _mlir_last_compiled_func + return _mlir_last_compiled_func + +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirBackend(MlirBackendBase): + + _name = "mlir_backend" + + def __init__(self): + MlirBackendBase.__init__(self) + + def run_pass_impl(self, state): + import mlir_compiler + global _mlir_active_module + old_module = _mlir_active_module + + try: + module = mlir_compiler.create_module() + _mlir_active_module = module + global _mlir_last_compiled_func + ctx = self._get_func_context(state) + _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) + mod_ir = mlir_compiler.compile_module(ctx, module) + finally: + _mlir_active_module = old_module + setattr(state, 'mlir_blob', mod_ir) + _reload_parfors() + state.reload_init.append(_reload_parfors) + return True + +@register_pass(mutates_CFG=True, analysis_only=False) +class MlirBackendInner(MlirBackendBase): + + _name = "mlir_backend_inner" + + def __init__(self): + MlirBackendBase.__init__(self) + + def run_pass_impl(self, state): + import mlir_compiler + global _mlir_active_module + module = _mlir_active_module + assert not module is None + global _mlir_last_compiled_func + ctx = self._get_func_context(state) + _mlir_last_compiled_func = mlir_compiler.lower_function(ctx, module, state.func_ir) + from numba.core.compiler import compile_result + state.cr = compile_result() + return True