Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs_input/api/dft/fft/fft.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ Perform a 1D FFT. Batching is supported for any tensor with a rank higher than 1

.. versionadded:: 0.6.0

.. note::

FFT kernel fusion is supported by cuFFTDx if ``-DMATX_EN_MATHDX=ON`` is enabled.

.. doxygenfunction:: fft(const OpA &a, uint64_t fft_size = 0, FFTNorm norm = FFTNorm::BACKWARD)
.. doxygenfunction:: fft(const OpA &a, const int32_t (&axis)[1], uint64_t fft_size = 0, FFTNorm norm = FFTNorm::BACKWARD)

Expand Down
4 changes: 4 additions & 0 deletions docs_input/api/linalg/matvec/matmul.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ is supported for any tensor with a rank higher than 2.

.. versionadded:: 0.6.0

.. note::

GEMM kernel fusion is supported by cuBLASDx if ``-DMATX_EN_MATHDX=ON`` is enabled.

.. doxygenfunction:: matmul(const OpA &A, const OpB &B, float alpha = 1.0, float beta = 0.0)
.. doxygenfunction:: matmul(const OpA &A, const OpB &B, const int32_t (&axis)[2], float alpha = 1.0, float beta = 0.0)

Expand Down
33 changes: 28 additions & 5 deletions docs_input/basics/fusion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ CUDA JIT Kernel Fusion
CUDA JIT kernel fusion is considered an experimental feature. There may be bugs that don't occur with JIT disabled, and new features are being added over time.

MatX supports CUDA JIT kernel fusion that compiles the entire expression into a single kernel. Currently this is enabled
for all standard MatX element-wise operators and FFT operations via MathDx. To enable fusion with MathDx,
for all standard MatX element-wise operators and FFT and GEMM operations via MathDx. To enable fusion with MathDx,
the following options must be enabled: ``-DMATX_EN_MATHDX=ON``. Once enabled, the ``CUDAJITExecutor`` can be used perform JIT compilation
in supported situations. If the expression cannot be JIT compiled, the JITExecutor will fall back to the normal non-JIT path.
in supported situations. If the expression cannot be JIT compiled, the JITExecutor may throw an error.

While JIT compilation can provide a large performance boost, there are two overheads that occur when using JIT compilation:
- The first pass to JIT the code takes time. The first time a ``run()`` statement is executed on a new operator, MatX identifies this and performs JIT compilation. Depending on the complexity of the operator, this could be anywhere from milliseconds to seconds to complete. Once finished, MatX will cache the compiled kernel so that subsequent runs of the same operator will not require JIT compilation.
- A lookup is done to find kernels that have already been compiled. This is a small overhead and may not be noticeable.
* The first pass to JIT the code takes time. The first time a ``run()`` statement is executed on a new operator, MatX identifies this and performs JIT compilation. Depending on the complexity of the operator, this could be anywhere from milliseconds to seconds to complete. Once finished, MatX will cache the compiled kernel so that subsequent runs of the same operator will not require JIT compilation.
* A lookup is done to find kernels that have already been compiled. This is a small overhead and may not be noticeable.

As mentioned above, there is no difference in syntax between MatX statements that perform JIT compilation and those that do not. The executor
is the only change, just as it would be with a host executor. For example, in the following code:
Expand All @@ -76,7 +76,7 @@ is the only change, just as it would be with a host executor. For example, in th
(A = B * fft(C)).run(CUDAExecutor{});
(A = B * fft(C)).run(CUDAJITExecutor{});

When MathDx is disabled, the the first statement will execute the FFT into a temporary buffer, then the multiply will be executed. This results
The first statement will execute the FFT as a separate kernel into a temporary buffer, then the multiply will be executed. This results
in a minimum of 2 kernels (one for MatX and at least one for cuFFT). The second statement will execute the FFT and multiply in a single kernel if
possible.

Expand All @@ -103,4 +103,27 @@ In this case the MathDx library requires at least 2 elements per thread for the
only 1 element per thread. Therefore, the entire expression cannot be JIT-compiled and will fall back to the non-JIT path. Some of
these restrictions may be relaxed in newer versions of MatX or the MathDx library.

MathDx Compatibility
====================

.. list-table:: MathDx library compatibility for CUDA JIT fusion
:header-rows: 1
:widths: 20 14 66

* - Library
- Supported
- Notes
* - cuBLASDx
- Yes
- Enabled via ``-DMATX_EN_MATHDX=ON`` for GEMM fusion paths.
* - cuFFTDx
- Yes
- Enabled via ``-DMATX_EN_MATHDX=ON`` for FFT fusion paths.
* - cuSolverDx
- No
- Not supported yet by MatX CUDA JIT fusion.
* - cuRandDx
- No
- Not supported yet by MatX CUDA JIT fusion.


12 changes: 12 additions & 0 deletions include/matx/core/capabilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ namespace detail {
ELEMENT_WISE, // Whether the operator is element-wise (safe with aliasing)
ALIASED_MEMORY, // Whether the operator's input and output pointers alias
GLOBAL_KERNEL, // Kernel operates entirely on a global level per chunk of data. False when at least one operator works on a block level
PASS_THROUGH_THREADS, // All threads must call operator() on nested operators; bounds checking done at tensor level
// Add more capabilities as needed
};

Expand Down Expand Up @@ -246,6 +247,15 @@ namespace detail {
static constexpr int default_value = 32;
static constexpr int min_identity = 32;
static constexpr int max_identity = 1;
};

template <>
struct capability_attributes<OperatorCapability::PASS_THROUGH_THREADS> {
using type = bool;
using input_type = VoidCapabilityType;
static constexpr bool default_value = false; // Default: operators do their own bounds checking
static constexpr bool or_identity = false;
static constexpr bool and_identity = true;
};


Expand Down Expand Up @@ -312,6 +322,8 @@ namespace detail {
return CapabilityQueryType::RANGE_QUERY; // The expression should use the minimum block size supported by all operators.
case OperatorCapability::GENERATE_LTOIR:
return CapabilityQueryType::AND_QUERY; // The expression should generate LTOIR code if all its children generate it.
case OperatorCapability::PASS_THROUGH_THREADS:
return CapabilityQueryType::OR_QUERY; // If ANY operator needs pass-through, all threads must call operator()
default:
// Default to OR_QUERY or handle as an error/assertion if a capability isn't mapped.
return CapabilityQueryType::OR_QUERY;
Expand Down
7 changes: 6 additions & 1 deletion include/matx/core/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ namespace matx
matxInvalidSize,
matxCudaError,
matxCufftError,
matxLibMathdxError,
matxMatMulError,
matxAssertError,
matxInvalidType,
Expand Down Expand Up @@ -107,8 +108,12 @@ namespace matx
return "matxInverseError";
case matxSolverError:
return "matxSolverError";
case matxLibMathdxError:
return "matxLibMathdxError";
case matxcuTensorError:
break;
return "matxcuTensorError";
case matxInvalidExecutor:
return "matxInvalidExecutor";
default:
return "Unknown";
};
Expand Down
46 changes: 46 additions & 0 deletions include/matx/core/get_grid_dims.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#pragma once

