Conversation
|
/build |
Greptile OverviewGreptile SummaryAdds cuBLASDx integration for JIT-compiled matrix multiplication of small matrices (up to ~200 elements per dimension). The implementation introduces a new execution model with pass-through threads where all threads in a block must invoke the operator, with bounds checking handled at the tensor level. Key additions:
Critical issue found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant MatMulOp
participant CUDAJITExecutor
participant cuBLASDxHelper
participant NVRTC
participant Block2DKernel
participant cuBLASDx
User->>MatMulOp: matmul(A, B)
MatMulOp->>cuBLASDxHelper: Initialize (m, n, k, cc)
cuBLASDxHelper->>cuBLASDxHelper: Check size constraints
User->>CUDAJITExecutor: Exec(matmul_op)
CUDAJITExecutor->>MatMulOp: Check SUPPORTS_JIT capability
MatMulOp->>cuBLASDxHelper: CheckJITSizeAndTypeRequirements()
cuBLASDxHelper-->>MatMulOp: true/false
alt JIT Supported
CUDAJITExecutor->>MatMulOp: Get PASS_THROUGH_THREADS capability
MatMulOp-->>CUDAJITExecutor: true
CUDAJITExecutor->>MatMulOp: Get BLOCK_DIM capability
MatMulOp->>cuBLASDxHelper: GetBlockDim()
cuBLASDxHelper->>cuBLASDxHelper: GeneratePlan()
cuBLASDxHelper-->>MatMulOp: block_dims[x,y,z]
CUDAJITExecutor->>CUDAJITExecutor: get_grid_dims_block_2d()
CUDAJITExecutor->>MatMulOp: Get GENERATE_LTOIR capability
MatMulOp->>cuBLASDxHelper: GenerateLTOIR()
cuBLASDxHelper->>cuBLASDxHelper: Generate cuBLASDx function
cuBLASDxHelper-->>MatMulOp: LTOIR symbols
CUDAJITExecutor->>NVRTC: nvrtc_compile_and_run()
NVRTC->>NVRTC: Compile with Block2D kernel
NVRTC->>Block2DKernel: Launch kernel
Block2DKernel->>Block2DKernel: Compute flattened tid
Block2DKernel->>MatMulOp: operator()(batch_idx, row, col)
MatMulOp->>MatMulOp: Load A and B to shared memory
MatMulOp->>cuBLASDx: Call GEMM function
cuBLASDx->>cuBLASDx: Perform matrix multiplication
cuBLASDx-->>MatMulOp: Result in shared memory
MatMulOp-->>Block2DKernel: Return output element
Block2DKernel-->>User: Complete
else JIT Not Supported
CUDAJITExecutor->>MatMulOp: Fall back to cuBLAS
MatMulOp-->>User: Complete
end
|
There was a problem hiding this comment.
Additional Comments (2)
-
include/matx/core/error.h, line 95-96 (link)logic: Missing case for
matxLibMathdxErrorin the switch statement. This will cause the function to fall through to the default case and return "Unknown" instead of a proper error string. -
include/matx/executors/jit_cuda.h, line 319 (link)logic: Missing initialization of
pass_through_threadsin ND kernel cache parameters - should be set tofalsefor consistency
11 files reviewed, 11 comments
|
/build |
|
/build |
Additional Comments (1)
|
|
/build |
Add cuBLASDx support for JIT-compiled matrix multiplication
Integrate cuBLASDx for fusion and accelerated matrix
multiplication of small matrices that fit in shared memory. This enables
significantly faster GEMM operations for sizes up to ~200 elements per
dimension (varies by data type and compute capability).
Key changes:
parameters, size validation, and device code generation
generation (get_jit_class_name, get_jit_op_str)
must invoke operator() with bounds checking at the tensor level
cuBLASDx operators with fixed block dimensions
pass-through thread execution model
Supported types: half, bfloat16, float, double, and their complex
variants. Size limits are architecture-dependent, ranging from 36-196
elements per dimension based on compute capability (SM 7.0 - SM 11.0).
Requires MATX_EN_MATHDX to be enabled at compile time.
Note that this is in early development. cuBLASDx has limitations that affect the MatX code base as a whole, such as dictating what the block size should be. This PR is for early adopter support and we will add more features over time.