diff --git a/tests/test_shared_memory.py b/tests/test_shared_memory.py index 31156d6..b010cdd 100644 --- a/tests/test_shared_memory.py +++ b/tests/test_shared_memory.py @@ -47,7 +47,7 @@ class TestElement(xo.HybridClass): unsigned int gid = blockIdx.x*blockDim.x + threadIdx.x; // global thread ID: 0,1,2,3 // init shared memory with chunk of input array - extern __shared__ double sdata[2]; + extern __shared__ double sdata[]; sdata[tid] = input_arr[gid]; __syncthreads(); diff --git a/xobjects/context_cupy.py b/xobjects/context_cupy.py index 3838376..5ba68e6 100644 --- a/xobjects/context_cupy.py +++ b/xobjects/context_cupy.py @@ -352,15 +352,22 @@ def __invert__(self): cudaheader: List[SourceType] = [ """\ -typedef signed long long int64_t; //only_for_context cuda typedef signed int int32_t; //only_for_context cuda typedef signed short int16_t; //only_for_context cuda typedef signed char int8_t; //only_for_context cuda -typedef unsigned long long uint64_t; //only_for_context cuda typedef unsigned int uint32_t; //only_for_context cuda typedef unsigned short uint16_t; //only_for_context cuda typedef unsigned char uint8_t; //only_for_context cuda +#if defined(__CUDACC__) && !defined(__HIPCC__) +typedef signed long long int64_t; +typedef unsigned long long uint64_t; +#endif + +#ifndef NULL + #define NULL nullptr +#endif + """ ] diff --git a/xobjects/headers/atomicadd.h b/xobjects/headers/atomicadd.h index 6782705..89e4f83 100644 --- a/xobjects/headers/atomicadd.h +++ b/xobjects/headers/atomicadd.h @@ -101,7 +101,7 @@ DEF_ATOMIC_ADD(double , f64) // ------------------------------------------- #if defined(XO_CONTEXT_CUDA) // CUDA compiler may not have , so define the types if needed. - #ifdef __CUDACC_RTC__ + #if defined(__CUDACC_RTC__) && !defined(__HIPCC__) // NVRTC (CuPy RawModule default) can’t see , so detect it via __CUDACC_RTC__ typedef signed char int8_t; typedef short int16_t; @@ -111,6 +111,14 @@ DEF_ATOMIC_ADD(double , f64) typedef unsigned short uint16_t; typedef unsigned int uint32_t; typedef unsigned long long uint64_t; + #elif defined(__HIPCC__) && !defined(__CUDACC_RTC__) + // ROCm appears to have definitions for 64-bit int types + typedef signed char int8_t; + typedef short int16_t; + typedef int int32_t; + typedef unsigned char uint8_t; + typedef unsigned short uint16_t; + typedef unsigned int uint32_t; #else // Alternatively, NVCC path is fine with host headers #include