#include "matx/core/defines.h"
#include "matx/core/error.h"
#include <cuda/std/array>
#include <cuda/std/functional>
#include <cuda/std/__numeric/accumulate.h>
Expand Down Expand Up @@ -298,5 +299,50 @@ inline bool get_grid_dims_block(dim3 &blocks, dim3 &threads, const cuda::std::ar
MATX_LOG_DEBUG("Blocks {}x{}x{} Threads {}x{}x{} groups_per_block={}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z, groups_per_block);
return stride;
}

// For 2D block operators (e.g., cuBLASDx GEMM) where all threads in a block cooperate
// on the last 2 dimensions and blockIdx is used purely for batching
template <int RANK>
inline bool get_grid_dims_block_2d(dim3 &blocks, dim3 &threads,
const cuda::std::array<index_t, RANK> &sizes,
int block_dim) {
// Threads are set to block_dim in x, y and z are 1
// All threads cooperate via flattened thread ID in the kernel
threads.x = block_dim;
threads.y = 1;
threads.z = 1;

// Grid covers batch dimensions only (dims 0 to RANK-3)
blocks.x = 1;
blocks.y = 1;
blocks.z = 1;

if constexpr (RANK == 2) {
blocks.x = 1; // Single block for entire 2D output
}
else if constexpr (RANK == 3) {
blocks.x = static_cast<int>(sizes[0]); // Batch dim
}
else if constexpr (RANK == 4) {
blocks.x = static_cast<int>(sizes[1]); // Second-to-last batch
blocks.y = static_cast<int>(sizes[0]); // First batch dim
}
else if constexpr (RANK > 4) {
MATX_THROW(matxNotSupported, "Block2D grid dims not supported for rank > 4");
return true;
}

if constexpr (RANK >= 2 && RANK <= 4) {
constexpr int kMaxGridDim = 65535;
if (blocks.x > kMaxGridDim || blocks.y > kMaxGridDim) {
MATX_THROW(matxInvalidParameter, "Block2D grid dims exceed CUDA limit (65535)");
}
}

MATX_LOG_DEBUG("Block2D: Blocks {}x{}x{} Threads {}x{}x{}", blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z);

// No stride needed for now - could be extended for very large batches
return false;
}
} // end namespace detail
} // end namespace matx
31 changes: 18 additions & 13 deletions include/matx/core/nvrtc_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ std::vector<std::string> __MATX_HOST__ __MATX_INLINE__ get_preprocessor_options(
options.push_back("-arch=sm_80"); // fallback
#endif

#ifdef NVRTC_CXX_STANDARD
options.push_back("-std=c++" NVRTC_CXX_STANDARD);
#else
options.push_back("-std=c++20"); // fallback
#endif
options.push_back("-std=c++20");

return options;
}
Expand Down Expand Up @@ -170,29 +166,35 @@ inline std::string get_jit_includes_path() {
}

