|
4 | 4 | #ifndef MSCCLPP_GPU_HPP_ |
5 | 5 | #define MSCCLPP_GPU_HPP_ |
6 | 6 |
|
7 | | -#if defined(__HIP_PLATFORM_AMD__) |
| 7 | +#include <mscclpp/device.hpp> |
8 | 8 |
|
9 | | -#include <hip/hip_runtime.h> |
| 9 | +#if defined(MSCCLPP_DEVICE_HIP) |
10 | 10 |
|
11 | 11 | using cudaError_t = hipError_t; |
12 | 12 | using cudaEvent_t = hipEvent_t; |
@@ -62,6 +62,7 @@ constexpr auto CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL = HIP_POINTER_ATTRIBUTE_DEVIC |
62 | 62 | #define CUDA_ERROR_DEINITIALIZED hipErrorDeinitialized |
63 | 63 | #define CUDA_ERROR_CONTEXT_IS_DESTROYED hipErrorContextIsDestroyed |
64 | 64 | #define CUDA_ERROR_LAUNCH_FAILED hipErrorLaunchFailure |
| 65 | +#define CUDA_ERROR_NOT_SUPPORTED hipErrorNotSupported |
65 | 66 | #define CUDA_ERROR_INVALID_VALUE hipErrorInvalidValue |
66 | 67 |
|
67 | 68 | #define cudaEventCreate(...) hipEventCreate(__VA_ARGS__) |
@@ -122,29 +123,29 @@ constexpr auto CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL = HIP_POINTER_ATTRIBUTE_DEVIC |
122 | 123 | #define cuMemGetAllocationGranularity(...) hipMemGetAllocationGranularity(__VA_ARGS__) |
123 | 124 | #define cuPointerGetAttribute(...) hipPointerGetAttribute(__VA_ARGS__) |
124 | 125 |
|
125 | | -#else |
| 126 | +#else // !defined(MSCCLPP_DEVICE_HIP) |
126 | 127 |
|
127 | 128 | #include <cuda.h> |
128 | 129 | #include <cuda_runtime.h> |
129 | 130 |
|
130 | | -#endif |
| 131 | +#endif // !defined(MSCCLPP_DEVICE_HIP) |
131 | 132 |
|
132 | 133 | // NVLS |
133 | | -#if !defined(__HIP_PLATFORM_AMD__) |
| 134 | +#if !defined(MSCCLPP_DEVICE_HIP) |
134 | 135 | #include <linux/version.h> |
135 | 136 | #if CUDART_VERSION < 12030 |
136 | 137 | #define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL) |
137 | 138 | #endif |
138 | 139 | // We need CUDA 12.3 above and kernel 5.6.0 above for NVLS API |
139 | 140 | #define CUDA_NVLS_API_AVAILABLE ((CUDART_VERSION >= 12030) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0))) |
140 | | -#else // defined(__HIP_PLATFORM_AMD__) |
| 141 | +#else // defined(MSCCLPP_DEVICE_HIP) |
141 | 142 | #define CUDA_NVLS_API_AVAILABLE 0 |
142 | 143 | // NVLS is not supported on AMD platform, just to avoid compilation error |
143 | | -#define CU_MEM_HANDLE_TYPE_FABRIC (0x8ULL) |
144 | | -#endif // !defined(__HIP_PLATFORM_AMD__) |
| 144 | +#define CU_MEM_HANDLE_TYPE_FABRIC ((hipMemAllocationHandleType)0x8ULL) |
| 145 | +#endif // defined(MSCCLPP_DEVICE_HIP) |
145 | 146 |
|
146 | 147 | // GPU sync threads |
147 | | -#if defined(__HIP_PLATFORM_AMD__) |
| 148 | +#if defined(MSCCLPP_DEVICE_HIP) |
148 | 149 | #define __syncshm() asm volatile("s_waitcnt lgkmcnt(0) \n s_barrier"); |
149 | 150 | #else |
150 | 151 | #define __syncshm() __syncthreads(); |
|
0 commit comments