Skip to content

cuBLASDx support#1122

Merged
cliffburdick merged 6 commits intomainfrom
cublasdx
Jan 27, 2026
Merged

cuBLASDx support#1122
cliffburdick merged 6 commits intomainfrom
cublasdx

Conversation

@cliffburdick
Copy link
Collaborator

Add cuBLASDx support for JIT-compiled matrix multiplication

Integrate cuBLASDx for fusion and accelerated matrix
multiplication of small matrices that fit in shared memory. This enables
significantly faster GEMM operations for sizes up to ~200 elements per
dimension (varies by data type and compute capability).

Key changes:

  • Add matmul_cublasdx.h with cuBLASDxHelper class for managing GEMM
    parameters, size validation, and device code generation
  • Extend MatMulOp with JIT storage support and cuBLASDx-specific code
    generation (get_jit_class_name, get_jit_op_str)
  • Add PASS_THROUGH_THREADS capability for operators where all threads
    must invoke operator() with bounds checking at the tensor level
  • Update JIT executor to handle 2D block launch configuration for
    cuBLASDx operators with fixed block dimensions
  • Add Block2D kernel variants (matxOpT{2,3,4}KernelBlock2D) for
    pass-through thread execution model

Supported types: half, bfloat16, float, double, and their complex
variants. Size limits are architecture-dependent, ranging from 36-196
elements per dimension based on compute capability (SM 7.0 - SM 11.0).

Requires MATX_EN_MATHDX to be enabled at compile time.

Note that this is in early development. cuBLASDx has limitations that affect the MatX code base as a whole, such as dictating what the block size should be. This PR is for early adopter support and we will add more features over time.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 23, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cliffburdick
Copy link
Collaborator Author

/build

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 23, 2026

Greptile Overview

Greptile Summary

Adds cuBLASDx integration for JIT-compiled matrix multiplication of small matrices (up to ~200 elements per dimension). The implementation introduces a new execution model with pass-through threads where all threads in a block must invoke the operator, with bounds checking handled at the tensor level.

Key additions:

  • New matmul_cublasdx.h with cuBLASDxHelper class for GEMM parameter management and device code generation
  • PASS_THROUGH_THREADS capability for block-level cooperative operators
  • Block2D kernel variants for 2D block execution model with flattened thread indexing
  • JIT launch parameter caching to avoid recomputation
  • Bounds checking in tensor_impl.h for pass-through threads
  • Support for float, double, and complex variants (half/bfloat16 infrastructure present but disabled)

Critical issue found:

  • matmul_cublasdx.h:441-442 has incorrect output indexing logic that only returns smem_c[threadIdx.x] instead of properly mapping the flattened thread ID to the 2D output matrix layout

Confidence Score: 2/5

  • This PR has a critical logic error in the output indexing that will produce incorrect results
  • The output indexing bug in matmul_cublasdx.h:441-442 is a fundamental logic error that breaks correctness. The Block2D kernel computes flattened thread indices and converts them to 2D coordinates, but the cuBLASDx operator ignores these and only uses threadIdx.x to index into the output, which will produce incorrect results for any matrix where m*n != blockDim.x
  • Critical attention needed for include/matx/transforms/matmul/matmul_cublasdx.h - the output indexing logic must be fixed before merge

Important Files Changed

Filename Overview
include/matx/transforms/matmul/matmul_cublasdx.h New file adding cuBLASDx support for JIT-compiled small matrix multiplication. Contains critical output indexing bug in Block2D kernel return logic.
include/matx/operators/matmul.h Extended MatMulOp with cuBLASDx JIT support, added capability checks, block dimension handling, and LTOIR generation.
include/matx/executors/jit_cuda.h Added support for pass-through threads execution model with Block2D grid computation and caching of launch parameters.
include/matx/executors/jit_kernel.h Added Block2D kernel variants (matxOpT2/3/4KernelBlock2D) for 2D block execution model with flattened thread indexing.
include/matx/core/get_grid_dims.h Added get_grid_dims_block_2d function for computing grid dimensions for 2D block operators with batch dimension handling.
include/matx/core/tensor_impl.h Added bounds checking for pass-through threads execution model with dummy storage for out-of-bounds returns.

