Skip to content
Merged
133 changes: 133 additions & 0 deletions crates/tropical-gemm-cuda/kernels/tropical_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1084,3 +1084,136 @@ TROPICAL_BACKWARD_A(tropical_backward_a_f32, float, atomicAdd)
TROPICAL_BACKWARD_B(tropical_backward_b_f32, float, atomicAdd)
TROPICAL_BACKWARD_A(tropical_backward_a_f64, double, atomicAddDouble)
TROPICAL_BACKWARD_B(tropical_backward_b_f64, double, atomicAddDouble)

// ============================================================================
// BATCHED F32 GEMM WITH ARGMAX KERNEL MACRO
// ============================================================================
// Strided batched GEMM: processes batch_size independent GEMMs
// Uses blockIdx.z for batch index, strides for memory offsets

#define TROPICAL_GEMM_BATCHED_ARGMAX_F32(KERNEL_NAME, INIT_VAL, COMPARE_OP, MUL_OP) \
extern "C" __global__ void KERNEL_NAME( \
const float* __restrict__ A, \
const float* __restrict__ B, \
float* __restrict__ C, \
int* __restrict__ argmax, \
int M, int N, int K, \
int strideA, int strideB, int strideC \
) { \
const int BLOCK_SIZE_M = 64; \
const int BLOCK_SIZE_K = 32; \
const int BLOCK_SIZE_N = 64; \
const int THREAD_SIZE_M = 4; \
const int THREAD_SIZE_N = 4; \
\
const int bszm = BLOCK_SIZE_M / THREAD_SIZE_M; \
const int bszn = BLOCK_SIZE_N / THREAD_SIZE_N; \
const int THREAD_NUM_PER_BLOCK = bszm * bszn; \
\
int batch_idx = blockIdx.z; \
int DIM_GRID_X = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M; \
int BLOCK_IDX = blockIdx.x % DIM_GRID_X; \
int BLOCK_IDY = blockIdx.x / DIM_GRID_X; \
\
const float* A_batch = A + batch_idx * strideA; \
const float* B_batch = B + batch_idx * strideB; \
float* C_batch = C + batch_idx * strideC; \
int* argmax_batch = argmax + batch_idx * strideC; \
\
const int tid = threadIdx.y * bszm + threadIdx.x; \
\
__shared__ float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; \
__shared__ float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; \
\
float accum[THREAD_SIZE_M * THREAD_SIZE_N]; \
int accum_idx[THREAD_SIZE_M * THREAD_SIZE_N]; \
float regs_a[THREAD_SIZE_M]; \
float regs_b[THREAD_SIZE_N]; \
\
_Pragma("unroll") \
for (int i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { \
accum[i] = INIT_VAL; \
accum_idx[i] = 0; \
} \
\
const int A_TILE_COL = tid / BLOCK_SIZE_M; \
const int A_TILE_ROW = tid % BLOCK_SIZE_M; \
const int B_TILE_COL = tid / BLOCK_SIZE_K; \
const int B_TILE_ROW = tid % BLOCK_SIZE_K; \
const int A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M; \
const int B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K; \
\
for (int tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { \
_Pragma("unroll") \
for (int i = 0; i < BLOCK_SIZE_K; i += A_TILE_COL_STRIDE) { \
int row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW; \
int col = A_TILE_COL + i + tile_idx; \
float val = INIT_VAL; \
if (row < M && col < K) { \
val = A_batch[OFFSET_COL(row, col, M)]; \
} \
As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; \
} \
\
_Pragma("unroll") \
for (int i = 0; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) { \
int row = tile_idx + B_TILE_ROW; \
int col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL; \
float val = INIT_VAL; \
if (row < K && col < N) { \
val = B_batch[OFFSET_COL(row, col, K)]; \
} \
Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; \
} \
\
__syncthreads(); \
\
_Pragma("unroll") \
for (int k = 0; k < BLOCK_SIZE_K; ++k) { \
int global_k = tile_idx + k; \
_Pragma("unroll") \
for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { \
regs_a[tm] = As[OFFSET_COL(threadIdx.x * THREAD_SIZE_M + tm, \
k, BLOCK_SIZE_M)]; \
} \
_Pragma("unroll") \
for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { \
regs_b[tn] = Bs[OFFSET_COL(k, threadIdx.y * THREAD_SIZE_N + tn,\
BLOCK_SIZE_K)]; \
} \
_Pragma("unroll") \
for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { \
_Pragma("unroll") \
for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { \
float prod = regs_a[tm] MUL_OP regs_b[tn]; \
int idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); \
if (prod COMPARE_OP accum[idx]) { \
accum[idx] = prod; \
accum_idx[idx] = global_k; \
} \
} \
} \
} \
__syncthreads(); \
} \
\
_Pragma("unroll") \
for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { \
_Pragma("unroll") \
for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { \
int row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * threadIdx.x + tm; \
int col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * threadIdx.y + tn; \
if (row < M && col < N) { \
int out_idx = OFFSET_COL(row, col, M); \
int local_idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); \
C_batch[out_idx] = accum[local_idx]; \
argmax_batch[out_idx] = accum_idx[local_idx]; \
} \
} \
} \
}