template <typename Op>
std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride, bool global_kernel) {
std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride, bool global_kernel, bool pass_through_threads = false) {
if constexpr (Op::Rank() == 0) {
return "matx::detail::matxOpT0Kernel";
}
else if constexpr (Op::Rank() == 1) {
return global_kernel ? "matx::detail::matxOpT1Kernel" : "matx::detail::matxOpT1KernelBlock";
}
else if constexpr (Op::Rank() == 2) {
if (stride) {
if (pass_through_threads) {
return "matx::detail::matxOpT2KernelBlock2D";
} else if (stride) {
return global_kernel ? "matx::detail::matxOpT2StrideKernel" : "matx::detail::matxOpT2StrideKernelBlock";
} else {
return global_kernel ? "matx::detail::matxOpT2Kernel" : "matx::detail::matxOpT2KernelBlock";
}
}
else if constexpr (Op::Rank() == 3) {
if (stride) {
if (pass_through_threads) {
return "matx::detail::matxOpT3KernelBlock2D";
} else if (stride) {
return global_kernel ? "matx::detail::matxOpT3StrideKernel" : "matx::detail::matxOpT3StrideKernelBlock";
} else {
return global_kernel ? "matx::detail::matxOpT3Kernel" : "matx::detail::matxOpT3KernelBlock";
}
}
else if constexpr (Op::Rank() == 4) {
if (stride) {
if (pass_through_threads) {
return "matx::detail::matxOpT4KernelBlock2D";
} else if (stride) {
return global_kernel ? "matx::detail::matxOpT4StrideKernel" : "matx::detail::matxOpT4StrideKernelBlock";
} else {
return global_kernel ? "matx::detail::matxOpT4Kernel" : "matx::detail::matxOpT4KernelBlock";
Expand All @@ -207,7 +209,7 @@ std::string get_kernel_name([[maybe_unused]] const Op &op, bool stride, bool glo
}

template <typename Op>
std::string generate_capability_params_string([[maybe_unused]] const Op &op, ElementsPerThread EPT, bool JIT, int osize, int block_size) {
std::string generate_capability_params_string([[maybe_unused]] const Op &op, ElementsPerThread EPT, bool JIT, int osize, int block_size, bool pass_through_threads = false) {
std::string ept_str;
switch (EPT) {
case ElementsPerThread::ONE:
Expand Down Expand Up @@ -235,6 +237,8 @@ std::string generate_capability_params_string([[maybe_unused]] const Op &op, Ele

std::string jit_str = JIT ? "true" : "false";

std::string pass_through_str = pass_through_threads ? "true" : "false";

std::string final_str =
"namespace matx { namespace detail {\n"
"template <ElementsPerThread EPT, bool JIT>\n"
Expand All @@ -243,6 +247,7 @@ std::string generate_capability_params_string([[maybe_unused]] const Op &op, Ele
" static constexpr bool jit = JIT;\n"
" static constexpr int osize = " + std::to_string(osize) + ";\n"
" static constexpr int block_size = " + std::to_string(block_size) + ";\n"
" static constexpr bool pass_through_threads = " + pass_through_str + ";\n"
"};\n"
"using CurrentCapabilities = CapabilityParams<" + ept_str + ", " + jit_str + ">;\n"
"} }\n";
Expand Down Expand Up @@ -300,7 +305,7 @@ inline std::string qualify_jit_type_names(const std::string& type_str) {
}

template <typename Op, typename SizeArray>
auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, const SizeArray &sa, dim3 &blocks, dim3 &threads, ElementsPerThread ept, bool stride, int dynamic_shmem_size, int osize, bool global_kernel) {
auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, const SizeArray &sa, dim3 &blocks, dim3 &threads, ElementsPerThread ept, bool stride, int dynamic_shmem_size, int osize, bool global_kernel, bool pass_through_threads = false) {
// Pure NVRTC implementation
// Cache both module and function to prevent resource leaks
// CUmodule must remain loaded for CUfunction to be valid
Expand All @@ -312,10 +317,10 @@ auto nvrtc_compile_and_run([[maybe_unused]] const std::string &name, Op op, cons
static std::mutex kernel_cache_mutex;

const auto all_jit_classes_string = get_all_jit_classes_string(op);
auto capstr = generate_capability_params_string(op, ept, false, osize, threads.x);
auto capstr = generate_capability_params_string(op, ept, false, osize, threads.x, pass_through_threads);
const auto kernel_op_type = detail::get_operator_capability<OperatorCapability::JIT_TYPE_QUERY>(op);

std::string kernel_name = get_kernel_name(op, stride, global_kernel);
std::string kernel_name = get_kernel_name(op, stride, global_kernel, pass_through_threads);
std::string cache_key = kernel_name + "_" + kernel_op_type;

MATX_LOG_DEBUG("nvrtc_compile_and_run called with operator type: {}", typeid(op).name());
Expand Down
59 changes: 44 additions & 15 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,35 +221,64 @@ class tensor_impl_t {
" return GetValC<EPT, 0, Is...>(cuda::std::make_tuple(indices...));\n" +
" }\n" +
" }\n" +
" template <typename CapType, int I = 0, typename... Is>\n" +
" __MATX_INLINE__ __MATX_DEVICE__ bool CheckBounds(cuda::std::tuple<Is...> tup) const {\n" +
" if constexpr (I < sizeof...(Is)) {\n" +
" constexpr int EPT_int = static_cast<int>(CapType::ept);\n" +
" if constexpr (I == sizeof...(Is) - 1 && EPT_int > 1) {\n" +
" // Last dimension with EPT > 1: check all elements [idx*EPT, idx*EPT+EPT-1] are in bounds\n" +
" if ((cuda::std::get<I>(tup) + 1) * EPT_int > sizes_[I]) return false;\n" +
" } else {\n" +
" if (cuda::std::get<I>(tup) >= sizes_[I]) return false;\n" +
" }\n" +
" return CheckBounds<CapType, I+1>(tup);\n" +
" }\n" +
" return true;\n" +
" }\n" +
" template <typename CapType, int M = RANK, typename... Is,\n" +
" cuda::std::enable_if_t<cuda::std::conjunction_v<cuda::std::is_integral<Is>...>, bool> = true>\n" +
" __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) const noexcept" + "{\n" +
" __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) const noexcept" + "{\n" +
" static_assert(sizeof...(Is) == M, \"Number of indices of operator() must match rank of tensor\");\n" +
" constexpr int EPT_int = static_cast<int>(CapType::ept);\n" +
" const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
" if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" +
" return ldata_[offset];\n" +
" } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {\n" +
" return *reinterpret_cast<detail::Vector<T, EPT_int>*>(ldata_ + offset);\n" +
" } else {\n" +
" detail::Vector<T, EPT_int> vec;\n" +
" vec.template load<EPT_int>(ldata_ + offset);\n" +
" return vec;\n" +
" }\n" +
" constexpr int EPT_int = static_cast<int>(CapType::ept);\n" +
" using ReturnType = cuda::std::conditional_t<CapType::ept == detail::ElementsPerThread::ONE, T, detail::Vector<T, EPT_int>>;\n" +
" if constexpr (CapType::pass_through_threads) {\n" +
" if (!CheckBounds<CapType, 0>(cuda::std::make_tuple(indices...))) {\n" +
" return ReturnType{};\n" +
" }\n" +
" }\n" +
" const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
" if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" +
" return ldata_[offset];\n" +
" } else if constexpr (EPT_int * sizeof(T) <= MAX_VEC_WIDTH_BYTES ) {\n" +
" return *reinterpret_cast<detail::Vector<T, EPT_int>*>(ldata_ + offset);\n" +
" } else {\n" +
" detail::Vector<T, EPT_int> vec;\n" +
" vec.template load<EPT_int>(ldata_ + offset);\n" +
" return vec;\n" +
" }\n" +
" }\n" +
" template <typename CapType, int M = RANK, typename... Is,\n" +
" cuda::std::enable_if_t<cuda::std::conjunction_v<cuda::std::is_integral<Is>...>, bool> = true>\n" +
" __MATX_INLINE__ __MATX_DEVICE__ decltype(auto) operator()(Is... indices) noexcept\n" +
" __MATX_INLINE__ __MATX_DEVICE__ auto operator()(Is... indices) noexcept\n" +
" -> cuda::std::conditional_t<CapType::ept == detail::ElementsPerThread::ONE, T&, detail::Vector<T, static_cast<int>(CapType::ept)>&>\n" +
" {\n" +
" static_assert(sizeof...(Is) == M, \"Number of indices of operator() must match rank of tensor\");\n" +
" constexpr int EPT_int = static_cast<int>(CapType::ept);\n" +
" if constexpr (CapType::pass_through_threads) {\n" +
" using ReturnType = cuda::std::conditional_t<CapType::ept == detail::ElementsPerThread::ONE, T, detail::Vector<T, EPT_int>>;\n" +
" __align__(alignof(ReturnType)) __shared__ unsigned char dummy_storage[sizeof(ReturnType)];\n" +
" auto &dummy_ = *reinterpret_cast<ReturnType *>(dummy_storage);\n" +
" if (!CheckBounds<CapType, 0>(cuda::std::make_tuple(indices...))) {\n" +
" return dummy_;\n" +
" }\n" +
" }\n" +
" const index_t offset = GetOffsetOptimized<CapType::ept>(indices...);\n" +
" if constexpr (CapType::ept == detail::ElementsPerThread::ONE) {\n" +
" return ldata_[offset];\n" +
" } else {\n" +
" return *reinterpret_cast<detail::Vector<T, EPT_int>*>(ldata_ + offset);\n" +
" }\n" +
" }\n" +
" }\n" +
" }\n" +
" template <typename CapType>\n" +
" __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) operator()(const cuda::std::array<index_t, RANK> &idx) const noexcept\n" +
" {\n" +
Expand Down
Loading