diff --git a/numba/core/cpu_dispatcher.py b/numba/core/cpu_dispatcher.py index f6ea02b08a7..1b712ce2aff 100644 --- a/numba/core/cpu_dispatcher.py +++ b/numba/core/cpu_dispatcher.py @@ -6,22 +6,8 @@ class CPUDispatcher(dispatcher.Dispatcher): targetdescr = cpu_target def __init__(self, py_func, locals={}, targetoptions={}, impl_kind='direct', pipeline_class=compiler.Compiler): - if ('parallel' in targetoptions and isinstance(targetoptions['parallel'], dict) and - 'offload' in targetoptions['parallel'] and targetoptions['parallel']['offload'] == True): - import numba.dppl_config as dppl_config - if dppl_config.dppl_present: - from numba.dppl.compiler import DPPLCompiler - dispatcher.Dispatcher.__init__(self, py_func, locals=locals, - targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=DPPLCompiler) - else: - print("---------------------------------------------------------------------------") - print("WARNING : offload=True option ignored. Ensure OpenCL drivers are installed.") - print("---------------------------------------------------------------------------") - dispatcher.Dispatcher.__init__(self, py_func, locals=locals, - targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class) - else: - dispatcher.Dispatcher.__init__(self, py_func, locals=locals, - targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class) + dispatcher.Dispatcher.__init__(self, py_func, locals=locals, + targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class) dispatcher_registry['cpu'] = CPUDispatcher diff --git a/numba/core/decorators.py b/numba/core/decorators.py index f57c427994b..9ece7ce422b 100644 --- a/numba/core/decorators.py +++ b/numba/core/decorators.py @@ -11,6 +11,7 @@ from numba.core.errors import DeprecationError, NumbaDeprecationWarning from numba.stencils.stencil import stencil from numba.core import config, sigutils, registry, cpu_dispatcher +from numba.dppl import gpu_dispatcher _logger = logging.getLogger(__name__) diff --git a/numba/core/transforms.py b/numba/core/transforms.py index d01747884ac..2f49aef93ad 100644 --- a/numba/core/transforms.py +++ b/numba/core/transforms.py @@ -323,8 +323,9 @@ def with_lifting(func_ir, typingctx, targetctx, flags, locals): """ from numba.core import postproc - def dispatcher_factory(func_ir, objectmode=False, **kwargs): + def dispatcher_factory(func_ir, objectmode=False, dppl_mode=False, **kwargs): from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith + from numba.dppl.withcontexts import DPPLLiftedWith myflags = flags.copy() if objectmode: @@ -335,6 +336,8 @@ def dispatcher_factory(func_ir, objectmode=False, **kwargs): myflags.force_pyobject = True myflags.no_cpython_wrapper = False cls = ObjModeLiftedWith + elif dppl_mode: + cls = DPPLLiftedWith else: cls = LiftedWith return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs) diff --git a/numba/dppl/gpu_dispatcher.py b/numba/dppl/gpu_dispatcher.py new file mode 100644 index 00000000000..fffd12c863c --- /dev/null +++ b/numba/dppl/gpu_dispatcher.py @@ -0,0 +1,22 @@ +from numba.core import dispatcher, compiler +from numba.core.registry import cpu_target, dispatcher_registry +import numba.dppl_config as dppl_config +from numba.dppl.compiler import DPPLCompiler + + +class GPUDispatcher(dispatcher.Dispatcher): + targetdescr = cpu_target + + def __init__(self, py_func, locals={}, targetoptions={}, impl_kind='direct', pipeline_class=compiler.Compiler): + if dppl_config.dppl_present: + dispatcher.Dispatcher.__init__(self, py_func, locals=locals, + targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=DPPLCompiler) + else: + print("---------------------------------------------------------------------") + print("WARNING : DPPL pipeline ignored. Ensure OpenCL drivers are installed.") + print("---------------------------------------------------------------------") + dispatcher.Dispatcher.__init__(self, py_func, locals=locals, + targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class) + + +dispatcher_registry['gpu'] = GPUDispatcher \ No newline at end of file diff --git a/numba/dppl/tests/dppl/test_with_semantics.py b/numba/dppl/tests/dppl/test_with_semantics.py new file mode 100644 index 00000000000..675034dce43 --- /dev/null +++ b/numba/dppl/tests/dppl/test_with_semantics.py @@ -0,0 +1,120 @@ +from numba.dppl.testing import unittest +from numba.dppl.testing import DPPLTestCase +from numba.dppl.withcontexts import dppl_context +from numba.core import typing, cpu +from numba.core.compiler import compile_ir, DEFAULT_FLAGS +from numba.core.transforms import with_lifting +from numba.core.registry import cpu_target +from numba.core.bytecode import FunctionIdentity, ByteCode +from numba.core.interpreter import Interpreter +from numba.tests.support import captured_stdout +from numba import njit, prange +import numpy as np + + +def get_func_ir(func): + func_id = FunctionIdentity.from_function(func) + bc = ByteCode(func_id=func_id) + interp = Interpreter(func_id) + func_ir = interp.interpret(bc) + return func_ir + + +def liftcall1(): + x = 1 + print("A", x) + with dppl_context: + x += 1 + print("B", x) + return x + + +def liftcall2(): + x = 1 + print("A", x) + with dppl_context: + x += 1 + print("B", x) + with dppl_context: + x += 10 + print("C", x) + return x + + +def liftcall3(): + x = 1 + print("A", x) + with dppl_context: + if x > 0: + x += 1 + print("B", x) + with dppl_context: + for i in range(10): + x += i + print("C", x) + return x + + +class BaseTestWithLifting(DPPLTestCase): + def setUp(self): + super(BaseTestWithLifting, self).setUp() + self.typingctx = typing.Context() + self.targetctx = cpu.CPUContext(self.typingctx) + self.flags = DEFAULT_FLAGS + + def check_extracted_with(self, func, expect_count, expected_stdout): + the_ir = get_func_ir(func) + new_ir, extracted = with_lifting( + the_ir, self.typingctx, self.targetctx, self.flags, + locals={}, + ) + self.assertEqual(len(extracted), expect_count) + cres = self.compile_ir(new_ir) + + with captured_stdout() as out: + cres.entry_point() + + self.assertEqual(out.getvalue(), expected_stdout) + + def compile_ir(self, the_ir, args=(), return_type=None): + typingctx = self.typingctx + targetctx = self.targetctx + flags = self.flags + # Register the contexts in case for nested @jit or @overload calls + with cpu_target.nested_context(typingctx, targetctx): + return compile_ir(typingctx, targetctx, the_ir, args, + return_type, flags, locals={}) + + +class TestLiftCall(BaseTestWithLifting): + + def check_same_semantic(self, func): + """Ensure same semantic with non-jitted code + """ + jitted = njit(target="gpu")(func) + with captured_stdout() as got: + jitted() + + with captured_stdout() as expect: + func() + + self.assertEqual(got.getvalue(), expect.getvalue()) + + def test_liftcall1(self): + self.check_extracted_with(liftcall1, expect_count=1, + expected_stdout="A 1\nB 2\n") + self.check_same_semantic(liftcall1) + + def test_liftcall2(self): + self.check_extracted_with(liftcall2, expect_count=2, + expected_stdout="A 1\nB 2\nC 12\n") + self.check_same_semantic(liftcall2) + + def test_liftcall3(self): + self.check_extracted_with(liftcall3, expect_count=2, + expected_stdout="A 1\nB 2\nC 47\n") + self.check_same_semantic(liftcall3) + + +if __name__ == '__main__': + unittest.main() diff --git a/numba/dppl/withcontexts.py b/numba/dppl/withcontexts.py new file mode 100644 index 00000000000..9274e9960d3 --- /dev/null +++ b/numba/dppl/withcontexts.py @@ -0,0 +1,144 @@ +from numba.core import compiler, typing, types, sigutils +from numba.core.compiler_lock import global_compiler_lock +from numba.core.dispatcher import _DispatcherBase +from numba.core.transforms import find_region_inout_vars +from numba.core.withcontexts import (WithContext, _mutate_with_block_callee, _mutate_with_block_caller, + _clear_blocks) +from numba.dppl.compiler import DPPLCompiler +from numba.core.cpu_options import ParallelOptions + + +class _DPPLContextType(WithContext): + def mutate_with_body(self, func_ir, blocks, blk_start, blk_end, + body_blocks, dispatcher_factory, extra): + assert extra is None + vlt = func_ir.variable_lifetime + + inputs, outputs = find_region_inout_vars( + blocks=blocks, + livemap=vlt.livemap, + callfrom=blk_start, + returnto=blk_end, + body_block_ids=set(body_blocks), + ) + + lifted_blks = {k: blocks[k] for k in body_blocks} + _mutate_with_block_callee(lifted_blks, blk_start, blk_end, + inputs, outputs) + + # XXX: transform body-blocks to return the output variables + lifted_ir = func_ir.derive( + blocks=lifted_blks, + arg_names=tuple(inputs), + arg_count=len(inputs), + force_non_generator=True, + ) + + dispatcher = dispatcher_factory(lifted_ir, dppl_mode=True) + + newblk = _mutate_with_block_caller( + dispatcher, blocks, blk_start, blk_end, inputs, outputs, + ) + + blocks[blk_start] = newblk + _clear_blocks(blocks, body_blocks) + return dispatcher + + +class DPPLLiftedCode(_DispatcherBase): + """ + Implementation of the hidden dispatcher objects used for lifted code + (a lifted loop is really compiled as a separate function). + """ + _fold_args = False + + def __init__(self, func_ir, typingctx, targetctx, flags, locals): + self.func_ir = func_ir + self.lifted_from = None + + self.typingctx = typingctx + self.targetctx = targetctx + self.flags = flags + self.locals = locals + + _DispatcherBase.__init__(self, self.func_ir.arg_count, + self.func_ir.func_id.func, + self.func_ir.func_id.pysig, + can_fallback=True, + exact_match_required=False) + + def get_source_location(self): + """Return the starting line number of the loop. + """ + return self.func_ir.loc.line + + def _pre_compile(self, args, return_type, flags): + """Pre-compile actions + """ + pass + + @global_compiler_lock + def compile(self, sig): + # Use counter to track recursion compilation depth + with self._compiling_counter: + # XXX this is mostly duplicated from Dispatcher. + flags = self.flags + args, return_type = sigutils.normalize_signature(sig) + + # Don't recompile if signature already exists + # (e.g. if another thread compiled it before we got the lock) + existing = self.overloads.get(tuple(args)) + if existing is not None: + return existing.entry_point + + self._pre_compile(args, return_type, flags) + + # Clone IR to avoid (some of the) mutation in the rewrite pass + cloned_func_ir = self.func_ir.copy() + + flags.auto_parallel = ParallelOptions({'offload':True}) + cres = compiler.compile_ir(typingctx=self.typingctx, + targetctx=self.targetctx, + func_ir=cloned_func_ir, + args=args, return_type=return_type, + flags=flags, locals=self.locals, + lifted=(), + lifted_from=self.lifted_from, + is_lifted_loop=True, + pipeline_class=DPPLCompiler) + + # Check typing error if object mode is used + if cres.typing_error is not None and not flags.enable_pyobject: + raise cres.typing_error + self.add_overload(cres) + return cres.entry_point + + +class DPPLLiftedWith(DPPLLiftedCode): + @property + def _numba_type_(self): + return types.Dispatcher(self) + + def get_call_template(self, args, kws): + """ + Get a typing.ConcreteTemplate for this dispatcher and the given + *args* and *kws* types. This enables the resolving of the return type. + + A (template, pysig, args, kws) tuple is returned. + """ + # Ensure an overload is available + if self._can_compile: + self.compile(tuple(args)) + + pysig = None + # Create function type for typing + func_name = self.py_func.__name__ + name = "CallTemplate({0})".format(func_name) + # The `key` isn't really used except for diagnosis here, + # so avoid keeping a reference to `cfunc`. + call_template = typing.make_concrete_template( + name, key=func_name, signatures=self.nopython_signatures) + return call_template, pysig, args, kws + + +dppl_context = _DPPLContextType()