Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/test_shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TestElement(xo.HybridClass):
unsigned int gid = blockIdx.x*blockDim.x + threadIdx.x; // global thread ID: 0,1,2,3

// init shared memory with chunk of input array
extern __shared__ double sdata[2];
extern __shared__ double sdata[];
sdata[tid] = input_arr[gid];
__syncthreads();

Expand Down
11 changes: 9 additions & 2 deletions xobjects/context_cupy.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,15 +352,22 @@ def __invert__(self):

cudaheader: List[SourceType] = [
"""\
typedef signed long long int64_t; //only_for_context cuda
typedef signed int int32_t; //only_for_context cuda
typedef signed short int16_t; //only_for_context cuda
typedef signed char int8_t; //only_for_context cuda
typedef unsigned long long uint64_t; //only_for_context cuda
typedef unsigned int uint32_t; //only_for_context cuda
typedef unsigned short uint16_t; //only_for_context cuda
typedef unsigned char uint8_t; //only_for_context cuda

#if defined(__CUDACC__) && !defined(__HIPCC__)
typedef signed long long int64_t;
typedef unsigned long long uint64_t;
#endif

#ifndef NULL
#define NULL nullptr
#endif

"""
]

Expand Down
10 changes: 9 additions & 1 deletion xobjects/headers/atomicadd.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ DEF_ATOMIC_ADD(double , f64)
// -------------------------------------------
#if defined(XO_CONTEXT_CUDA)
// CUDA compiler may not have <stdint.h>, so define the types if needed.
#ifdef __CUDACC_RTC__
#if defined(__CUDACC_RTC__) && !defined(__HIPCC__)
// NVRTC (CuPy RawModule default) can’t see <stdint.h>, so detect it via __CUDACC_RTC__
typedef signed char int8_t;
typedef short int16_t;
Expand All @@ -111,6 +111,14 @@ DEF_ATOMIC_ADD(double , f64)
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
#elif defined(__HIPCC__) && !defined(__CUDACC_RTC__)
// ROCm appears to have definitions for 64-bit int types
typedef signed char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
#else
// Alternatively, NVCC path is fine with host headers
#include <stdint.h>
Expand Down