// --- Batched F32 GEMM with Argmax Kernels ---
TROPICAL_GEMM_BATCHED_ARGMAX_F32(tropical_maxplus_f32_nn_batched_with_argmax, NEG_INF_F32, >, +)
TROPICAL_GEMM_BATCHED_ARGMAX_F32(tropical_minplus_f32_nn_batched_with_argmax, INF_F32, <, +)
TROPICAL_GEMM_BATCHED_ARGMAX_F32(tropical_maxmul_f32_nn_batched_with_argmax, 0.0f, >, *)
4 changes: 4 additions & 0 deletions crates/tropical-gemm-cuda/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ const KERNEL_NAMES: &[&str] = &[
"tropical_backward_b_f32",
"tropical_backward_a_f64",
"tropical_backward_b_f64",
// Batched GEMM with argmax kernels (f32 only)
"tropical_maxplus_f32_nn_batched_with_argmax",
"tropical_minplus_f32_nn_batched_with_argmax",
"tropical_maxmul_f32_nn_batched_with_argmax",
];

/// CUDA context for tropical GEMM operations.
Expand Down
81 changes: 80 additions & 1 deletion crates/tropical-gemm-cuda/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

use crate::context::CudaContext;
use crate::error::Result;
use crate::memory::{ExternalGpuMatrix, GpuMatrix, GpuMatrixWithArgmax};
use crate::memory::{
ExternalGpuMatrix, ExternalGpuTensor3, GpuMatrix, GpuMatrixWithArgmax, GpuTensor3WithArgmax,
};
use cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits};
use tropical_gemm::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus, TropicalSemiring};

Expand Down Expand Up @@ -469,3 +471,80 @@ pub unsafe fn launch_gemm_external_f32(
ctx.device().synchronize()?;
Ok(c)
}

// ============================================================================
// Batched External Kernels (for DLPack 3D tensors)
// ============================================================================

/// Launch a batched tropical GEMM kernel with argmax using external (DLPack) 3D tensors.
///
/// Computes C[b] = A[b] ⊗ B[b] for each batch b, where ⊗ is tropical matrix multiplication.
///
/// # Arguments
///
/// * `ctx` - CUDA context
/// * `kernel_name` - Name of the batched kernel to launch
/// * `a` - External 3D tensor A (batch, M, K) in row-major per batch
/// * `b` - External 3D tensor B (batch, K, N) in row-major per batch
/// * `batch` - Batch size
/// * `m` - Rows per matrix in A/C
/// * `k` - Columns in A / rows in B
/// * `n` - Columns per matrix in B/C
///
/// # Safety
///
/// - The input pointers must point to valid GPU memory with the specified dimensions
/// - The memory must remain valid for the duration of the kernel execution
pub unsafe fn launch_gemm_external_batched_with_argmax_f32(
ctx: &CudaContext,
kernel_name: &'static str,
a: &ExternalGpuTensor3<f32>,
b: &ExternalGpuTensor3<f32>,
batch: usize,
m: usize,
k: usize,
n: usize,
) -> Result<GpuTensor3WithArgmax<f32>> {
// Apply row-major → column-major trick: swap inputs and swap M↔N
// Same as non-batched version, but for each batch
let mut c = GpuTensor3WithArgmax::<f32>::alloc(ctx, batch, m, n)?;

// Grid: (ceil(N/64) * ceil(M/64), 1, batch) with swapped M↔N
let grid_xy = ((n + 63) / 64) * ((m + 63) / 64);
let grid = (grid_xy as u32, 1, batch as u32);
let block = CudaContext::block_dims_f32();

let kernel = ctx.get_kernel(kernel_name)?;
let cfg = LaunchConfig {
grid_dim: grid,
block_dim: block,
shared_mem_bytes: 0,
};

// Compute strides before borrowing mutably
let stride_a = a.stride() as i32;
let stride_b = b.stride() as i32;
let stride_c = c.tensor.stride() as i32;

// Swap order: pass B first, then A, and swap M↔N
// Kernel signature: (A, B, C, argmax, M, N, K, strideA, strideB, strideC)
// We pass: (B, A, C, argmax, N, M, K, strideB, strideA, strideC)
kernel.launch(
cfg,
(
b.device_ptr(), // B becomes "A" in kernel
a.device_ptr(), // A becomes "B" in kernel
c.tensor.as_slice_mut(),
c.argmax.as_slice_mut(),
n as i32, // Swapped: N becomes "M"
m as i32, // Swapped: M becomes "N"
k as i32,
stride_b, // strideA (B's stride in our swap)
stride_a, // strideB (A's stride in our swap)
stride_c, // strideC
),
)?;

ctx.device().synchronize()?;
Ok(c)
}
Loading