diff --git a/dace/codegen/CMakeLists.txt b/dace/codegen/CMakeLists.txt index 5482d4d30d..7d1ca4d714 100644 --- a/dace/codegen/CMakeLists.txt +++ b/dace/codegen/CMakeLists.txt @@ -141,7 +141,7 @@ if(DACE_ENABLE_CUDA) set(CMAKE_CUDA_ARCHITECTURES "${LOCAL_CUDA_ARCHITECTURES}") enable_language(CUDA) - list(APPEND DACE_LIBS CUDA::cudart) + list(APPEND DACE_LIBS CUDA::cudart CUDA::nvtx3) add_definitions(-DWITH_CUDA) if (MSVC_IDE) diff --git a/dace/codegen/targets/cpp.py b/dace/codegen/targets/cpp.py index 5c4d04c0a7..38e16e72d6 100644 --- a/dace/codegen/targets/cpp.py +++ b/dace/codegen/targets/cpp.py @@ -257,7 +257,7 @@ def ptr(name: str, desc: data.Data, sdfg: SDFG = None, framecode=None) -> str: if desc.storage == dtypes.StorageType.CPU_ThreadLocal: # Use unambiguous name for thread-local arrays return f'__{sdfg.cfg_id}_{name}' - elif not CUDACodeGen._in_device_code: # GPU kernels cannot access state + elif not CUDACodeGen._in_device_code.get(): # GPU kernels cannot access state return f'__state->__{sdfg.cfg_id}_{name}' elif (sdfg, name) in framecode.where_allocated and framecode.where_allocated[(sdfg, name)] is not sdfg: return f'__{sdfg.cfg_id}_{name}' diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index aaba068da3..cbff4954af 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -1,4 +1,5 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +import contextvars import ctypes import functools import warnings @@ -60,7 +61,7 @@ class CUDACodeGen(TargetCodeGenerator): """ GPU (CUDA/HIP) code generator. """ target_name = 'cuda' title = 'CUDA' - _in_device_code = False + _in_device_code = contextvars.ContextVar('_in_device_code') def __init__(self, frame_codegen: 'DaCeCodeGenerator', sdfg: SDFG): self._frame = frame_codegen @@ -70,7 +71,7 @@ def __init__(self, frame_codegen: 'DaCeCodeGenerator', sdfg: SDFG): self.create_grid_barrier = False self.dynamic_tbmap_type = None self.extra_nsdfg_args = [] - CUDACodeGen._in_device_code = False + CUDACodeGen._in_device_code.set(False) self._cpu_codegen: Optional['CPUCodeGen'] = None self._block_dims = None self._grid_dims = None @@ -161,35 +162,40 @@ def preprocess(self, sdfg: SDFG) -> None: nsdfg = state.parent if (e.src.desc(nsdfg).storage == dtypes.StorageType.GPU_Global and e.dst.desc(nsdfg).storage == dtypes.StorageType.GPU_Global): + + # NOTE: If possible `memlet_copy_to_absolute_strides()` will collapse a + # ND copy into a 1D copy if the memory is contiguous. In that case + # `copy_shape` will only have one element. copy_shape, src_strides, dst_strides, _, _ = memlet_copy_to_absolute_strides( None, nsdfg, state, e, e.src, e.dst) dims = len(copy_shape) # Skip supported copy types if dims == 1: + # NOTE: We do not check if the stride is `1`. See `_emit_copy()` for more. continue elif dims == 2: - if src_strides[-1] != 1 or dst_strides[-1] != 1: - # NOTE: Special case of continuous copy - # Example: dcol[0:I, 0:J, k] -> datacol[0:I, 0:J] - # with copy shape [I, J] and strides [J*K, K], [J, 1] - try: - is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] - is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] - except (TypeError, ValueError): - is_src_cont = False - is_dst_cont = False - if is_src_cont and is_dst_cont: - continue - else: + # Because `memlet_copy_to_absolute_strides()` handles contiguous copies + # transparently, we only have to check if we have FORTRAN or C order. + # If we do not have them, then we have to turn this into a Map. + is_fortran_order = src_strides[0] == 1 and dst_strides[0] == 1 + is_c_order = src_strides[-1] == 1 and dst_strides[-1] == 1 + if is_c_order or is_fortran_order: continue elif dims > 2: - if not (src_strides[-1] != 1 or dst_strides[-1] != 1): + # Any higher dimensional copies must be C order. If not turn it + # into a copy map. + if src_strides[-1] == 1 and dst_strides[-1] == 1: continue # Turn unsupported copy to a map try: - CopyToMap.apply_to(nsdfg, save=False, annotate=False, a=e.src, b=e.dst) + CopyToMap.apply_to(nsdfg, + save=False, + annotate=False, + a=e.src, + b=e.dst, + options={"ignore_strides": True}) except ValueError: # If transformation doesn't match, continue normally continue @@ -449,7 +455,7 @@ def node_dispatch_predicate(self, sdfg, state, node): if hasattr(node, 'schedule'): # NOTE: Works on nodes and scopes if node.schedule in dtypes.GPU_SCHEDULES: return True - if CUDACodeGen._in_device_code: + if CUDACodeGen._in_device_code.get(): return True return False @@ -916,7 +922,7 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St raise LookupError('Memlet does not point to any of the nodes') if (isinstance(src_node, nodes.AccessNode) and isinstance(dst_node, nodes.AccessNode) - and not CUDACodeGen._in_device_code + and not CUDACodeGen._in_device_code.get() and (src_storage in [dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned] or dst_storage in [dtypes.StorageType.GPU_Global, dtypes.StorageType.CPU_Pinned]) and not (src_storage in cpu_storage_types and dst_storage in cpu_storage_types)): @@ -973,32 +979,21 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St copy_shape, src_strides, dst_strides, src_expr, dst_expr = (memlet_copy_to_absolute_strides( self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, self._cpu_codegen._packed_types)) dims = len(copy_shape) - dtype = dst_node.desc(sdfg).dtype - # Handle unsupported copy types - if dims == 2 and (src_strides[-1] != 1 or dst_strides[-1] != 1): - # NOTE: Special case of continuous copy - # Example: dcol[0:I, 0:J, k] -> datacol[0:I, 0:J] - # with copy shape [I, J] and strides [J*K, K], [J, 1] - try: - is_src_cont = src_strides[0] / src_strides[1] == copy_shape[1] - is_dst_cont = dst_strides[0] / dst_strides[1] == copy_shape[1] - except (TypeError, ValueError): - is_src_cont = False - is_dst_cont = False - if is_src_cont and is_dst_cont: - dims = 1 - copy_shape = [copy_shape[0] * copy_shape[1]] - src_strides = [src_strides[1]] - dst_strides = [dst_strides[1]] - else: - raise NotImplementedError('2D copy only supported with one stride') + # In 1D there is no difference between FORTRAN or C order, thus we will set them + # to the same value. The value indicates if the stride is `1` + # TODO: Figuring out if this is enough for views. + is_fortran_order = src_strides[0] == 1 and dst_strides[0] == 1 + is_c_order = src_strides[-1] == 1 and dst_strides[-1] == 1 - # Currently we only support ND copies when they can be represented - # as a 1D copy or as a 2D strided copy if dims > 2: - if src_strides[-1] != 1 or dst_strides[-1] != 1: + # Currently we only support ND copies when they can be represented + # as a 1D copy or as a 2D strided copy + # NOTE: Not sure if this test is enough, it should also be tested that + # they are ordered, i.e. largest stride on the left. + if not is_c_order: + # TODO: Implement the FORTRAN case. raise NotImplementedError( 'GPU copies are not supported for N-dimensions if they cannot be represented by a strided copy\n' f' Nodes: src {src_node} ({src_storage}), dst {dst_node}({dst_storage})\n' @@ -1026,7 +1021,8 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St for d in range(dims - 2): callsite_stream.write("}") - if dims == 1 and not (src_strides[-1] != 1 or dst_strides[-1] != 1): + elif dims == 1 and is_c_order: + # A 1D copy, in which the stride is 1, known at code generation time. copysize = ' * '.join(_topy(copy_shape)) array_length = copysize copysize += ' * sizeof(%s)' % dtype.ctype @@ -1064,22 +1060,70 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St backend=self.backend), cfg, state_id, [src_node, dst_node]) callsite_stream.write('}') - elif dims == 1 and ((src_strides[-1] != 1 or dst_strides[-1] != 1)): + + elif dims == 1 and not is_c_order: + # This is the case that generated for expressions such as `A[::3]`, we reduce it + # to a 2D copy. + callsite_stream.write( + 'DACE_GPU_CHECK({backend}Memcpy2DAsync({dst}, {dst_stride}, {src}, {src_stride}, {width}, {height}, {kind}, {stream}));\n' + .format( + backend=self.backend, + dst=dst_expr, + dst_stride=f'({_topy(dst_strides[0])}) * sizeof({dst_node.desc(sdfg).dtype.ctype})', + src=src_expr, + src_stride=f'({sym2cpp(src_strides[0])}) * sizeof({src_node.desc(sdfg).dtype.ctype})', + width=f'sizeof({dst_node.desc(sdfg).dtype.ctype})', + height=sym2cpp(copy_shape[0]), + kind=f'{self.backend}Memcpy{src_location}To{dst_location}', + stream=cudastream, + ), + cfg, + state_id, + [src_node, dst_node], + ) + + elif dims == 2 and is_c_order: + # Copying a 2D array that are in C order, i.e. last stride is 1. callsite_stream.write( - 'DACE_GPU_CHECK(%sMemcpy2DAsync(%s, %s, %s, %s, %s, %s, %sMemcpy%sTo%s, %s));\n' % - (self.backend, dst_expr, _topy(dst_strides[0]) + ' * sizeof(%s)' % dst_node.desc(sdfg).dtype.ctype, - src_expr, sym2cpp(src_strides[0]) + ' * sizeof(%s)' % src_node.desc(sdfg).dtype.ctype, - 'sizeof(%s)' % dst_node.desc(sdfg).dtype.ctype, sym2cpp( - copy_shape[0]), self.backend, src_location, dst_location, cudastream), cfg, state_id, - [src_node, dst_node]) - elif dims == 2: + 'DACE_GPU_CHECK({backend}Memcpy2DAsync({dst}, {dst_stride}, {src}, {src_stride}, {width}, {height}, {kind}, {stream}));\n' + .format( + backend=self.backend, + dst=dst_expr, + dst_stride=f'({_topy(dst_strides[0])}) * sizeof({dst_node.desc(sdfg).dtype.ctype})', + src=src_expr, + src_stride=f'({sym2cpp(src_strides[0])}) * sizeof({src_node.desc(sdfg).dtype.ctype})', + width=f'({sym2cpp(copy_shape[1])}) * sizeof({dst_node.desc(sdfg).dtype.ctype})', + height=sym2cpp(copy_shape[0]), + kind=f'{self.backend}Memcpy{src_location}To{dst_location}', + stream=cudastream, + ), + cfg, + state_id, + [src_node, dst_node], + ) + elif dims == 2 and is_fortran_order: + # Copying a 2D array into a 2D array that is in FORTRAN order, i.e. first stride + # is one. The CUDA API can not handle such cases directly, however, by "transposing" + # it is possible to use `Memcyp2DAsync`. callsite_stream.write( - 'DACE_GPU_CHECK(%sMemcpy2DAsync(%s, %s, %s, %s, %s, %s, %sMemcpy%sTo%s, %s));\n' % - (self.backend, dst_expr, _topy(dst_strides[0]) + ' * sizeof(%s)' % dst_node.desc(sdfg).dtype.ctype, - src_expr, sym2cpp(src_strides[0]) + ' * sizeof(%s)' % src_node.desc(sdfg).dtype.ctype, - sym2cpp(copy_shape[1]) + ' * sizeof(%s)' % dst_node.desc(sdfg).dtype.ctype, sym2cpp( - copy_shape[0]), self.backend, src_location, dst_location, cudastream), cfg, state_id, - [src_node, dst_node]) + 'DACE_GPU_CHECK({backend}Memcpy2DAsync({dst}, {dst_stride}, {src}, {src_stride}, {width}, {height}, {kind}, {stream}));\n' + .format( + backend=self.backend, + dst=dst_expr, + dst_stride=f'({_topy(dst_strides[1])}) * sizeof({dst_node.desc(sdfg).dtype.ctype})', + src=src_expr, + src_stride=f'({sym2cpp(src_strides[1])}) * sizeof({src_node.desc(sdfg).dtype.ctype})', + width=f'({sym2cpp(copy_shape[0])}) * sizeof({dst_node.desc(sdfg).dtype.ctype})', + height=sym2cpp(copy_shape[1]), + kind=f'{self.backend}Memcpy{src_location}To{dst_location}', + stream=cudastream, + ), + cfg, + state_id, + [src_node, dst_node], + ) + else: + raise NotImplementedError("The requested copy operation is not implemented.") # Post-copy synchronization if is_sync: @@ -1126,7 +1170,6 @@ def _emit_copy(self, state_id: int, src_node: nodes.Node, src_storage: dtypes.St # Obtain copy information copy_shape, src_strides, dst_strides, src_expr, dst_expr = (memlet_copy_to_absolute_strides( self._dispatcher, sdfg, state, edge, src_node, dst_node, self._cpu_codegen._packed_types)) - dims = len(copy_shape) funcname = 'dace::%sTo%s%dD' % (_get_storagename(src_storage), _get_storagename(dst_storage), dims) @@ -1242,7 +1285,7 @@ def generate_state(self, callsite_stream: CodeIOStream, generate_state_footer: bool = False) -> None: # Two modes: device-level state and if this state has active streams - if CUDACodeGen._in_device_code: + if CUDACodeGen._in_device_code.get(): self.generate_devicelevel_state(sdfg, cfg, state, function_stream, callsite_stream) else: # Active streams found. Generate state normally and sync with the @@ -1467,10 +1510,9 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub outer_name = cpp.ptr(node.data, desc, nsdfg, self._frame) # Create name from within kernel - oldval = CUDACodeGen._in_device_code - CUDACodeGen._in_device_code = True + token = CUDACodeGen._in_device_code.set(True) inner_name = cpp.ptr(node.data, desc, nsdfg, self._frame) - CUDACodeGen._in_device_code = oldval + CUDACodeGen._in_device_code.reset(token) self.extra_nsdfg_args.append((desc.as_arg(name=''), inner_name, outer_name)) self._dispatcher.defined_vars.add(inner_name, @@ -1530,9 +1572,9 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub if not defined_type: defined_type, ctype = self._dispatcher.defined_vars.get(ptrname, is_global=is_global) - CUDACodeGen._in_device_code = True + token = CUDACodeGen._in_device_code.set(True) inner_ptrname = cpp.ptr(aname, data_desc, sdfg, self._frame) - CUDACodeGen._in_device_code = False + CUDACodeGen._in_device_code.reset(token) self._dispatcher.defined_vars.add(inner_ptrname, defined_type, @@ -1549,9 +1591,9 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub dtypes.AllocationLifetime.Persistent, dtypes.AllocationLifetime.External) defined_type, ctype = self._dispatcher.defined_vars.get(ptrname, is_global=is_global) - CUDACodeGen._in_device_code = True + token = CUDACodeGen._in_device_code.set(True) inner_ptrname = cpp.ptr(aname, data_desc, sdfg, self._frame) - CUDACodeGen._in_device_code = False + CUDACodeGen._in_device_code.reset(token) self._dispatcher.defined_vars.add(inner_ptrname, defined_type, ctype, allow_shadowing=True) # Rename argument in kernel prototype as necessary @@ -2059,8 +2101,8 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S self._dispatcher.defined_vars.add(varname, DefinedType.Scalar, tidtype.ctype) # Dispatch internal code - assert CUDACodeGen._in_device_code is False - CUDACodeGen._in_device_code = True + assert CUDACodeGen._in_device_code.get() is False + CUDACodeGen._in_device_code.set(True) self._kernel_map = node self._kernel_state = cfg.node(state_id) self._block_dims = block_dims @@ -2113,7 +2155,7 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S self._block_dims = None self._kernel_map = None self._kernel_state = None - CUDACodeGen._in_device_code = False + CUDACodeGen._in_device_code.set(False) self._grid_dims = None self.dynamic_tbmap_type = None @@ -2138,7 +2180,7 @@ def get_next_scope_entries(self, dfg, scope_entry): def generate_devicelevel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSubgraphView, state_id: int, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: # Sanity check - assert CUDACodeGen._in_device_code == True + assert CUDACodeGen._in_device_code.get() == True dfg = cfg.state(state_id) scope_entry = dfg_scope.source_nodes()[0] @@ -2568,14 +2610,14 @@ def generate_node(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphVi gen(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream) return - if not CUDACodeGen._in_device_code: + if not CUDACodeGen._in_device_code.get(): self._cpu_codegen.generate_node(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream) return if isinstance(node, nodes.ExitNode): self._locals.clear_scope(self._code_state.indentation + 1) - if CUDACodeGen._in_device_code and isinstance(node, nodes.MapExit): + if CUDACodeGen._in_device_code.get() and isinstance(node, nodes.MapExit): return # skip self._cpu_codegen.generate_node(sdfg, cfg, dfg, state_id, node, function_stream, callsite_stream) @@ -2697,7 +2739,7 @@ def _generate_condition_from_location(self, name: str, index_expr: str, node: no def _generate_Tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphView, state_id: int, node: nodes.Tasklet, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: generated_preamble_scopes = 0 - if self._in_device_code: + if self._in_device_code.get(): # If location dictionary prescribes that the code should run on a certain group of threads/blocks, # add condition generated_preamble_scopes += self._generate_condition_from_location('gpu_thread', self._get_thread_id(), diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 33bc562f73..7a62170c3d 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -140,6 +140,8 @@ def generate_fileheader(self, sdfg: SDFG, global_stream: CodeIOStream, backend: if backend == 'frame': global_stream.write('#include "../../include/hash.h"\n', sdfg) + global_stream.write('#ifdef WITH_CUDA\n#include \n#endif\n', sdfg) + ######################################################### # Environment-based includes for env in self.environments: @@ -266,7 +268,13 @@ def generate_footer(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stre f''' DACE_EXPORTED void __program_{fname}({mangle_dace_state_struct_name(fname)} *__state{params_comma}) {{ + #ifdef WITH_CUDA + nvtxRangePushA("{fname}"); + #endif __program_{fname}_internal(__state{paramnames_comma}); + #ifdef WITH_CUDA + nvtxRangePop(); + #endif }}''', sdfg) for target in self._dispatcher.used_targets: diff --git a/dace/config_schema.yml b/dace/config_schema.yml index b5a7914018..189931ff3a 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -303,7 +303,7 @@ required: type: str title: nvcc Arguments description: Compiler argument flags for CUDA - default: '-Xcompiler -march=native --use_fast_math -Xcompiler -Wno-unused-parameter' + default: '--generate-line-info -Xcompiler -march=native -Xcompiler -Wno-unused-parameter' default_Windows: '-O3 --use_fast_math' hip_args: diff --git a/dace/data.py b/dace/data.py index 74c1e8b985..3279aff63b 100644 --- a/dace/data.py +++ b/dace/data.py @@ -210,6 +210,8 @@ def _validate(self): if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.shape): raise TypeError('Shape must be a list or tuple of integer values ' 'or symbols') + if any((shp < 0) == True for shp in self.shape): + raise TypeError(f'Found negative shape in Data, its shape was {self.shape}') return True def to_json(self): @@ -1471,12 +1473,20 @@ def validate(self): super(Array, self).validate() if len(self.strides) != len(self.shape): raise TypeError('Strides must be the same size as shape') + if len(self.offset) != len(self.shape): + raise TypeError('Offset must be the same size as shape') if any(not isinstance(s, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) for s in self.strides): raise TypeError('Strides must be a list or tuple of integer values or symbols') - - if len(self.offset) != len(self.shape): - raise TypeError('Offset must be the same size as shape') + if any(not isinstance(off, (int, symbolic.SymExpr, symbolic.symbol, symbolic.sympy.Basic)) + for off in self.offset): + raise TypeError('Offset must be a list or tuple of integer values or symbols') + + # Actually it would be enough to only enforce the non negativity only if the shape is larger than one. + if any((stride < 0) == True for stride in self.strides): + raise TypeError(f'Found negative strides in array, they were {self.strides}') + if (self.total_size < 0) == True: + raise TypeError(f'The total size of an array must be positive but it was negative {self.total_size}') def covers_range(self, rng): if len(rng) != len(self.shape): diff --git a/dace/memlet.py b/dace/memlet.py index 46dac51edf..090a7890fa 100644 --- a/dace/memlet.py +++ b/dace/memlet.py @@ -534,6 +534,9 @@ def dst_subset(self, new_dst_subset): def validate(self, sdfg, state): if self.data is not None and self.data not in sdfg.arrays: raise KeyError('Array "%s" not found in SDFG' % self.data) + # NOTE: We do not check here is the subsets have a negative size, because such as subset + # is valid, in certain cases, for example if an AccessNode is connected to a MapEntry, + # because the Map is not executed. Thus we do the check in the `validate_state()` function. def used_symbols(self, all_symbols: bool, edge=None) -> Set[str]: """ diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index ccfb0adada..f501697b57 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -9,7 +9,7 @@ import networkx as nx -from dace import dtypes, subsets, symbolic +from dace import dtypes, subsets, symbolic, data from dace.dtypes import DebugInfo if TYPE_CHECKING: @@ -656,7 +656,6 @@ def validate_state(state: 'dace.sdfg.SDFGState', ) ######################################## - # Memlet checks for eid, e in enumerate(state.edges()): # Reference check if id(e) in references: @@ -680,6 +679,27 @@ def validate_state(state: 'dace.sdfg.SDFGState', except Exception as ex: raise InvalidSDFGEdgeError("Edge validation failed: " + str(ex), sdfg, state_id, eid) + # If the edge is a connection between two AccessNodes check if the subset has negative size. + # NOTE: We _should_ do this check in `Memlet.validate()` however, this is not possible, + # because the connection between am AccessNode and a MapEntry, with a negative size, is + # legal because, the Map will not run in that case. However, this constellation can not + # be tested for in the Memlet's validation function, so we have to do it here. + # NOTE: Zero size is explicitly allowed because it is essentially `memcpy(dst, src, 0)` + # which is save. + # TODO: The AN to AN connection is the most obvious one, but it should be extended. + if isinstance(e.src, nd.AccessNode) and isinstance(e.dst, nd.AccessNode): + e_memlet: dace.Memlet = e.data + if e_memlet.subset is not None: + if any((ss < 0) == True for ss in e_memlet.subset.size()): + raise InvalidSDFGEdgeError( + f'`subset` of an AccessNode to AccessNode Memlet contains a negative size; the size was {e_memlet.subset.size()}', + sdfg, state_id, eid) + if e_memlet.other_subset is not None: + if any((ss < 0) == True for ss in e_memlet.other_subset.size()): + raise InvalidSDFGEdgeError( + f'`other_subset` of an AccessNode to AccessNode Memlet contains a negative size; the size was {e_memlet.other_subset.size()}', + sdfg, state_id, eid) + # For every memlet, obtain its full path in the DFG path = state.memlet_path(e) src_node = path[0].src diff --git a/tests/codegen/cuda_memcopy_test.py b/tests/codegen/cuda_memcopy_test.py index 36c5d19f7a..34853f0adb 100644 --- a/tests/codegen/cuda_memcopy_test.py +++ b/tests/codegen/cuda_memcopy_test.py @@ -1,8 +1,10 @@ """ Tests code generation for array copy on GPU target. """ import dace from dace.transformation.auto import auto_optimize +from dace.sdfg import nodes as dace_nodes import pytest +import copy import re # this test requires cupy module @@ -12,6 +14,237 @@ rng = cp.random.default_rng(42) +def count_node(sdfg: dace.SDFG, node_type): + nb_nodes = 0 + for rsdfg in sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, node_type): + nb_nodes += 1 + return nb_nodes + + +def _make_2d_gpu_copy_sdfg(c_order: bool, ) -> dace.SDFG: + """The SDFG performs a copy from the input of the output, that is continuous. + + Essentially the function will generate am SDFG that performs the following + operation: + ```python + B[2:7, 3:9] = A[1:6, 2:8] + ``` + However, two arrays have a shape of `(20, 30)`. This means that this copy + can not be expressed as a continuous copy. Regardless which memory order + that is used, which can be selected by `c_order`. + """ + sdfg = dace.SDFG(f'gpu_2d_copy_{"corder" if c_order else "forder"}_copy_sdfg') + state = sdfg.add_state(is_start_block=True) + + for aname in 'AB': + sdfg.add_array( + name=aname, + shape=(20, 30), + dtype=dace.float64, + storage=dace.StorageType.GPU_Global, + transient=False, + strides=((30, 1) if c_order else (1, 20)), + ) + + state.add_nedge( + state.add_access("A"), + state.add_access("B"), + dace.Memlet("A[1:6, 2:8] -> [2:7, 3:9]"), + ) + sdfg.validate() + + return sdfg + + +def _perform_2d_gpu_copy_test(c_order: bool, ): + """Check 2D strided copies are handled by the `Memcpy2D` family. + """ + sdfg = _make_2d_gpu_copy_sdfg(c_order=c_order) + assert count_node(sdfg, dace_nodes.AccessNode) == 2 + assert count_node(sdfg, dace_nodes.MapEntry) == 0 + + # Now generate the code. + csdfg = sdfg.compile() + + # Ensure that the copy was not turned into a Map + assert count_node(csdfg.sdfg, dace_nodes.AccessNode) == 2 + assert count_node(csdfg.sdfg, dace_nodes.MapEntry) == 0 + + # Ensure that the correct call was issued. + # We have to look at the CPU code and not at the GPU. + code = sdfg.generate_code()[0].clean_code + m = re.search(r'(cuda|hip)Memcpy2DAsync\b', code) + assert m is not None + + # Generate input data. + ref = { + "A": cp.array(cp.random.rand(20, 30), dtype=cp.float64, order="C" if c_order else "F"), + "B": cp.array(cp.random.rand(20, 30), dtype=cp.float64, order="C" if c_order else "F"), + } + + # We can not use `deepcopy` or `.copy()` because this would set the strides to `C` order. + res = {} + for name in ref.keys(): + res[name] = cp.empty_like(ref[name]) + res[name][:] = ref[name][:] + + exp_strides = (240, 8) if c_order else (8, 160) + assert all(v.strides == exp_strides for v in ref.values()) + assert all(v.strides == exp_strides for v in res.values()) + + # Now apply the operation on the reference + ref["B"][2:7, 3:9] = ref["A"][1:6, 2:8] + + # Now run the SDFG + csdfg(**res) + + assert all(cp.all(ref[k] == res[k]) for k in ref.keys()) + + +def _make_1d_gpu_copy( + src_row: bool, + dst_row: bool, +) -> dace.SDFG: + sdfg = dace.SDFG(f'gpu_1d_copy_{"row" if src_row else "col"}_{"row" if src_row else "col"}_copy_sdfg') + state = sdfg.add_state(is_start_block=True) + + for aname in 'AB': + sdfg.add_array( + name=aname, + shape=(20, 20), + dtype=dace.float64, + storage=dace.StorageType.GPU_Global, + transient=False, + ) + + src_subset = "1, 1:9" if src_row else "1:9, 2" + dst_subset = "3, 0:8" if dst_row else "0:8, 4" + + state.add_nedge( + state.add_access("A"), + state.add_access("B"), + dace.Memlet(f"A[{src_subset}] -> [{dst_subset}]"), + ) + sdfg.validate() + return sdfg + + +def _perform_1d_gpu_copy( + src_row: bool, + dst_row: bool, +): + sdfg = _make_1d_gpu_copy(src_row=src_row, dst_row=dst_row) + assert count_node(sdfg, dace_nodes.AccessNode) == 2 + assert count_node(sdfg, dace_nodes.MapEntry) == 0 + + # Now generate the code. + csdfg = sdfg.compile() + + # Ensure that the copy was not turned into a Map + assert count_node(csdfg.sdfg, dace_nodes.AccessNode) == 2 + assert count_node(csdfg.sdfg, dace_nodes.MapEntry) == 0 + + # It will always result in a call to `Memcpy2D` except the source and the destination + # operates on rows, then it is a simple 1D copy. + if src_row and dst_row: + code = sdfg.generate_code()[0].clean_code + m = re.search(r'(cuda|hip)MemcpyAsync\b', code) + assert m is not None + else: + code = sdfg.generate_code()[0].clean_code + m = re.search(r'(cuda|hip)Memcpy2DAsync\b', code) + assert m is not None + + # Generate input data. + ref = { + "A": cp.array(cp.random.rand(20, 20), dtype=cp.float64, order="C"), + "B": cp.array(cp.random.rand(20, 20), dtype=cp.float64, order="C"), + } + res = {k: v.copy() for k, v in ref.items()} + + # Now perform the reference operation + src_subset = ref["A"][1, 1:9] if src_row else ref["A"][1:9, 2] + if dst_row: + ref["B"][3, 0:8] = src_subset + else: + ref["B"][0:8, 4] = src_subset + + # Now run the SDFG + csdfg(**res) + + assert all(cp.all(ref[k] == res[k]) for k in ref.keys()) + + +def _make_pseudo_1d_copy_sdfg(c_order: bool, ) -> dace.SDFG: + """An SDFG that performs a 2D copy that can be turned into a 1d copy. + """ + sdfg = dace.SDFG(f'gpu_pseudo_1d_copy_{"corder" if c_order else "forder"}_sdfg') + state = sdfg.add_state(is_start_block=True) + + for aname in 'AB': + sdfg.add_array( + name=aname, + shape=(20, 30), + dtype=dace.float64, + storage=dace.StorageType.GPU_Global, + transient=False, + strides=((30, 1) if c_order else (1, 20)), + ) + + cpy_subset = "1:18, 0:30" if c_order else "0:20, 2:29" + state.add_nedge( + state.add_access("A"), + state.add_access("B"), + dace.Memlet(f"A[{cpy_subset}] -> [{cpy_subset}]"), + ) + sdfg.validate() + + return sdfg + + +def _perform_pseudo_1d_copy_test(c_order: bool): + sdfg = _make_pseudo_1d_copy_sdfg(c_order=c_order) + assert count_node(sdfg, dace_nodes.AccessNode) == 2 + assert count_node(sdfg, dace_nodes.MapEntry) == 0 + + # Now generate the code. + csdfg = sdfg.compile() + + # Ensure that the copy was not turned into a Map + assert count_node(csdfg.sdfg, dace_nodes.AccessNode) == 2 + assert count_node(csdfg.sdfg, dace_nodes.MapEntry) == 0 + + code = sdfg.generate_code()[0].clean_code + m = re.search(r'(cuda|hip)MemcpyAsync\b', code) + assert m is not None + + # Generate input data. + ref = { + "A": cp.array(cp.random.rand(20, 30), dtype=cp.float64, order="C" if c_order else "F"), + "B": cp.array(cp.random.rand(20, 30), dtype=cp.float64, order="C" if c_order else "F"), + } + + # We can not use `deepcopy` or `.copy()` because this would set the strides to `C` order. + res = {} + for name in ref.keys(): + res[name] = cp.empty_like(ref[name]) + res[name][:] = ref[name][:] + + # Perform the reference computation. + if c_order: + ref["B"][1:18, 0:30] = ref["A"][1:18, 0:30] + else: + ref["B"][0:20, 2:29] = ref["A"][0:20, 2:29] + + # Now run the SDFG + csdfg(**res) + + assert all(cp.all(ref[k] == res[k]) for k in ref.keys()) + + @pytest.mark.gpu def test_gpu_shared_to_global_1D(): M = 32 @@ -88,6 +321,96 @@ def transpose_and_add_shared_to_global(A: dace.float64[M, N], B: dace.float64[N, assert m is not None +@pytest.mark.gpu +def test_gpu_1d_copy(): + sdfg = dace.SDFG("gpu_1d_copy_sdfg") + state = sdfg.add_state(is_start_block=True) + + for aname in 'AB': + sdfg.add_array( + name=aname, + shape=(20, ), + dtype=dace.float64, + storage=dace.StorageType.GPU_Global, + transient=False, + ) + state.add_nedge( + state.add_access("A"), + state.add_access("B"), + dace.Memlet("A[2:13] -> [1:12]"), + ) + sdfg.validate() + + csdfg = sdfg.compile() + assert count_node(csdfg.sdfg, dace_nodes.AccessNode) == 2 + assert count_node(csdfg.sdfg, dace_nodes.MapEntry) == 0 + + code = sdfg.generate_code()[0].clean_code + m = re.search(r'(cuda|hip)MemcpyAsync\b', code) + assert m is not None + + # Now run the sdfg. + ref = { + "A": cp.array(cp.random.rand(20), dtype=cp.float64), + "B": cp.array(cp.random.rand(20), dtype=cp.float64), + } + res = {k: v.copy() for k, v in ref.items()} + + ref["B"][1:12] = ref["A"][2:13] + csdfg(**res) + + assert all(cp.all(ref[k] == res[k]) for k in ref.keys()) + + +@pytest.mark.gpu +def test_2d_c_order_gpu_copy(): + _perform_2d_gpu_copy_test(c_order=True) + + +@pytest.mark.gpu +def test_2d_f_order_gpu_copy(): + _perform_2d_gpu_copy_test(c_order=False) + + +@pytest.mark.gpu +def test_gpu_1d_copy_row_row(): + _perform_1d_gpu_copy(src_row=True, dst_row=True) + + +@pytest.mark.gpu +def test_gpu_1d_copy_row_col(): + _perform_1d_gpu_copy(src_row=True, dst_row=False) + + +@pytest.mark.gpu +def test_gpu_1d_copy_col_col(): + _perform_1d_gpu_copy(src_row=False, dst_row=False) + + +@pytest.mark.gpu +def test_gpu_1d_copy_col_row(): + _perform_1d_gpu_copy(src_row=False, dst_row=True) + + +@pytest.mark.gpu +def test_gpu_pseudo_1d_copy_c_order(): + _perform_pseudo_1d_copy_test(c_order=True) + + +@pytest.mark.gpu +def test_gpu_pseudo_1d_copy_f_order(): + _perform_pseudo_1d_copy_test(c_order=False) + + if __name__ == '__main__': test_gpu_shared_to_global_1D() test_gpu_shared_to_global_1D_accumulate() + test_2d_c_order_copy() + test_2d_f_order_copy() + test_gpu_1d_copy_row_row() + test_gpu_1d_copy_row_col() + test_gpu_1d_copy_col_row() + test_gpu_1d_copy_col_col() + test_gpu_1d_copy() + test_gpu_pseudo_1d_copy_c_order() + test_gpu_pseudo_1d_copy_f_order() diff --git a/tests/sdfg/validation/subset_size_test.py b/tests/sdfg/validation/subset_size_test.py new file mode 100644 index 0000000000..bc01b85a12 --- /dev/null +++ b/tests/sdfg/validation/subset_size_test.py @@ -0,0 +1,83 @@ +from typing import Tuple + +import dace + +import re +import pytest +import numpy as np + + +def _make_sdfg_with_zero_sized_an_to_an_memlet() -> Tuple[dace.SDFG, dace.SDFGState]: + """Generates an SDFG that performs a copy that has a zero size. + """ + sdfg = dace.SDFG("zero_size_copy_sdfg") + state = sdfg.add_state(is_start_block=True) + + for name in "AB": + sdfg.add_array( + name=name, + shape=(20, 20), + dtype=dace.float64, + transient=True, + ) + + state.add_nedge( + state.add_access("A"), + state.add_access("B"), + dace.Memlet("A[2:17, 2:2] -> [2:18, 3:3]"), + ) + + return sdfg, state + + +def test_an_to_an_memlet_with_zero_size(): + sdfg, state = _make_sdfg_with_zero_sized_an_to_an_memlet() + assert sdfg.number_of_nodes() == 1 + assert state.number_of_nodes() == 2 + + sdfg.validate() + + # This zero sized copy should be considered valid. + assert sdfg.is_valid() + + # The SDFG should be a no ops. + ref = { + "A": np.array(np.random.rand(20, 20), copy=True, order="C", dtype=np.float64), + "B": np.array(np.random.rand(20, 20), copy=True, order="C", dtype=np.float64), + } + res = {k: np.array(v, order="C", copy=True) for k, v in ref.items()} + + csdfg = sdfg.compile() + assert csdfg.sdfg.number_of_nodes() == 1 + assert csdfg.sdfg.states()[0].number_of_nodes() == 2 + csdfg(**res) + + assert all(np.all(ref[k] == res[k]) for k in ref.keys()) + + +def test_an_to_an_memlet_with_negative_size(): + """Tests if an AccessNode to AccessNode connection leads to an invalid SDFG. + """ + sdfg = dace.SDFG("an_to_an_memlet_with_negative_size") + state = sdfg.add_state(is_start_block=True) + + for name in "AB": + sdfg.add_array( + name=name, + shape=(20, 20), + dtype=dace.float64, + transient=True, + ) + + state.add_nedge( + state.add_access("A"), + state.add_access("B"), + dace.Memlet("A[2:17, 13:2] -> [2:18, 14:3]"), + ) + + with pytest.raises( + expected_exception=dace.sdfg.InvalidSDFGEdgeError, + match=re.escape( + f'`subset` of an AccessNode to AccessNode Memlet contains a negative size; the size was [15, -11]'), + ): + sdfg.validate()