Sequence Diagram

sequenceDiagram
    participant User
    participant MatMulOp
    participant CUDAJITExecutor
    participant cuBLASDxHelper
    participant NVRTC
    participant Block2DKernel
    participant cuBLASDx

    User->>MatMulOp: matmul(A, B)
    MatMulOp->>cuBLASDxHelper: Initialize (m, n, k, cc)
    cuBLASDxHelper->>cuBLASDxHelper: Check size constraints
    
    User->>CUDAJITExecutor: Exec(matmul_op)
    CUDAJITExecutor->>MatMulOp: Check SUPPORTS_JIT capability
    MatMulOp->>cuBLASDxHelper: CheckJITSizeAndTypeRequirements()
    cuBLASDxHelper-->>MatMulOp: true/false
    
    alt JIT Supported
        CUDAJITExecutor->>MatMulOp: Get PASS_THROUGH_THREADS capability
        MatMulOp-->>CUDAJITExecutor: true
        CUDAJITExecutor->>MatMulOp: Get BLOCK_DIM capability
        MatMulOp->>cuBLASDxHelper: GetBlockDim()
        cuBLASDxHelper->>cuBLASDxHelper: GeneratePlan()
        cuBLASDxHelper-->>MatMulOp: block_dims[x,y,z]
        
        CUDAJITExecutor->>CUDAJITExecutor: get_grid_dims_block_2d()
        CUDAJITExecutor->>MatMulOp: Get GENERATE_LTOIR capability
        MatMulOp->>cuBLASDxHelper: GenerateLTOIR()
        cuBLASDxHelper->>cuBLASDxHelper: Generate cuBLASDx function
        cuBLASDxHelper-->>MatMulOp: LTOIR symbols
        
        CUDAJITExecutor->>NVRTC: nvrtc_compile_and_run()
        NVRTC->>NVRTC: Compile with Block2D kernel
        NVRTC->>Block2DKernel: Launch kernel
        
        Block2DKernel->>Block2DKernel: Compute flattened tid
        Block2DKernel->>MatMulOp: operator()(batch_idx, row, col)
        MatMulOp->>MatMulOp: Load A and B to shared memory
        MatMulOp->>cuBLASDx: Call GEMM function
        cuBLASDx->>cuBLASDx: Perform matrix multiplication
        cuBLASDx-->>MatMulOp: Result in shared memory
        MatMulOp-->>Block2DKernel: Return output element
        Block2DKernel-->>User: Complete
    else JIT Not Supported
        CUDAJITExecutor->>MatMulOp: Fall back to cuBLAS
        MatMulOp-->>User: Complete
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. include/matx/core/error.h, line 95-96 (link)

    logic: Missing case for matxLibMathdxError in the switch statement. This will cause the function to fall through to the default case and return "Unknown" instead of a proper error string.

  2. include/matx/executors/jit_cuda.h, line 319 (link)

    logic: Missing initialization of pass_through_threads in ND kernel cache parameters - should be set to false for consistency

11 files reviewed, 11 comments

Edit Code Review Agent Settings | Greptile

@cliffburdick
Copy link
Collaborator Author

/build

@cliffburdick
Copy link
Collaborator Author

/build

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 26, 2026

Additional Comments (1)

include/matx/executors/jit_cuda.h
Global static variables without initialization order guarantees. Consider using a function-local static or singleton pattern to ensure proper initialization before first use across translation units.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@cliffburdick
Copy link
Collaborator Author

/build

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@cliffburdick cliffburdick merged commit 6a4a698 into main Jan 27, 2026
@cliffburdick cliffburdick deleted the cublasdx branch January 27, 2026 19:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant