diff --git a/docs_input/api/dft/fft/fft.rst b/docs_input/api/dft/fft/fft.rst index 10eae43a..75385a41 100644 --- a/docs_input/api/dft/fft/fft.rst +++ b/docs_input/api/dft/fft/fft.rst @@ -7,6 +7,10 @@ Perform a 1D FFT. Batching is supported for any tensor with a rank higher than 1 .. versionadded:: 0.6.0 +.. note:: + + FFT kernel fusion is supported by cuFFTDx if ``-DMATX_EN_MATHDX=ON`` is enabled. + .. doxygenfunction:: fft(const OpA &a, uint64_t fft_size = 0, FFTNorm norm = FFTNorm::BACKWARD) .. doxygenfunction:: fft(const OpA &a, const int32_t (&axis)[1], uint64_t fft_size = 0, FFTNorm norm = FFTNorm::BACKWARD) diff --git a/docs_input/api/linalg/matvec/matmul.rst b/docs_input/api/linalg/matvec/matmul.rst index 9c29364d..f9524cdb 100644 --- a/docs_input/api/linalg/matvec/matmul.rst +++ b/docs_input/api/linalg/matvec/matmul.rst @@ -11,6 +11,10 @@ is supported for any tensor with a rank higher than 2. .. versionadded:: 0.6.0 +.. note:: + + GEMM kernel fusion is supported by cuBLASDx if ``-DMATX_EN_MATHDX=ON`` is enabled. + .. doxygenfunction:: matmul(const OpA &A, const OpB &B, float alpha = 1.0, float beta = 0.0) .. doxygenfunction:: matmul(const OpA &A, const OpB &B, const int32_t (&axis)[2], float alpha = 1.0, float beta = 0.0) diff --git a/docs_input/basics/fusion.rst b/docs_input/basics/fusion.rst index 54d36cd9..9c9ac38a 100644 --- a/docs_input/basics/fusion.rst +++ b/docs_input/basics/fusion.rst @@ -60,13 +60,13 @@ CUDA JIT Kernel Fusion CUDA JIT kernel fusion is considered an experimental feature. There may be bugs that don't occur with JIT disabled, and new features are being added over time. MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled -for all standard MatX element-wise operators and FFT operations via MathDx. To enable fusion with MathDx, +for all standard MatX element-wise operators and FFT and GEMM operations via MathDx. To enable fusion with MathDx, the following options must be enabled: ``-DMATX_EN_MATHDX=ON``. Once enabled, the ``CUDAJITExecutor`` can be used perform JIT compilation -in supported situations. If the expression cannot be JIT compiled, the JITExecutor will fall back to the normal non-JIT path. +in supported situations. If the expression cannot be JIT compiled, the JITExecutor may throw an error. While JIT compilation can provide a large performance boost, there are two overheads that occur when using JIT compilation: -- The first pass to JIT the code takes time. The first time a ``run()`` statement is executed on a new operator, MatX identifies this and performs JIT compilation. Depending on the complexity of the operator, this could be anywhere from milliseconds to seconds to complete. Once finished, MatX will cache the compiled kernel so that subsequent runs of the same operator will not require JIT compilation. -- A lookup is done to find kernels that have already been compiled. This is a small overhead and may not be noticeable. +* The first pass to JIT the code takes time. The first time a ``run()`` statement is executed on a new operator, MatX identifies this and performs JIT compilation. Depending on the complexity of the operator, this could be anywhere from milliseconds to seconds to complete. Once finished, MatX will cache the compiled kernel so that subsequent runs of the same operator will not require JIT compilation. +* A lookup is done to find kernels that have already been compiled. This is a small overhead and may not be noticeable. As mentioned above, there is no difference in syntax between MatX statements that perform JIT compilation and those that do not. The executor is the only change, just as it would be with a host executor. For example, in the following code: @@ -76,7 +76,7 @@ is the only change, just as it would be with a host executor. For example, in th (A = B * fft(C)).run(CUDAExecutor{}); (A = B * fft(C)).run(CUDAJITExecutor{}); -When MathDx is disabled, the the first statement will execute the FFT into a temporary buffer, then the multiply will be executed. This results +The first statement will execute the FFT as a separate kernel into a temporary buffer, then the multiply will be executed. This results in a minimum of 2 kernels (one for MatX and at least one for cuFFT). The second statement will execute the FFT and multiply in a single kernel if possible. @@ -103,4 +103,27 @@ In this case the MathDx library requires at least 2 elements per thread for the only 1 element per thread. Therefore, the entire expression cannot be JIT-compiled and will fall back to the non-JIT path. Some of these restrictions may be relaxed in newer versions of MatX or the MathDx library. +MathDx Compatibility +==================== + +.. list-table:: MathDx library compatibility for CUDA JIT fusion + :header-rows: 1 + :widths: 20 14 66 + + * - Library + - Supported + - Notes + * - cuBLASDx + - Yes + - Enabled via ``-DMATX_EN_MATHDX=ON`` for GEMM fusion paths. + * - cuFFTDx + - Yes + - Enabled via ``-DMATX_EN_MATHDX=ON`` for FFT fusion paths. + * - cuSolverDx + - No + - Not supported yet by MatX CUDA JIT fusion. + * - cuRandDx + - No + - Not supported yet by MatX CUDA JIT fusion. + diff --git a/include/matx/core/capabilities.h b/include/matx/core/capabilities.h index 9961206d..31fce76c 100644 --- a/include/matx/core/capabilities.h +++ b/include/matx/core/capabilities.h @@ -71,6 +71,7 @@ namespace detail { ELEMENT_WISE, // Whether the operator is element-wise (safe with aliasing) ALIASED_MEMORY, // Whether the operator's input and output pointers alias GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level + PASS_THROUGH_THREADS, // All threads must call operator() on nested operators; bounds checking done at tensor level // Add more capabilities as needed }; @@ -246,6 +247,15 @@ namespace detail { static constexpr int default_value = 32; static constexpr int min_identity = 32; static constexpr int max_identity = 1; + }; + + template <> + struct capability_attributes { + using type = bool; + using input_type = VoidCapabilityType; + static constexpr bool default_value = false; // Default: operators do their own bounds checking + static constexpr bool or_identity = false; + static constexpr bool and_identity = true; }; @@ -312,6 +322,8 @@ namespace detail { return CapabilityQueryType::RANGE_QUERY; // The expression should use the minimum block size supported by all operators. case OperatorCapability::GENERATE_LTOIR: return CapabilityQueryType::AND_QUERY; // The expression should generate LTOIR code if all its children generate it. + case OperatorCapability::PASS_THROUGH_THREADS: + return CapabilityQueryType::OR_QUERY; // If ANY operator needs pass-through, all threads must call operator() default: // Default to OR_QUERY or handle as an error/assertion if a capability isn't mapped. return CapabilityQueryType::OR_QUERY; diff --git a/include/matx/core/error.h b/include/matx/core/error.h index 19613a4e..6b3d4e73 100644 --- a/include/matx/core/error.h +++ b/include/matx/core/error.h @@ -63,6 +63,7 @@ namespace matx matxInvalidSize, matxCudaError, matxCufftError, + matxLibMathdxError, matxMatMulError, matxAssertError, matxInvalidType, @@ -107,8 +108,12 @@ namespace matx return "matxInverseError"; case matxSolverError: return "matxSolverError"; + case matxLibMathdxError: + return "matxLibMathdxError"; case matxcuTensorError: - break; + return "matxcuTensorError"; + case matxInvalidExecutor: + return "matxInvalidExecutor"; default: return "Unknown"; }; diff --git a/include/matx/core/get_grid_dims.h b/include/matx/core/get_grid_dims.h index afecb785..0577e2b4 100644 --- a/include/matx/core/get_grid_dims.h +++ b/include/matx/core/get_grid_dims.h @@ -33,6 +33,7 @@ #pragma once #include "matx/core/defines.h" +#include "matx/core/error.h" #include #include #include @@ -298,5 +299,50 @@ inline bool get_grid_dims_block(dim3 &blocks, dim3 &threads, const cuda::std::ar MATX_LOG_DEBUG("Blocks {}x{}x{} Threads {}x{}x{} groups_per_block={}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, groups_per_block); return stride; } + +// For 2D block operators (e.g., cuBLASDx GEMM) where all threads in a block cooperate +// on the last 2 dimensions and blockIdx is used purely for batching +template +inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads, + const cuda::std::array &sizes, + int block_dim) { + // Threads are set to block_dim in x, y and z are 1 + // All threads cooperate via flattened thread ID in the kernel + threads.x = block_dim; + threads.y = 1; + threads.z = 1; + + // Grid covers batch dimensions only (dims 0 to RANK-3) + blocks.x = 1; + blocks.y = 1; + blocks.z = 1; + + if constexpr (RANK == 2) { + blocks.x = 1; // Single block for entire 2D output + } + else if constexpr (RANK == 3) { + blocks.x = static_cast(sizes[0]); // Batch dim + } + else if constexpr (RANK == 4) { + blocks.x = static_cast(sizes[1]); // Second-to-last batch + blocks.y = static_cast(sizes[0]); // First batch dim + } + else if constexpr (RANK > 4) { + MATX_THROW(matxNotSupported, "Block2D grid dims not supported for rank > 4"); + return true; + } + + if constexpr (RANK >= 2 && RANK <= 4) { + constexpr int kMaxGridDim = 65535; + if (blocks.x > kMaxGridDim || blocks.y > kMaxGridDim) { + MATX_THROW(matxInvalidParameter, "Block2D grid dims exceed CUDA limit (65535)"); + } + } + + MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z); + + // No stride needed for now - could be extended for very large batches + return false; +} } // end namespace detail } // end namespace matx diff --git a/include/matx/core/nvrtc_helper.h b/include/matx/core/nvrtc_helper.h index 9245c666..80eed3d0 100644 --- a/include/matx/core/nvrtc_helper.h +++ b/include/matx/core/nvrtc_helper.h @@ -88,11 +88,7 @@ std::vector __MATX_HOST__ __MATX_INLINE__ get_preprocessor_options( options.push_back("-arch=sm_80"); // fallback #endif - #ifdef NVRTC_CXX_STANDARD - options.push_back("-std=c++" NVRTC_CXX_STANDARD); - #else - options.push_back("-std=c++20"); // fallback - #endif + options.push_back("-std=c++20"); return options; } @@ -170,7 +166,7 @@ inline std::string get_jit_includes_path() { } template -std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride, bool global_kernel) { +std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride, bool global_kernel, bool pass_through_threads = false) { if constexpr (Op::Rank() == 0) { return "matx::detail::matxOpT0Kernel"; } @@ -178,21 +174,27 @@ std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride, bool glo return global_kernel ? "matx::detail::matxOpT1Kernel" : "matx::detail::matxOpT1KernelBlock"; } else if constexpr (Op::Rank() == 2) { - if (stride) { + if (pass_through_threads) { + return "matx::detail::matxOpT2KernelBlock2D"; + } else if (stride) { return global_kernel ? "matx::detail::matxOpT2StrideKernel" : "matx::detail::matxOpT2StrideKernelBlock"; } else { return global_kernel ? "matx::detail::matxOpT2Kernel" : "matx::detail::matxOpT2KernelBlock"; } } else if constexpr (Op::Rank() == 3) { - if (stride) { + if (pass_through_threads) { + return "matx::detail::matxOpT3KernelBlock2D"; + } else if (stride) { return global_kernel ? "matx::detail::matxOpT3StrideKernel" : "matx::detail::matxOpT3StrideKernelBlock"; } else { return global_kernel ? "matx::detail::matxOpT3Kernel" : "matx::detail::matxOpT3KernelBlock"; } } else if constexpr (Op::Rank() == 4) { - if (stride) { + if (pass_through_threads) { + return "matx::detail::matxOpT4KernelBlock2D"; + } else if (stride) { return global_kernel ? "matx::detail::matxOpT4StrideKernel" : "matx::detail::matxOpT4StrideKernelBlock"; } else { return global_kernel ? "matx::detail::matxOpT4Kernel" : "matx::detail::matxOpT4KernelBlock"; @@ -207,7 +209,7 @@ std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride, bool glo } template -std::string generate_capability_params_string([[maybe_unused]] const Op &op, ElementsPerThread EPT, bool JIT, int osize, int block_size) { +std::string generate_capability_params_string([[maybe_unused]] const Op &op, ElementsPerThread EPT, bool JIT, int osize, int block_size, bool pass_through_threads = false) { std::string ept_str; switch (EPT) { case ElementsPerThread::ONE: @@ -235,6 +237,8 @@ std::string generate_capability_params_string([[maybe_unused]] const Op &op, Ele std::string jit_str = JIT ? "true" : "false"; + std::string pass_through_str = pass_through_threads ? "true" : "false"; + std::string final_str = "namespace matx { namespace detail {\n" "template \n" @@ -243,6 +247,7 @@ std::string generate_capability_params_string([[maybe_unused]] const Op &op, Ele " static constexpr bool jit = JIT;\n" " static constexpr int osize = " + std::to_string(osize) + ";\n" " static constexpr int block_size = " + std::to_string(block_size) + ";\n" + " static constexpr bool pass_through_threads = " + pass_through_str + ";\n" "};\n" "using CurrentCapabilities = CapabilityParams<" + ept_str + ", " + jit_str + ">;\n" "} }\n"; @@ -300,7 +305,7 @@ inline std::string qualify_jit_type_names(const std::string& type_str) { } template -auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, const SizeArray &sa, dim3 &blocks, dim3 &threads, ElementsPerThread ept, bool stride, int dynamic_shmem_size, int osize, bool global_kernel) { +auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, const SizeArray &sa, dim3 &blocks, dim3 &threads, ElementsPerThread ept, bool stride, int dynamic_shmem_size, int osize, bool global_kernel, bool pass_through_threads = false) { // Pure NVRTC implementation // Cache both module and function to prevent resource leaks // CUmodule must remain loaded for CUfunction to be valid @@ -312,10 +317,10 @@ auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, cons static std::mutex kernel_cache_mutex; const auto all_jit_classes_string = get_all_jit_classes_string(op); - auto capstr = generate_capability_params_string(op, ept, false, osize, threads.x); + auto capstr = generate_capability_params_string(op, ept, false, osize, threads.x, pass_through_threads); const auto kernel_op_type = detail::get_operator_capability(op); - std::string kernel_name = get_kernel_name(op, stride, global_kernel); + std::string kernel_name = get_kernel_name(op, stride, global_kernel, pass_through_threads); std::string cache_key = kernel_name + "_" + kernel_op_type; MATX_LOG_DEBUG("nvrtc_compile_and_run called with operator type: {}", typeid(op).name()); diff --git a/include/matx/core/tensor_impl.h b/include/matx/core/tensor_impl.h index c226df4a..4b46907f 100644 --- a/include/matx/core/tensor_impl.h +++ b/include/matx/core/tensor_impl.h @@ -221,35 +221,64 @@ class tensor_impl_t { " return GetValC(cuda::std::make_tuple(indices...));\n" + " }\n" + " }\n" + + " template \n" + + " __MATX_INLINE__ __MATX_DEVICE__ bool CheckBounds(cuda::std::tuple tup) const {\n" + + " if constexpr (I < sizeof...(Is)) {\n" + + " constexpr int EPT_int = static_cast(CapType::ept);\n" + + " if constexpr (I == sizeof...(Is) - 1 && EPT_int > 1) {\n" + + " // Last dimension with EPT > 1: check all elements [idx*EPT, idx*EPT+EPT-1] are in bounds\n" + + " if ((cuda::std::get(tup) + 1) * EPT_int > sizes_[I]) return false;\n" + + " } else {\n" + + " if (cuda::std::get(tup) >= sizes_[I]) return false;\n" + + " }\n" + + " return CheckBounds(tup);\n" + + " }\n" + + " return true;\n" + + " }\n" + " template ...>, bool> = true>\n" + - " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const noexcept" + "{\n" + + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const noexcept" + "{\n" + " static_assert(sizeof...(Is) == M, \"Number of indices of operator() must match rank of tensor\");\n" + - " constexpr int EPT_int = static_cast(CapType::ept);\n" + - " const index_t offset = GetOffsetOptimized(indices...);\n" + - " if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" + - " return ldata_[offset];\n" + - " } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {\n" + - " return *reinterpret_cast*>(ldata_ + offset);\n" + - " } else {\n" + - " detail::Vector vec;\n" + - " vec.template load(ldata_ + offset);\n" + - " return vec;\n" + - " }\n" + + " constexpr int EPT_int = static_cast(CapType::ept);\n" + + " using ReturnType = cuda::std::conditional_t>;\n" + + " if constexpr (CapType::pass_through_threads) {\n" + + " if (!CheckBounds(cuda::std::make_tuple(indices...))) {\n" + + " return ReturnType{};\n" + + " }\n" + + " }\n" + + " const index_t offset = GetOffsetOptimized(indices...);\n" + + " if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" + + " return ldata_[offset];\n" + + " } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {\n" + + " return *reinterpret_cast*>(ldata_ + offset);\n" + + " } else {\n" + + " detail::Vector vec;\n" + + " vec.template load(ldata_ + offset);\n" + + " return vec;\n" + + " }\n" + " }\n" + " template ...>, bool> = true>\n" + - " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) noexcept\n" + + " __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) noexcept\n" + + " -> cuda::std::conditional_t(CapType::ept)>&>\n" + " {\n" + " static_assert(sizeof...(Is) == M, \"Number of indices of operator() must match rank of tensor\");\n" + " constexpr int EPT_int = static_cast(CapType::ept);\n" + + " if constexpr (CapType::pass_through_threads) {\n" + + " using ReturnType = cuda::std::conditional_t>;\n" + + " __align__(alignof(ReturnType)) __shared__ unsigned char dummy_storage[sizeof(ReturnType)];\n" + + " auto &dummy_ = *reinterpret_cast(dummy_storage);\n" + + " if (!CheckBounds(cuda::std::make_tuple(indices...))) {\n" + + " return dummy_;\n" + + " }\n" + + " }\n" + " const index_t offset = GetOffsetOptimized(indices...);\n" + " if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" + " return ldata_[offset];\n" + " } else {\n" + " return *reinterpret_cast*>(ldata_ + offset);\n" + - " }\n" + - " }\n" + + " }\n" + + " }\n" + " template \n" + " __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array &idx) const noexcept\n" + " {\n" + diff --git a/include/matx/executors/jit_cuda.h b/include/matx/executors/jit_cuda.h index 65b35af0..f827aa20 100644 --- a/include/matx/executors/jit_cuda.h +++ b/include/matx/executors/jit_cuda.h @@ -82,6 +82,7 @@ namespace matx dim3 threads; // Block dimensions (x, y, z threads) int osize; // Output size (last dimension) bool global_kernel; // Whether this is a global or block-level kernel + bool pass_through_threads; // Whether all threads must call operator() (bounds checking at tensor level) }; // Global cache for JIT launch parameters, keyed by operator type string from JIT_TYPE_QUERY @@ -164,8 +165,11 @@ namespace matx bool global_kernel = detail::get_operator_capability(op); + bool pass_through_threads = detail::get_operator_capability(op); if (global_kernel) { MATX_LOG_DEBUG("Operator operates on a global level"); + } else if (pass_through_threads) { + MATX_LOG_DEBUG("Operator uses pass-through threads (bounds checking at tensor level)"); } else { MATX_LOG_DEBUG("Operator operates on a block level"); } @@ -207,23 +211,38 @@ namespace matx // No cached parameters - compute them MATX_LOG_DEBUG("No cached parameters found, computing launch parameters for JIT"); - // Create kernel provider for JIT using consolidated function - auto kernel_provider = detail::create_kernel_provider(sizes, true, global_kernel); - - // Find the best launch parameters - auto result = detail::find_best_launch_params(op, kernel_provider, 0, true); - best_ept = cuda::std::get<0>(result); - shm_size = cuda::std::get<1>(result); - block_size = cuda::std::get<2>(result); - groups_per_block = cuda::std::get<3>(result); - - MATX_LOG_DEBUG("Best EPT {}, Shm size {}, Block size {}, Groups per block {}", - static_cast(best_ept), shm_size, block_size, groups_per_block); - - if (global_kernel) { - stride = detail::get_grid_dims(blocks, threads, sizes, static_cast(best_ept), 256); + if (pass_through_threads) { + // For pass-through operators (e.g., cuBLASDx), block dimensions are fixed by the operator + auto block_dim_range = detail::get_operator_capability(op); + block_size = block_dim_range[1]; // Use the max block size + stride = detail::get_grid_dims_block_2d(blocks, threads, sizes, block_size); + + // EPT is 1 for 2D block operators - the operator handles elements internally + best_ept = detail::ElementsPerThread::ONE; + shm_size = detail::get_operator_capability(op); + groups_per_block = 1; + + MATX_LOG_DEBUG("Block2D: EPT {}, Shm size {}, Block size {}", + static_cast(best_ept), shm_size, block_size); } else { - stride = detail::get_grid_dims_block(blocks, threads, sizes, static_cast(best_ept), groups_per_block, block_size, true); + // Create kernel provider for JIT using consolidated function + auto kernel_provider = detail::create_kernel_provider(sizes, true, global_kernel); + + // Find the best launch parameters + auto result = detail::find_best_launch_params(op, kernel_provider, 0, true); + best_ept = cuda::std::get<0>(result); + shm_size = cuda::std::get<1>(result); + block_size = cuda::std::get<2>(result); + groups_per_block = cuda::std::get<3>(result); + + MATX_LOG_DEBUG("Best EPT {}, Shm size {}, Block size {}, Groups per block {}", + static_cast(best_ept), shm_size, block_size, groups_per_block); + + if (global_kernel) { + stride = detail::get_grid_dims(blocks, threads, sizes, static_cast(best_ept), 256); + } else { + stride = detail::get_grid_dims_block(blocks, threads, sizes, static_cast(best_ept), groups_per_block, block_size, true); + } } // Cache ALL parameters for future use (sizes are encoded in type string) @@ -237,6 +256,7 @@ namespace matx params_to_cache.threads = threads; params_to_cache.osize = op.Rank() == 0 ? 1 : static_cast(op.Size(op.Rank() - 1)); params_to_cache.global_kernel = global_kernel; + params_to_cache.pass_through_threads = pass_through_threads; { std::lock_guard lock(detail::jit_launch_params_mutex); @@ -244,10 +264,10 @@ namespace matx } } - MATX_LOG_DEBUG("Shm size {}, Stride {}, estimated EPT {}, blocks {}x{}x{} threads {}x{}x{}", - shm_size, stride, static_cast(best_ept), blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z); + MATX_LOG_DEBUG("Shm size {}, Stride {}, estimated EPT {}, blocks {}x{}x{} threads {}x{}x{}, pass_through_threads {}", + shm_size, stride, static_cast(best_ept), blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, pass_through_threads); const int osize = op.Rank() == 0 ? 1 : static_cast(op.Size(op.Rank() - 1)); - detail::nvrtc_compile_and_run("output.cu", op, sizes, blocks, threads, best_ept, stride, shm_size, osize, global_kernel); + detail::nvrtc_compile_and_run("output.cu", op, sizes, blocks, threads, best_ept, stride, shm_size, osize, global_kernel, pass_through_threads); } else { // ND kernel support for ranks > 4 (JIT path) @@ -296,7 +316,7 @@ namespace matx params_to_cache.threads = threads; params_to_cache.osize = op.Rank() == 0 ? 1 : static_cast(op.Size(op.Rank() - 1)); params_to_cache.global_kernel = true; - + params_to_cache.pass_through_threads = pass_through_threads; { std::lock_guard lock(detail::jit_launch_params_mutex); detail::jit_launch_params_cache[kernel_op_type] = params_to_cache; diff --git a/include/matx/executors/jit_kernel.h b/include/matx/executors/jit_kernel.h index 621f0eb5..a3317bf5 100644 --- a/include/matx/executors/jit_kernel.h +++ b/include/matx/executors/jit_kernel.h @@ -321,6 +321,45 @@ namespace matx {\n\ }\n\ }\n\ }\n\ + \n\ + template \n\ + __global__ void matxOpT2KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1) {\n\ + int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ + matx::index_t idx = tid % size1;\n\ + matx::index_t idy = tid / size1;\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idy, idx);\n\ + } else {\n\ + op.template operator()(idy, idx);\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT3KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2) {\n\ + int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ + matx::index_t idx = tid % size2;\n\ + matx::index_t idy = tid / size2;\n\ + matx::index_t idz = blockIdx.x;\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idz, idy, idx);\n\ + } else {\n\ + op.template operator()(idz, idy, idx);\n\ + }\n\ + }\n\ + \n\ + template \n\ + __global__ void matxOpT4KernelBlock2D(Op op, matx::index_t size0, matx::index_t size1, matx::index_t size2, matx::index_t size3) {\n\ + int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y;\n\ + matx::index_t idx = tid % size3;\n\ + matx::index_t idy = tid / size3;\n\ + matx::index_t idz = blockIdx.x;\n\ + matx::index_t idw = blockIdx.y;\n\ + if constexpr (cuda::std::is_pointer_v) {\n\ + (*op).template operator()(idw, idz, idy, idx);\n\ + } else {\n\ + op.template operator()(idw, idz, idy, idx);\n\ + }\n\ + }\n\ }\n\ }"; #else diff --git a/include/matx/operators/binary_operators.h b/include/matx/operators/binary_operators.h index 02f632c7..bfea2776 100644 --- a/include/matx/operators/binary_operators.h +++ b/include/matx/operators/binary_operators.h @@ -180,13 +180,11 @@ namespace matx " typename detail::inner_storage_or_self_t> op_;\n" + " template \n" + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const\n" + - " {\n" + - " if ((threadIdx.x * static_cast(CapType::ept)) > Size(Rank() - 1)) {\n" + - " return detail::GetJitSentinelValue();\n" + - " }\n" + + " {\n" + " auto i1 = get_value(in1_, indices...);\n" + " auto i2 = get_value(in2_, indices...);\n" + - " return op_.template operator()(i1, i2);\n" + + " auto result = op_.template operator()(i1, i2);\n" + + " return result;\n" + " }\n" + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank()\n" + " {\n" + diff --git a/include/matx/operators/fft.h b/include/matx/operators/fft.h index a825bdf3..c3063abf 100644 --- a/include/matx/operators/fft.h +++ b/include/matx/operators/fft.h @@ -186,11 +186,7 @@ namespace matx dx_fft_helper_.set_fft_type(DeduceFFTTransformType::ctype, value_type>()); dx_fft_helper_.set_direction(Direction); dx_fft_helper_.set_cc(cc); - // if (fft_size_ <= 32) { - // dx_fft_helper_.set_method(cuFFTDxMethod::REGISTER); - // } else { - dx_fft_helper_.set_method(cuFFTDxMethod::SHARED); - //} + dx_fft_helper_.set_method(cuFFTDxMethod::SHARED); bool contiguous = false; if constexpr (is_tensor_view_v) { diff --git a/include/matx/operators/matmul.h b/include/matx/operators/matmul.h index a341840e..6c20ac12 100644 --- a/include/matx/operators/matmul.h +++ b/include/matx/operators/matmul.h @@ -35,6 +35,8 @@ #include "matx/core/type_utils.h" #include "matx/operators/base_operator.h" +#include "matx/core/operator_options.h" +#include "matx/core/log.h" #include "matx/transforms/matmul/matmul_cuda.h" #include "matx/transforms/matmul/matmul_cusparse.h" #ifdef MATX_EN_CPU_MATMUL @@ -42,6 +44,10 @@ #endif #include +#if defined(MATX_EN_MATHDX) && defined (__CUDACC__) + #include "matx/transforms/matmul/matmul_cublasdx.h" +#endif + namespace matx { namespace detail { @@ -59,7 +65,10 @@ namespace matx // This should be tensor_impl_t, but need to work around issues with temp types returned in matmul mutable detail::tensor_impl_t::value_type, out_rank> tmp_out_; mutable typename remove_cvref_t::value_type *ptr = nullptr; - mutable bool prerun_done_ = false; + mutable bool prerun_done_ = false; +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + mutable cuBLASDxHelper dx_gemm_helper_; +#endif public: using matxop = bool; @@ -71,6 +80,71 @@ namespace matx return "matmul(" + get_type_str(a_) + "," + get_type_str(b_) + ")"; } +#ifdef MATX_EN_JIT + struct JIT_Storage { + typename detail::inner_storage_or_self_t> a_; + typename detail::inner_storage_or_self_t> b_; + }; + + JIT_Storage ToJITStorage() const { + return JIT_Storage{detail::to_jit_storage(a_), detail::to_jit_storage(b_)}; + } +#endif + +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + __MATX_INLINE__ std::string get_jit_class_name() const { + std::string symbol_name = "JITMatMulOp_"; + symbol_name += std::to_string(dx_gemm_helper_.get_m()); + symbol_name += "_"; + symbol_name += std::to_string(dx_gemm_helper_.get_n()); + symbol_name += "_"; + symbol_name += std::to_string(dx_gemm_helper_.get_k()); + symbol_name += dx_gemm_helper_.get_is_complex() ? "_C" : "_R"; + return symbol_name; + } + + __MATX_INLINE__ auto get_jit_op_str() const { + const std::string class_name = get_jit_class_name(); + const std::string gemm_func_name = std::string(GEMM_DX_FUNC_PREFIX) + "_" + dx_gemm_helper_.GetSymbolName(); + + return cuda::std::make_tuple( + class_name, + std::string( + " extern \"C\" __device__ void " + gemm_func_name + "(" + + detail::type_to_string() + "*, " + + detail::type_to_string() + "*, " + + detail::type_to_string() + "*, " + + detail::type_to_string() + "*, " + + detail::type_to_string() + "*);\n" + + " template struct " + class_name + " {\n" + + " using input_type = typename OpA::value_type;\n" + + " using matxop = bool;\n" + + " using value_type = input_type;\n" + + " typename detail::inner_storage_or_self_t> a_;\n" + + " typename detail::inner_storage_or_self_t> b_;\n" + + " constexpr static cuda::std::array out_dims_ = { " + + detail::array_to_string(out_dims_) + " };\n" + + " static constexpr index_t m_ = " + std::to_string(dx_gemm_helper_.get_m()) + ";\n" + + " static constexpr index_t n_ = " + std::to_string(dx_gemm_helper_.get_n()) + ";\n" + + " static constexpr index_t k_ = " + std::to_string(dx_gemm_helper_.get_k()) + ";\n" + + " template \n" + + " __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const\n" + + " {\n" + + " " + dx_gemm_helper_.GetFuncStr(gemm_func_name, alpha_, beta_) + "\n" + + " }\n" + + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank()\n" + + " {\n" + + " return " + std::to_string(Rank()) + ";\n" + + " }\n" + + " constexpr __MATX_INLINE__ __MATX_DEVICE__ index_t Size(int dim) const\n" + + " {\n" + + " return out_dims_[dim];\n" + + " }\n" + + "};\n") + ); + } +#endif + __MATX_INLINE__ MatMulOp(const OpA &a, const OpB &b, float alpha, float beta, PermDims perm) : a_(a), b_(b), alpha_(alpha), beta_(beta), perm_(perm) { MATX_LOG_TRACE("{} constructor: alpha={}, beta={}", str(), alpha, beta); @@ -95,6 +169,27 @@ namespace matx out_dims_[Rank() - 2] = a_.Size(OpA::Rank() - 2); out_dims_[Rank() - 1] = b_.Size(OpB::Rank() - 1); } + +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + // Initialize cuBLASDx helper with matrix dimensions + // For GEMM: C(m x n) = A(m x k) * B(k x n) + // m = rows of output (from A's second-to-last dim) + // n = cols of output (from B's last dim) + // k = inner dimension (A's last dim = B's second-to-last dim) + int major = 0; + int minor = 0; + int device; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); + int cc = major * 100 + minor; // Compute capability as three digits (e.g., 890 for SM 8.9) + + dx_gemm_helper_.set_m(a_.Size(OpA::Rank() - 2)); + dx_gemm_helper_.set_n(b_.Size(OpB::Rank() - 1)); + dx_gemm_helper_.set_k(a_.Size(OpA::Rank() - 1)); + dx_gemm_helper_.set_cc(cc); + dx_gemm_helper_.set_is_complex(is_complex_v); +#endif } template @@ -111,16 +206,129 @@ namespace matx template __MATX_INLINE__ __MATX_HOST__ auto get_capability([[maybe_unused]] InType& in) const { +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + // Branch with cuBLASDx support + if constexpr (Cap == OperatorCapability::ALIASED_MEMORY) { + auto in_copy = in; + in_copy.permutes_input_output = true; + return combine_capabilities(detail::get_operator_capability(a_, in_copy), detail::get_operator_capability(b_, in_copy)); + } + else if constexpr (Cap == OperatorCapability::ELEMENTS_PER_THREAD) { +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) + auto result = ElementsPerThread::ONE; + MATX_LOG_DEBUG("cuBLASDx ELEMENTS_PER_THREAD: {}", static_cast(result)); + return cuda::std::array{result, result}; +#else + return combine_capabilities(capability_attributes::default_value, detail::get_operator_capability(a_, in), detail::get_operator_capability(b_, in)); +#endif + } + else if constexpr (Cap == OperatorCapability::DYN_SHM_SIZE) { + auto result = combine_capabilities(dx_gemm_helper_.GetShmRequired(), + detail::get_operator_capability(a_, in), + detail::get_operator_capability(b_, in)); + MATX_LOG_DEBUG("cuBLASDx DYN_SHM_SIZE: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { + bool supported = dx_gemm_helper_.template CheckJITSizeAndTypeRequirements() && + dx_gemm_helper_.IsSupported(); + + auto result = combine_capabilities(supported, + detail::get_operator_capability(a_, in), + detail::get_operator_capability(b_, in)); + MATX_LOG_DEBUG("cuBLASDx SUPPORTS_JIT: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::JIT_CLASS_QUERY) { + // Get the capability string and add to map + const auto [key, value] = get_jit_op_str(); + + // Insert into the map if the key doesn't exist + if (in.find(key) == in.end()) { + in[key] = value; + } + + // Also handle child operators + detail::get_operator_capability(a_, in); + detail::get_operator_capability(b_, in); + + MATX_LOG_DEBUG("cuBLASDx JIT_CLASS_QUERY: true"); + return true; + } + else if constexpr (Cap == OperatorCapability::BLOCK_DIM) { + auto block_dims = dx_gemm_helper_.GetBlockDim(); + MATX_LOG_DEBUG("cuBLASDx block dim: {} {} {}", block_dims[0], block_dims[1], block_dims[2]); + // Use the first dimension as the primary block size (similar to FFT) + const auto my_block = cuda::std::array{block_dims[0], block_dims[0]}; + return combine_capabilities(my_block, + detail::get_operator_capability(a_, in), + detail::get_operator_capability(b_, in)); + } + else if constexpr (Cap == OperatorCapability::GENERATE_LTOIR) { + auto result = combine_capabilities( + dx_gemm_helper_.GenerateLTOIR(in.ltoir_symbols), + detail::get_operator_capability(a_, in), + detail::get_operator_capability(b_, in)); + MATX_LOG_DEBUG("cuBLASDx GENERATE_LTOIR: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { + // No need to use combine_capabilities here since we're just returning a string. + const auto inner_op_a_jit_name = detail::get_operator_capability(a_, in); + const auto inner_op_b_jit_name = detail::get_operator_capability(b_, in); + auto result = get_jit_class_name() + "<" + inner_op_a_jit_name + ", " + inner_op_b_jit_name + ">"; + MATX_LOG_DEBUG("cuBLASDx JIT_TYPE_QUERY: {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::GLOBAL_KERNEL) { + // If MathDx is enabled we always return false. Other checks on size and type may prevent JIT compilation. + MATX_LOG_DEBUG("cuBLASDx GLOBAL_KERNEL: false"); + return false; + } + else if constexpr (Cap == OperatorCapability::PASS_THROUGH_THREADS) { + // cuBLASDx needs all threads to call operator() for block-level cooperation + MATX_LOG_DEBUG("cuBLASDx PASS_THROUGH_THREADS: true"); + return true; + } + else if constexpr (Cap == OperatorCapability::GROUPS_PER_BLOCK) { + // 2D block operators only support one group per block + const auto my_cap = cuda::std::array{1, 1}; + return combine_capabilities(my_cap, + detail::get_operator_capability(a_, in), + detail::get_operator_capability(b_, in)); + } + else { + auto self_has_cap = capability_attributes::default_value; + return combine_capabilities(self_has_cap, + detail::get_operator_capability(a_, in), + detail::get_operator_capability(b_, in)); + } +#else + // Branch without cuBLASDx support if constexpr (Cap == OperatorCapability::ALIASED_MEMORY) { auto in_copy = in; in_copy.permutes_input_output = true; return combine_capabilities(detail::get_operator_capability(a_, in_copy), detail::get_operator_capability(b_, in_copy)); - } else { + } + else if constexpr (Cap == OperatorCapability::SUPPORTS_JIT) { + bool supported = false; + auto result = combine_capabilities(supported, + detail::get_operator_capability(a_, in), + detail::get_operator_capability(b_, in)); + MATX_LOG_DEBUG("SUPPORTS_JIT (no cuBLASDx): {}", result); + return result; + } + else if constexpr (Cap == OperatorCapability::JIT_TYPE_QUERY) { + MATX_LOG_DEBUG("JIT_TYPE_QUERY (no cuBLASDx): \"\""); + return ""; + } + else { auto self_has_cap = capability_attributes::default_value; return combine_capabilities(self_has_cap, detail::get_operator_capability(a_, in), detail::get_operator_capability(b_, in)); } +#endif } static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank() diff --git a/include/matx/operators/set.h b/include/matx/operators/set.h index acc52df9..43feb81f 100644 --- a/include/matx/operators/set.h +++ b/include/matx/operators/set.h @@ -175,22 +175,17 @@ class set : public BaseOp> { "remove_cvref_t(op_, indices...))>\n" + " {\n" + " using in_val_type = remove_cvref_t(op_, indices...))>;\n" + - " if ((threadIdx.x * static_cast(CapType::ept)) >= Size(Rank() - 1)) {\n" + - " return in_val_type{};\n" + - " }\n" + " auto in_val = detail::get_value(op_, indices...);\n" + " using out_type = decltype(out_.template operator()(indices...));\n" + - " if (out_.Rank() == 0 || threadIdx.x < out_.Size(out_.Rank() - 1)) {\n" + - " if constexpr (!is_vector_v && is_vector_v) {\n" + - " Vector, static_cast(CapType::ept)> vec{in_val};\n" + - " out_.template operator()(indices...) = vec;\n" + - " }\n" + - " else {\n" + - " out_.template operator()(indices...) = in_val;\n" + - " }\n" + - " }\n" + - " return in_val;\n" + - " }\n" + + " if constexpr (!is_vector_v && is_vector_v) {\n" + + " Vector, static_cast(CapType::ept)> vec{in_val};\n" + + " out_.template operator()(indices...) = vec;\n" + + " }\n" + + " else {\n" + + " out_.template operator()(indices...) = in_val;\n" + + " }\n" + + " return in_val;\n" + + " }\n" + " static __MATX_INLINE__ constexpr __MATX_DEVICE__ int32_t Rank()\n" + " {\n" + " return T::Rank();\n" + diff --git a/include/matx/transforms/matmul/matmul_cublasdx.h b/include/matx/transforms/matmul/matmul_cublasdx.h new file mode 100644 index 00000000..1d4d7159 --- /dev/null +++ b/include/matx/transforms/matmul/matmul_cublasdx.h @@ -0,0 +1,469 @@ +//////////////////////////////////////////////////////////////////////////////// +// BSD 3-Clause License +// +// Copyright (c) 2025, NVIDIA Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// 3. Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +///////////////////////////////////////////////////////////////////////////////// + +#pragma once + +#include "matx/core/operator_options.h" +#include "matx/core/capabilities.h" +#include "matx/core/log.h" + +#include "matx/core/half_complex.h" +#include "matx/core/half.h" + +#include +#include +#define GEMM_DX_FUNC_PREFIX "gemm_cublasdx_func" + +#if defined(MATX_EN_MATHDX) && defined(__CUDACC__) +#include + +#define LIBCUBLASDX_CHECK(ans) \ + do { \ + commondxStatusType result = (ans); \ + MATX_ASSERT_STR_EXP(result, commondxStatusType::COMMONDX_SUCCESS, matxLibMathdxError, "cuBLASDx failed");\ + } while (0) + +namespace matx { + namespace detail { + +// Returns true if the matrix size and data type are supported by cuBLASDx for the given compute capability. + // Based on table from cuBLASDx documentation: + // https://docs.nvidia.com/cuda/cublasdx/requirements_func.html#supported-maximal-sizes-with-non-overlapping-a-and-b + template + __MATX_INLINE__ bool IscuBLASDxSupported(index_t m, index_t n, index_t k, int compute_capability) + { + // Using "Restricted AB with C in Shared" column from documentation + int max_size = 0; + + // Real, half or bfloat16 + if constexpr (std::is_same_v || std::is_same_v) { + if (compute_capability == 700 || compute_capability == 720) max_size = 128; + else if (compute_capability == 750) max_size = 104; + else if (compute_capability == 800 || compute_capability == 870) max_size = 166; + else if (compute_capability == 860 || compute_capability == 890 || compute_capability == 1200 || compute_capability == 1210) max_size = 129; + else if (compute_capability == 900 || compute_capability == 1000 || compute_capability == 1010 || compute_capability == 1030 || compute_capability == 1100) max_size = 196; + } + // Real, float OR Complex, half/bf16 + else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v>) { + if (compute_capability == 700 || compute_capability == 720) max_size = 90; + else if (compute_capability == 750) max_size = 73; + else if (compute_capability == 800 || compute_capability == 870) max_size = 117; + else if (compute_capability == 860 || compute_capability == 890 || compute_capability == 1200 || compute_capability == 1210) max_size = 91; + else if (compute_capability == 900 || compute_capability == 1000 || compute_capability == 1010 || compute_capability == 1030 || compute_capability == 1100) max_size = 139; + } + // Real, double OR Complex, float + else if constexpr (std::is_same_v || std::is_same_v>) { + if (compute_capability == 700 || compute_capability == 720) max_size = 64; + else if (compute_capability == 750) max_size = 52; + else if (compute_capability == 800 || compute_capability == 870) max_size = 83; + else if (compute_capability == 860 || compute_capability == 890 || compute_capability == 1200 || compute_capability == 1210) max_size = 64; + else if (compute_capability == 900 || compute_capability == 1000 || compute_capability == 1010 || compute_capability == 1030 || compute_capability == 1100) max_size = 98; + } + // Complex, double + else if constexpr (std::is_same_v>) { + if (compute_capability == 700 || compute_capability == 720) max_size = 45; + else if (compute_capability == 750) max_size = 36; + else if (compute_capability == 800 || compute_capability == 870) max_size = 58; + else if (compute_capability == 860 || compute_capability == 890 || compute_capability == 1200 || compute_capability == 1210) max_size = 45; + else if (compute_capability == 900 || compute_capability == 1000 || compute_capability == 1010 || compute_capability == 1030 || compute_capability == 1100) max_size = 69; + } + + if (max_size == 0) { + // If we made it here, the type or architecture is unsupported -- throw an error. + MATX_THROW(matxNotSupported, "IscuBLASDxSupported: Combination of data type and compute capability not supported by cuBLASDx"); + } + + const auto max_shm = static_cast(max_size) * static_cast(max_size) * sizeof(T) * 2; // Most SHM both A and B can use + const auto req_shm = sizeof(T) * (static_cast(m) * static_cast(k) + static_cast(k) * static_cast(n)); + + // All dimensions must fit in shared memory + return req_shm <= max_shm; +} + + template + class cuBLASDxHelper { + private: + index_t m_; // Output rows (A rows) + index_t n_; // Output cols (B cols) + index_t k_; // Inner dimension (A cols = B rows) + int cc_; // Compute capability + bool is_complex_; // Whether the type is complex + + template + static std::string FormatScalarLiteral(Scalar value) { + return std::format( + "{:.{}g}", + value, + std::numeric_limits::max_digits10); + } + + public: + // Constructor + cuBLASDxHelper() = default; + + // Getters + index_t get_m() const { return m_; } + index_t get_n() const { return n_; } + index_t get_k() const { return k_; } + int get_cc() const { return cc_; } + bool get_is_complex() const { return is_complex_; } + + // Setters + void set_m(index_t m) { m_ = m; } + void set_n(index_t n) { n_ = n; } + void set_k(index_t k) { k_ = k; } + void set_cc(int cc) { cc_ = cc; } + void set_is_complex(bool is_complex) { is_complex_ = is_complex; } + + cublasdxDescriptor GeneratePlan() const { + cublasdxDescriptor h_; + LIBCUBLASDX_CHECK(cublasdxCreateDescriptor(&h_)); + + // Set the GEMM function + LIBCUBLASDX_CHECK(cublasdxSetOperatorInt64(h_, CUBLASDX_OPERATOR_FUNCTION, CUBLASDX_FUNCTION_MM)); + + // Set problem size (M, N, K) + long long int sizes[3] = {static_cast(m_), static_cast(n_), static_cast(k_)}; + LIBCUBLASDX_CHECK(cublasdxSetOperatorInt64s(h_, CUBLASDX_OPERATOR_SIZE, 3, sizes)); + + // Set API type - use shared memory API + LIBCUBLASDX_CHECK(cublasdxSetOperatorInt64(h_, CUBLASDX_OPERATOR_API, CUBLASDX_API_SMEM)); + + // Set execution mode - block level execution + LIBCUBLASDX_CHECK(cublasdxSetOperatorInt64(h_, CUBLASDX_OPERATOR_EXECUTION, COMMONDX_EXECUTION_BLOCK)); + + // Set precision for A, B, C matrices (all same precision for now) + commondxPrecision precision; + if constexpr (std::is_same_v || std::is_same_v) { + precision = COMMONDX_PRECISION_F16; + } else if constexpr (std::is_same_v || std::is_same_v>) { + precision = COMMONDX_PRECISION_BF16; + } else if constexpr (std::is_same_v || std::is_same_v>) { + precision = COMMONDX_PRECISION_F32; + } else if constexpr (std::is_same_v || std::is_same_v>) { + precision = COMMONDX_PRECISION_F64; + } else { + MATX_THROW(matxInvalidParameter, "Unsupported input type for cuBLASDx"); + } + + long long int precisions[3] = {static_cast(precision), + static_cast(precision), + static_cast(precision)}; + LIBCUBLASDX_CHECK(cublasdxSetOperatorInt64s(h_, CUBLASDX_OPERATOR_PRECISION, 3, precisions)); + + // Set type (real or complex) + cublasdxType type = is_complex_ ? CUBLASDX_TYPE_COMPLEX : CUBLASDX_TYPE_REAL; + LIBCUBLASDX_CHECK(cublasdxSetOperatorInt64(h_, CUBLASDX_OPERATOR_TYPE, type)); + + // Set target compute capability + LIBCUBLASDX_CHECK(cublasdxSetOperatorInt64(h_, CUBLASDX_OPERATOR_SM, cc_)); + + // Set arrangement - row major for all matrices (MatX default) + long long int arrangements[3] = {CUBLASDX_ARRANGEMENT_ROW_MAJOR, + CUBLASDX_ARRANGEMENT_ROW_MAJOR, + CUBLASDX_ARRANGEMENT_ROW_MAJOR}; + LIBCUBLASDX_CHECK(cublasdxSetOperatorInt64s(h_, CUBLASDX_OPERATOR_ARRANGEMENT, 3, arrangements)); + + return h_; + } + + std::string GetSymbolName() const { + std::string symbol_name; + symbol_name += std::to_string(m_); + symbol_name += "_"; + symbol_name += std::to_string(n_); + symbol_name += "_"; + symbol_name += std::to_string(k_); + symbol_name += "_T"; + symbol_name += is_complex_ ? "C" : "R"; + symbol_name += "_CC"; + symbol_name += std::to_string(cc_); + + // Add precision identifier + if constexpr (std::is_same_v || std::is_same_v) { + symbol_name += "_F16"; + } else if constexpr (std::is_same_v || std::is_same_v>) { + symbol_name += "_BF16"; + } else if constexpr (std::is_same_v || std::is_same_v>) { + symbol_name += "_F32"; + } else if constexpr (std::is_same_v || std::is_same_v>) { + symbol_name += "_F64"; + } + + // Add CUDA version to the symbol name +#if defined(CUDART_VERSION) + symbol_name += "_CUDA"; + symbol_name += std::to_string(CUDART_VERSION); +#else + symbol_name += "_CUDAUNKNOWN"; +#endif + + return symbol_name; + } + + void PrintMembers() const { + std::cout << "m_ = " << m_ << std::endl; + std::cout << "n_ = " << n_ << std::endl; + std::cout << "k_ = " << k_ << std::endl; + std::cout << "cc_ = " << cc_ << std::endl; + std::cout << "is_complex_ = " << is_complex_ << std::endl; + } + + bool IsSupported() const { + // Check basic size requirements + if (!IscuBLASDxSupported(m_, n_, k_, cc_)) { + MATX_LOG_DEBUG("cuBLASDx not supported: matrix size too large for shared memory"); + return false; + } + + // For now, only support float and double (and their complex variants) + // Half and bf16 support can be added later + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>) { + return true; + } + + return false; + } + + template + bool CheckJITSizeAndTypeRequirements() const { + using OpAType = typename OpA::value_type; + using OpBType = typename OpB::value_type; + + // A and B must have same type + if constexpr (!std::is_same_v) { + return false; + } + + // Check supported types for JIT + if constexpr (!(std::is_same_v || + std::is_same_v || + std::is_same_v> || + std::is_same_v>)) { + return false; + } + + // Check size constraints + return IscuBLASDxSupported(m_, n_, k_, cc_); + } + + int GetShmRequired() const { + // For GEMM, we need shared memory for A (m x k), B (k x n), and C (m x n) + size_t a_size = static_cast(m_) * static_cast(k_) * sizeof(InputType); + size_t b_size = static_cast(k_) * static_cast(n_) * sizeof(InputType); + size_t c_size = static_cast(m_) * static_cast(n_) * sizeof(InputType); + + // Total shared memory requirement + size_t total_shm = a_size + b_size + c_size; + + MATX_LOG_DEBUG("cuBLASDx shared memory: A={}, B={}, C={}, Total={}", a_size, b_size, c_size, total_shm); + return static_cast(total_shm); + } + + cuda::std::array GetBlockDim() const { + auto handle = GeneratePlan(); + cuda::std::array block_dim = {0, 0, 0}; + + LIBCUBLASDX_CHECK( + cublasdxGetTraitInt64s(handle, CUBLASDX_TRAIT_SUGGESTED_BLOCK_DIM, 3, block_dim.data())); + MATX_LOG_DEBUG("cuBLASDx suggested block dim: {} {} {}", block_dim[0], block_dim[1], block_dim[2]); + + cublasdxDestroyDescriptor(handle); + + return cuda::std::array{static_cast(block_dim[0]), + static_cast(block_dim[1]), + static_cast(block_dim[2])}; + } + + cuda::std::array GetLeadingDimensions() const { + auto handle = GeneratePlan(); + cuda::std::array ld = {0, 0, 0}; + + LIBCUBLASDX_CHECK( + cublasdxGetTraitInt64s(handle, CUBLASDX_TRAIT_SUGGESTED_LEADING_DIMENSION, 3, ld.data())); + MATX_LOG_DEBUG("cuBLASDx suggested leading dimensions: {} {} {}", ld[0], ld[1], ld[2]); + + cublasdxDestroyDescriptor(handle); + + return cuda::std::array{static_cast(ld[0]), + static_cast(ld[1]), + static_cast(ld[2])}; + } + + bool GenerateLTOIR(std::set <oir_symbols) { + LTOIRData ltoir; + const auto symbol_name = std::string(GEMM_DX_FUNC_PREFIX) + "_" + GetSymbolName(); + ltoir_symbols.insert(symbol_name); + + if (detail::GetCache().GetLTOIRCachedBytes(symbol_name) != nullptr) { + MATX_LOG_DEBUG("cuBLASDx LTOIR found in cache with size {}", detail::GetCache().GetLTOIRCachedBytes(symbol_name)->length); + return true; + } + + auto handle = GeneratePlan(); + + LIBCUBLASDX_CHECK(cublasdxSetOptionStr(handle, COMMONDX_OPTION_SYMBOL_NAME, symbol_name.c_str())); + + commondxCode code; + LIBCUBLASDX_CHECK(commondxCreateCode(&code)); + + LIBCUBLASDX_CHECK(commondxSetCodeOptionInt64(code, COMMONDX_OPTION_TARGET_SM, cc_)); + LIBCUBLASDX_CHECK(cublasdxFinalizeCode(code, handle)); + + LIBCUBLASDX_CHECK(commondxGetCodeLTOIRSize(code, <oir.length)); + ltoir.data = static_cast(malloc(ltoir.length)); + MATX_ASSERT_STR(ltoir.data != nullptr, matxInvalidParameter, "Failed to allocate LTO IR data"); + + LIBCUBLASDX_CHECK(commondxGetCodeLTOIR(code, ltoir.length, ltoir.data)); + + MATX_LOG_DEBUG("cuBLASDx Function {}", symbol_name); + MATX_LOG_DEBUG("cuBLASDx LTOIR size {}", ltoir.length); + + if (!detail::GetCache().StoreLTOIRCachedBytes(symbol_name, ltoir.data, ltoir.length)) { + free(ltoir.data); + MATX_LOG_ERROR("Failed to store cuBLASDx LTOIR cached bytes for: {}", symbol_name); + return false; + } + + // CRITICAL: Set to nullptr after transferring ownership to cache to prevent double-free + ltoir.data = nullptr; + ltoir.length = 0; + + LIBCUBLASDX_CHECK(commondxDestroyCode(code)); + LIBCUBLASDX_CHECK(cublasdxDestroyDescriptor(handle)); + + return true; + } + + std::string GetFuncStr(const std::string &gemm_func_name, float alpha, float beta) const { + std::string result = R"( + using value_type = )"; + result += detail::type_to_string(); + result += R"(; + + // cuBLASDx requires block-level cooperation, so all threads in the block + // must participate in loading data and executing the GEMM + extern __shared__ __align__(16) char smem[]; + + // Partition shared memory for A, B, C matrices + constexpr size_t a_size = )"; + result += std::to_string(static_cast(m_ * k_)); + result += R"( * sizeof(value_type); + constexpr size_t b_size = )"; + result += std::to_string(static_cast(k_ * n_)); + + result += R"( * sizeof(value_type); + + value_type* smem_a = reinterpret_cast(smem); + value_type* smem_b = reinterpret_cast(smem + a_size); + value_type* smem_c = reinterpret_cast(smem + a_size + b_size); + + // Cooperatively load A and B from global to shared memory using operator() + // Batch indices are already preset in the operators, so we only need 2D matrix indices + const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + const int total_threads = blockDim.x * blockDim.y * blockDim.z; + // Load A matrix (m x k) - each thread loads multiple elements strided by total_threads + constexpr index_t a_cols = )"; + result += std::to_string(static_cast(k_)); + result += R"(; + for (int i = tid; i < )"; + result += std::to_string(static_cast(m_ * k_)); + result += R"(; i += total_threads) { + const index_t row = i / a_cols; + const index_t col = i % a_cols; + smem_a[row * a_cols + col] = a_.template operator()(row, col); + } + + // Load B matrix (k x n) - each thread loads multiple elements strided by total_threads + constexpr index_t b_cols = )"; + result += std::to_string(static_cast(n_)); + result += R"(; + for (int i = tid; i < )"; + result += std::to_string(static_cast(k_ * n_)); + result += R"(; i += total_threads) { + const index_t row = i / b_cols; + const index_t col = i % b_cols; + smem_b[row * b_cols + col] = b_.template operator()(row, col); + } + + __syncthreads(); + + // Call the cuBLASDx generated GEMM function + // Signature: void func(value_type* alpha, value_type* a, value_type* b, value_type* beta, value_type* c) + )"; + using literal_type = cuda::std::conditional_t< + std::is_same_v || std::is_same_v>, + double, + float>; + result += "value_type alpha_val = static_cast(" + FormatScalarLiteral(static_cast(alpha)) + ");\n"; + result += "value_type beta_val = static_cast(" + FormatScalarLiteral(static_cast(beta)) + ");\n"; + result += gemm_func_name; + result += R"((&alpha_val, smem_a, smem_b, &beta_val, smem_c); + + __syncthreads(); + + // Each thread returns its portion of the result + // For vectorized execution, return a Vector; for scalar, return scalar + static_assert(CapType::ept == ElementsPerThread::ONE, "cuBLASDx only supports ONE elements per thread"); + if constexpr (CapType::ept == ElementsPerThread::ONE) { + const int output_idx = threadIdx.x; + return smem_c[output_idx]; + } + )"; + + return result; + } + }; + + } // namespace detail +} // namespace matx + +#else // !MATX_EN_MATHDX || !__CUDACC__ + +namespace matx { + namespace detail { + + // Stub implementation when MathDx is not enabled + template + __MATX_INLINE__ bool IscuBLASDxSupported([[maybe_unused]] index_t m, [[maybe_unused]] index_t n, + [[maybe_unused]] index_t k, [[maybe_unused]] int compute_capability) + { + return false; + } + + } // namespace detail +} // namespace matx + +#endif // MATX_EN_MATHDX && __CUDACC__