diff --git a/crates/tropical-gemm-cuda/kernels/tropical_gemm.cu b/crates/tropical-gemm-cuda/kernels/tropical_gemm.cu index 7818384..db9a8ff 100644 --- a/crates/tropical-gemm-cuda/kernels/tropical_gemm.cu +++ b/crates/tropical-gemm-cuda/kernels/tropical_gemm.cu @@ -1084,3 +1084,136 @@ TROPICAL_BACKWARD_A(tropical_backward_a_f32, float, atomicAdd) TROPICAL_BACKWARD_B(tropical_backward_b_f32, float, atomicAdd) TROPICAL_BACKWARD_A(tropical_backward_a_f64, double, atomicAddDouble) TROPICAL_BACKWARD_B(tropical_backward_b_f64, double, atomicAddDouble) + +// ============================================================================ +// BATCHED F32 GEMM WITH ARGMAX KERNEL MACRO +// ============================================================================ +// Strided batched GEMM: processes batch_size independent GEMMs +// Uses blockIdx.z for batch index, strides for memory offsets + +#define TROPICAL_GEMM_BATCHED_ARGMAX_F32(KERNEL_NAME, INIT_VAL, COMPARE_OP, MUL_OP) \ +extern "C" __global__ void KERNEL_NAME( \ + const float* __restrict__ A, \ + const float* __restrict__ B, \ + float* __restrict__ C, \ + int* __restrict__ argmax, \ + int M, int N, int K, \ + int strideA, int strideB, int strideC \ +) { \ + const int BLOCK_SIZE_M = 64; \ + const int BLOCK_SIZE_K = 32; \ + const int BLOCK_SIZE_N = 64; \ + const int THREAD_SIZE_M = 4; \ + const int THREAD_SIZE_N = 4; \ + \ + const int bszm = BLOCK_SIZE_M / THREAD_SIZE_M; \ + const int bszn = BLOCK_SIZE_N / THREAD_SIZE_N; \ + const int THREAD_NUM_PER_BLOCK = bszm * bszn; \ + \ + int batch_idx = blockIdx.z; \ + int DIM_GRID_X = (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M; \ + int BLOCK_IDX = blockIdx.x % DIM_GRID_X; \ + int BLOCK_IDY = blockIdx.x / DIM_GRID_X; \ + \ + const float* A_batch = A + batch_idx * strideA; \ + const float* B_batch = B + batch_idx * strideB; \ + float* C_batch = C + batch_idx * strideC; \ + int* argmax_batch = argmax + batch_idx * strideC; \ + \ + const int tid = threadIdx.y * bszm + threadIdx.x; \ + \ + __shared__ float As[BLOCK_SIZE_M * BLOCK_SIZE_K]; \ + __shared__ float Bs[BLOCK_SIZE_K * BLOCK_SIZE_N]; \ + \ + float accum[THREAD_SIZE_M * THREAD_SIZE_N]; \ + int accum_idx[THREAD_SIZE_M * THREAD_SIZE_N]; \ + float regs_a[THREAD_SIZE_M]; \ + float regs_b[THREAD_SIZE_N]; \ + \ + _Pragma("unroll") \ + for (int i = 0; i < THREAD_SIZE_M * THREAD_SIZE_N; ++i) { \ + accum[i] = INIT_VAL; \ + accum_idx[i] = 0; \ + } \ + \ + const int A_TILE_COL = tid / BLOCK_SIZE_M; \ + const int A_TILE_ROW = tid % BLOCK_SIZE_M; \ + const int B_TILE_COL = tid / BLOCK_SIZE_K; \ + const int B_TILE_ROW = tid % BLOCK_SIZE_K; \ + const int A_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_M; \ + const int B_TILE_COL_STRIDE = THREAD_NUM_PER_BLOCK / BLOCK_SIZE_K; \ + \ + for (int tile_idx = 0; tile_idx < K; tile_idx += BLOCK_SIZE_K) { \ + _Pragma("unroll") \ + for (int i = 0; i < BLOCK_SIZE_K; i += A_TILE_COL_STRIDE) { \ + int row = BLOCK_SIZE_M * BLOCK_IDX + A_TILE_ROW; \ + int col = A_TILE_COL + i + tile_idx; \ + float val = INIT_VAL; \ + if (row < M && col < K) { \ + val = A_batch[OFFSET_COL(row, col, M)]; \ + } \ + As[OFFSET_COL(A_TILE_ROW, i + A_TILE_COL, BLOCK_SIZE_M)] = val; \ + } \ + \ + _Pragma("unroll") \ + for (int i = 0; i < BLOCK_SIZE_N; i += B_TILE_COL_STRIDE) { \ + int row = tile_idx + B_TILE_ROW; \ + int col = BLOCK_SIZE_N * BLOCK_IDY + i + B_TILE_COL; \ + float val = INIT_VAL; \ + if (row < K && col < N) { \ + val = B_batch[OFFSET_COL(row, col, K)]; \ + } \ + Bs[OFFSET_COL(B_TILE_ROW, i + B_TILE_COL, BLOCK_SIZE_K)] = val; \ + } \ + \ + __syncthreads(); \ + \ + _Pragma("unroll") \ + for (int k = 0; k < BLOCK_SIZE_K; ++k) { \ + int global_k = tile_idx + k; \ + _Pragma("unroll") \ + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { \ + regs_a[tm] = As[OFFSET_COL(threadIdx.x * THREAD_SIZE_M + tm, \ + k, BLOCK_SIZE_M)]; \ + } \ + _Pragma("unroll") \ + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { \ + regs_b[tn] = Bs[OFFSET_COL(k, threadIdx.y * THREAD_SIZE_N + tn,\ + BLOCK_SIZE_K)]; \ + } \ + _Pragma("unroll") \ + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { \ + _Pragma("unroll") \ + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { \ + float prod = regs_a[tm] MUL_OP regs_b[tn]; \ + int idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); \ + if (prod COMPARE_OP accum[idx]) { \ + accum[idx] = prod; \ + accum_idx[idx] = global_k; \ + } \ + } \ + } \ + } \ + __syncthreads(); \ + } \ + \ + _Pragma("unroll") \ + for (int tm = 0; tm < THREAD_SIZE_M; ++tm) { \ + _Pragma("unroll") \ + for (int tn = 0; tn < THREAD_SIZE_N; ++tn) { \ + int row = BLOCK_SIZE_M * BLOCK_IDX + THREAD_SIZE_M * threadIdx.x + tm; \ + int col = BLOCK_SIZE_N * BLOCK_IDY + THREAD_SIZE_N * threadIdx.y + tn; \ + if (row < M && col < N) { \ + int out_idx = OFFSET_COL(row, col, M); \ + int local_idx = OFFSET_COL(tm, tn, THREAD_SIZE_M); \ + C_batch[out_idx] = accum[local_idx]; \ + argmax_batch[out_idx] = accum_idx[local_idx]; \ + } \ + } \ + } \ +} + +// --- Batched F32 GEMM with Argmax Kernels --- +TROPICAL_GEMM_BATCHED_ARGMAX_F32(tropical_maxplus_f32_nn_batched_with_argmax, NEG_INF_F32, >, +) +TROPICAL_GEMM_BATCHED_ARGMAX_F32(tropical_minplus_f32_nn_batched_with_argmax, INF_F32, <, +) +TROPICAL_GEMM_BATCHED_ARGMAX_F32(tropical_maxmul_f32_nn_batched_with_argmax, 0.0f, >, *) diff --git a/crates/tropical-gemm-cuda/src/context.rs b/crates/tropical-gemm-cuda/src/context.rs index fda6ecd..17c6da7 100644 --- a/crates/tropical-gemm-cuda/src/context.rs +++ b/crates/tropical-gemm-cuda/src/context.rs @@ -57,6 +57,10 @@ const KERNEL_NAMES: &[&str] = &[ "tropical_backward_b_f32", "tropical_backward_a_f64", "tropical_backward_b_f64", + // Batched GEMM with argmax kernels (f32 only) + "tropical_maxplus_f32_nn_batched_with_argmax", + "tropical_minplus_f32_nn_batched_with_argmax", + "tropical_maxmul_f32_nn_batched_with_argmax", ]; /// CUDA context for tropical GEMM operations. diff --git a/crates/tropical-gemm-cuda/src/kernels.rs b/crates/tropical-gemm-cuda/src/kernels.rs index 5190618..cdb2fb7 100644 --- a/crates/tropical-gemm-cuda/src/kernels.rs +++ b/crates/tropical-gemm-cuda/src/kernels.rs @@ -2,7 +2,9 @@ use crate::context::CudaContext; use crate::error::Result; -use crate::memory::{ExternalGpuMatrix, GpuMatrix, GpuMatrixWithArgmax}; +use crate::memory::{ + ExternalGpuMatrix, ExternalGpuTensor3, GpuMatrix, GpuMatrixWithArgmax, GpuTensor3WithArgmax, +}; use cudarc::driver::{DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; use tropical_gemm::types::{TropicalMaxMul, TropicalMaxPlus, TropicalMinPlus, TropicalSemiring}; @@ -469,3 +471,80 @@ pub unsafe fn launch_gemm_external_f32( ctx.device().synchronize()?; Ok(c) } + +// ============================================================================ +// Batched External Kernels (for DLPack 3D tensors) +// ============================================================================ + +/// Launch a batched tropical GEMM kernel with argmax using external (DLPack) 3D tensors. +/// +/// Computes C[b] = A[b] ⊗ B[b] for each batch b, where ⊗ is tropical matrix multiplication. +/// +/// # Arguments +/// +/// * `ctx` - CUDA context +/// * `kernel_name` - Name of the batched kernel to launch +/// * `a` - External 3D tensor A (batch, M, K) in row-major per batch +/// * `b` - External 3D tensor B (batch, K, N) in row-major per batch +/// * `batch` - Batch size +/// * `m` - Rows per matrix in A/C +/// * `k` - Columns in A / rows in B +/// * `n` - Columns per matrix in B/C +/// +/// # Safety +/// +/// - The input pointers must point to valid GPU memory with the specified dimensions +/// - The memory must remain valid for the duration of the kernel execution +pub unsafe fn launch_gemm_external_batched_with_argmax_f32( + ctx: &CudaContext, + kernel_name: &'static str, + a: &ExternalGpuTensor3, + b: &ExternalGpuTensor3, + batch: usize, + m: usize, + k: usize, + n: usize, +) -> Result> { + // Apply row-major → column-major trick: swap inputs and swap M↔N + // Same as non-batched version, but for each batch + let mut c = GpuTensor3WithArgmax::::alloc(ctx, batch, m, n)?; + + // Grid: (ceil(N/64) * ceil(M/64), 1, batch) with swapped M↔N + let grid_xy = ((n + 63) / 64) * ((m + 63) / 64); + let grid = (grid_xy as u32, 1, batch as u32); + let block = CudaContext::block_dims_f32(); + + let kernel = ctx.get_kernel(kernel_name)?; + let cfg = LaunchConfig { + grid_dim: grid, + block_dim: block, + shared_mem_bytes: 0, + }; + + // Compute strides before borrowing mutably + let stride_a = a.stride() as i32; + let stride_b = b.stride() as i32; + let stride_c = c.tensor.stride() as i32; + + // Swap order: pass B first, then A, and swap M↔N + // Kernel signature: (A, B, C, argmax, M, N, K, strideA, strideB, strideC) + // We pass: (B, A, C, argmax, N, M, K, strideB, strideA, strideC) + kernel.launch( + cfg, + ( + b.device_ptr(), // B becomes "A" in kernel + a.device_ptr(), // A becomes "B" in kernel + c.tensor.as_slice_mut(), + c.argmax.as_slice_mut(), + n as i32, // Swapped: N becomes "M" + m as i32, // Swapped: M becomes "N" + k as i32, + stride_b, // strideA (B's stride in our swap) + stride_a, // strideB (A's stride in our swap) + stride_c, // strideC + ), + )?; + + ctx.device().synchronize()?; + Ok(c) +} diff --git a/crates/tropical-gemm-cuda/src/lib.rs b/crates/tropical-gemm-cuda/src/lib.rs index 0867417..17c62c4 100644 --- a/crates/tropical-gemm-cuda/src/lib.rs +++ b/crates/tropical-gemm-cuda/src/lib.rs @@ -48,54 +48,100 @@ mod gpu_mat; mod kernels; mod memory; +use cudarc::driver::CudaDevice; use once_cell::sync::OnceCell; +use std::collections::HashMap; use std::sync::Mutex; -/// Global CUDA context for convenience functions. -/// Lazily initialized on first use, persists for process lifetime. -static GLOBAL_CONTEXT: OnceCell = OnceCell::new(); +/// Per-device CUDA context cache, dynamically sized based on available devices. +static DEVICE_CONTEXTS: OnceCell>> = OnceCell::new(); -/// Mutex to ensure only one thread initializes the context. +/// Mutex to ensure thread-safe initialization. static INIT_MUTEX: Mutex<()> = Mutex::new(()); -/// Get or initialize the global CUDA context. +/// Get the number of available CUDA devices. +pub fn cuda_device_count() -> Result { + Ok(CudaDevice::count()? as usize) +} + +/// Get or initialize the CUDA context for a specific device. +/// +/// This function is thread-safe and will only initialize the context once per device. +/// Subsequent calls return the cached context for that device. +/// +/// # Arguments /// -/// This function is thread-safe and will only initialize the context once. -/// Subsequent calls return the cached context. +/// * `device_id` - The CUDA device ordinal /// /// # Errors /// -/// Returns an error if CUDA initialization fails (no device, driver issues, etc.) -pub fn get_global_context() -> Result<&'static CudaContext> { - // Fast path: already initialized - if let Some(ctx) = GLOBAL_CONTEXT.get() { - return Ok(ctx); +/// Returns an error if CUDA initialization fails or device_id is invalid. +pub fn get_context_for_device(device_id: usize) -> Result<&'static CudaContext> { + // Check device count + let device_count = cuda_device_count()?; + if device_id >= device_count { + return Err(CudaError::DimensionMismatch(format!( + "Device {} not available (only {} CUDA device(s) found)", + device_id, device_count + ))); + } + + // Initialize the map if needed + let contexts = DEVICE_CONTEXTS.get_or_init(|| Mutex::new(HashMap::new())); + + // Fast path: check if context exists + { + let map = contexts.lock().unwrap(); + if let Some(&ctx) = map.get(&device_id) { + return Ok(ctx); + } } // Slow path: need to initialize let _lock = INIT_MUTEX.lock().unwrap(); // Double-check after acquiring lock - if let Some(ctx) = GLOBAL_CONTEXT.get() { - return Ok(ctx); + { + let map = contexts.lock().unwrap(); + if let Some(&ctx) = map.get(&device_id) { + return Ok(ctx); + } } - // Initialize and store - let ctx = CudaContext::new()?; - let _ = GLOBAL_CONTEXT.set(ctx); + // Create context and leak it for 'static lifetime + let ctx = CudaContext::new_on_device(device_id)?; + let ctx_static: &'static CudaContext = Box::leak(Box::new(ctx)); + + // Store in map + { + let mut map = contexts.lock().unwrap(); + map.insert(device_id, ctx_static); + } - Ok(GLOBAL_CONTEXT.get().unwrap()) + Ok(ctx_static) +} + +/// Get or initialize the global CUDA context (device 0). +/// +/// This is a convenience function equivalent to `get_context_for_device(0)`. +/// +/// # Errors +/// +/// Returns an error if CUDA initialization fails (no device, driver issues, etc.) +pub fn get_global_context() -> Result<&'static CudaContext> { + get_context_for_device(0) } pub use context::CudaContext; pub use error::{CudaError, Result}; pub use gpu_mat::{GpuMat, GpuMatWithArgmax}; pub use kernels::{ - launch_gemm_external_f32, launch_gemm_external_with_argmax_f32, CudaKernel, - CudaKernelWithArgmax, + launch_gemm_external_batched_with_argmax_f32, launch_gemm_external_f32, + launch_gemm_external_with_argmax_f32, CudaKernel, CudaKernelWithArgmax, }; pub use memory::{ - ArgmaxIndex, ExternalGpuMatrix, ExternalGpuMemory, GpuMatrix, GpuMatrixWithArgmax, + ArgmaxIndex, ExternalGpuMatrix, ExternalGpuMemory, ExternalGpuTensor3, GpuMatrix, + GpuMatrixWithArgmax, GpuTensor3, GpuTensor3WithArgmax, }; use cudarc::driver::{DeviceRepr, ValidAsZeroBits}; @@ -125,7 +171,27 @@ fn validate_gemm_input(a: &[T], b: &[T], m: usize, k: usize, n: usize) -> Res /// One-shot tropical matrix multiplication on GPU. /// /// This function handles all GPU memory management automatically. -/// For repeated operations, use `tropical_gemm_gpu` with a persistent context. +/// +/// # Performance Note +/// +/// This function performs host-to-device (H2D) transfers for inputs and +/// device-to-host (D2H) transfer for output on every call. For repeated +/// operations, use [`GpuMatrix`] with [`tropical_gemm_gpu`] instead to +/// keep data on GPU between operations: +/// +/// ```ignore +/// let ctx = get_global_context()?; +/// let a_gpu = GpuMatrix::from_host(ctx, &a, m, k)?; +/// let b_gpu = GpuMatrix::from_host(ctx, &b, k, n)?; +/// let mut c_gpu = GpuMatrix::alloc(ctx, m, n)?; +/// +/// // Repeated operations without H2D/D2H transfers +/// for _ in 0..iterations { +/// tropical_gemm_gpu::>(ctx, &a_gpu, &b_gpu, &mut c_gpu)?; +/// } +/// +/// let c = c_gpu.to_host(ctx)?; // Single D2H at the end +/// ``` /// /// # Arguments /// @@ -252,6 +318,12 @@ where /// which k-index produced each C[i,j]. This is essential for backward /// propagation in tropical neural networks. /// +/// # Performance Note +/// +/// This function performs H2D transfers for inputs and D2H transfers for outputs +/// on every call. For repeated operations, use [`GpuMatrixWithArgmax`] with +/// [`tropical_gemm_gpu_with_argmax`] to keep data on GPU between operations. +/// /// # Arguments /// /// * `a` - Matrix A in **column-major** order, dimensions m×k diff --git a/crates/tropical-gemm-cuda/src/memory.rs b/crates/tropical-gemm-cuda/src/memory.rs index da7fea5..67bb625 100644 --- a/crates/tropical-gemm-cuda/src/memory.rs +++ b/crates/tropical-gemm-cuda/src/memory.rs @@ -54,7 +54,13 @@ impl GpuMatrix { /// /// Use this when interfacing with row-major data sources (e.g., C arrays). /// For column-major data, use `from_host` instead for better performance. - #[deprecated(since = "0.4.0", note = "use from_host with column-major data instead")] + /// + /// # Performance Warning + /// + /// This method performs an O(rows×cols) transpose on the CPU before uploading to GPU. + /// For performance-critical code, provide data in column-major order and use + /// [`from_host`] instead. + #[deprecated(since = "0.4.0", note = "use from_host with column-major data instead; this method has O(m×n) transpose overhead")] pub fn from_host_row_major( ctx: &CudaContext, data: &[T], @@ -112,7 +118,13 @@ impl GpuMatrix { /// /// Use this when interfacing with row-major data consumers. /// For column-major data, use `to_host` instead for better performance. - #[deprecated(since = "0.4.0", note = "use to_host for column-major data instead")] + /// + /// # Performance Warning + /// + /// This method performs an O(rows×cols) transpose on the CPU after downloading from GPU. + /// For performance-critical code, use [`to_host`] and handle the column-major layout + /// in your application. + #[deprecated(since = "0.4.0", note = "use to_host for column-major data instead; this method has O(m×n) transpose overhead")] pub fn to_host_row_major(&self, ctx: &CudaContext) -> Result> { let col_major = ctx.device().dtoh_sync_copy(&self.data)?; // Transpose from column-major to row-major @@ -156,6 +168,17 @@ impl GpuMatrix { pub fn as_slice_mut(&mut self) -> &mut CudaSlice { &mut self.data } + + /// Get the raw device pointer (for DLPack export). + pub fn device_ptr(&self) -> CUdeviceptr { + use cudarc::driver::DevicePtr; + *self.data.device_ptr() + } + + /// Consume self and return the inner CudaSlice (for ownership transfer). + pub fn into_inner(self) -> CudaSlice { + self.data + } } /// A GPU matrix paired with argmax indices (for backward propagation). @@ -223,6 +246,14 @@ impl GpuMatrixWithArgmax { pub fn argmax_to_host_col_major(&self, ctx: &CudaContext) -> Result> { self.argmax_to_host(ctx) } + + /// Consume self and return the matrix and argmax separately. + /// + /// This is useful for DLPack export where each tensor needs to be wrapped + /// independently for ownership transfer. + pub fn into_parts(self) -> (GpuMatrix, GpuMatrix) { + (self.matrix, self.argmax) + } } // ============================================================================ @@ -326,3 +357,217 @@ impl ExternalGpuMatrix { self.memory.is_empty() } } + +/// A 3D tensor view into external GPU memory for batched operations. +/// +/// This represents a batch of matrices stored contiguously in row-major order +/// (as PyTorch tensors are). Shape is (batch, rows, cols) with stride between batches. +/// +/// The actual data is not copied - we just store metadata and a pointer. +pub struct ExternalGpuTensor3 { + device_ptr: CUdeviceptr, + batch: usize, + rows: usize, + cols: usize, + stride: usize, // elements per batch (typically rows * cols for contiguous) + _marker: PhantomData, +} + +impl ExternalGpuTensor3 { + /// Create a new ExternalGpuTensor3 from a raw device pointer. + /// + /// # Safety + /// + /// - `device_ptr` must point to valid GPU memory containing at least `batch * stride` elements + /// - The memory must be in row-major (C-contiguous) order per batch + /// - The memory must remain valid for the lifetime of this struct + pub unsafe fn from_raw( + device_ptr: CUdeviceptr, + batch: usize, + rows: usize, + cols: usize, + stride: usize, + ) -> Self { + Self { + device_ptr, + batch, + rows, + cols, + stride, + _marker: PhantomData, + } + } + + /// Create from contiguous 3D tensor (stride = rows * cols). + /// + /// # Safety + /// + /// Same requirements as `from_raw`. + pub unsafe fn from_raw_contiguous( + device_ptr: CUdeviceptr, + batch: usize, + rows: usize, + cols: usize, + ) -> Self { + Self::from_raw(device_ptr, batch, rows, cols, rows * cols) + } + + /// Get the raw device pointer. + pub fn device_ptr(&self) -> CUdeviceptr { + self.device_ptr + } + + /// Get the batch size. + pub fn batch(&self) -> usize { + self.batch + } + + /// Get the number of rows per matrix. + pub fn rows(&self) -> usize { + self.rows + } + + /// Get the number of columns per matrix. + pub fn cols(&self) -> usize { + self.cols + } + + /// Get the stride (elements between batches). + pub fn stride(&self) -> usize { + self.stride + } + + /// Get the total number of elements. + pub fn len(&self) -> usize { + self.batch * self.stride + } + + /// Check if the tensor is empty. + pub fn is_empty(&self) -> bool { + self.batch == 0 || self.rows == 0 || self.cols == 0 + } + + /// Check if the tensor is contiguous (stride == rows * cols). + pub fn is_contiguous(&self) -> bool { + self.stride == self.rows * self.cols + } +} + +/// A batched GPU matrix result with owned memory. +/// +/// Stores batch_size matrices of shape (rows, cols) contiguously on GPU. +pub struct GpuTensor3 { + data: CudaSlice, + batch: usize, + rows: usize, + cols: usize, +} + +impl GpuTensor3 { + /// Allocate a zeroed batched GPU tensor. + pub fn alloc(ctx: &CudaContext, batch: usize, rows: usize, cols: usize) -> Result { + let gpu_data = ctx.device().alloc_zeros::(batch * rows * cols)?; + Ok(Self { + data: gpu_data, + batch, + rows, + cols, + }) + } + + /// Copy GPU data back to host as a flat vector (batch × rows × cols elements). + pub fn to_host(&self, ctx: &CudaContext) -> Result> { + Ok(ctx.device().dtoh_sync_copy(&self.data)?) + } + + /// Get the batch size. + pub fn batch(&self) -> usize { + self.batch + } + + /// Get rows per matrix. + pub fn rows(&self) -> usize { + self.rows + } + + /// Get columns per matrix. + pub fn cols(&self) -> usize { + self.cols + } + + /// Get stride (elements per batch). + pub fn stride(&self) -> usize { + self.rows * self.cols + } + + /// Get the underlying CUDA slice. + pub fn as_slice(&self) -> &CudaSlice { + &self.data + } + + /// Get a mutable reference to the underlying CUDA slice. + pub fn as_slice_mut(&mut self) -> &mut CudaSlice { + &mut self.data + } + + /// Get the raw device pointer (for DLPack export). + pub fn device_ptr(&self) -> CUdeviceptr { + use cudarc::driver::DevicePtr; + *self.data.device_ptr() + } + + /// Consume self and return the inner CudaSlice (for ownership transfer). + pub fn into_inner(self) -> CudaSlice { + self.data + } +} + +/// A batched GPU tensor with argmax indices for backward propagation. +pub struct GpuTensor3WithArgmax { + /// The result tensor C (batch × rows × cols). + pub tensor: GpuTensor3, + /// The argmax indices (batch × rows × cols). + pub argmax: GpuTensor3, +} + +impl GpuTensor3WithArgmax { + /// Allocate a zeroed batched GPU tensor with argmax. + pub fn alloc(ctx: &CudaContext, batch: usize, rows: usize, cols: usize) -> Result { + let tensor = GpuTensor3::alloc(ctx, batch, rows, cols)?; + let argmax = GpuTensor3::alloc(ctx, batch, rows, cols)?; + Ok(Self { tensor, argmax }) + } + + /// Get batch size. + pub fn batch(&self) -> usize { + self.tensor.batch() + } + + /// Get rows per matrix. + pub fn rows(&self) -> usize { + self.tensor.rows() + } + + /// Get cols per matrix. + pub fn cols(&self) -> usize { + self.tensor.cols() + } + + /// Copy the result tensor back to host. + pub fn tensor_to_host(&self, ctx: &CudaContext) -> Result> { + self.tensor.to_host(ctx) + } + + /// Copy the argmax indices back to host. + pub fn argmax_to_host(&self, ctx: &CudaContext) -> Result> { + self.argmax.to_host(ctx) + } + + /// Consume self and return the tensor and argmax components separately. + /// + /// This is useful for DLPack export where each tensor needs to be wrapped + /// independently for ownership transfer. + pub fn into_parts(self) -> (GpuTensor3, GpuTensor3) { + (self.tensor, self.argmax) + } +} diff --git a/crates/tropical-gemm-python/python/tropical_gemm/__init__.py b/crates/tropical-gemm-python/python/tropical_gemm/__init__.py index b7ff9e5..dc23d61 100644 --- a/crates/tropical-gemm-python/python/tropical_gemm/__init__.py +++ b/crates/tropical-gemm-python/python/tropical_gemm/__init__.py @@ -58,6 +58,22 @@ maxplus_matmul_i64, minplus_matmul_i64, maxmul_matmul_i64, + # 2D output variants (f32) + maxplus_matmul_2d, + minplus_matmul_2d, + maxmul_matmul_2d, + # 2D output variants (f64) + maxplus_matmul_2d_f64, + minplus_matmul_2d_f64, + maxmul_matmul_2d_f64, + # 2D output variants (i32) + maxplus_matmul_2d_i32, + minplus_matmul_2d_i32, + maxmul_matmul_2d_i32, + # 2D output variants (i64) + maxplus_matmul_2d_i64, + minplus_matmul_2d_i64, + maxmul_matmul_2d_i64, # CUDA availability cuda_available, ) @@ -73,7 +89,7 @@ maxmul_matmul_gpu_with_argmax, ) -__version__ = "0.1.0" +__version__ = "0.2.0" __all__ = [ # f32 operations @@ -107,6 +123,22 @@ "maxplus_matmul_i64", "minplus_matmul_i64", "maxmul_matmul_i64", + # 2D output variants (f32) + "maxplus_matmul_2d", + "minplus_matmul_2d", + "maxmul_matmul_2d", + # 2D output variants (f64) + "maxplus_matmul_2d_f64", + "minplus_matmul_2d_f64", + "maxmul_matmul_2d_f64", + # 2D output variants (i32) + "maxplus_matmul_2d_i32", + "minplus_matmul_2d_i32", + "maxmul_matmul_2d_i32", + # 2D output variants (i64) + "maxplus_matmul_2d_i64", + "minplus_matmul_2d_i64", + "maxmul_matmul_2d_i64", # CUDA "cuda_available", ] diff --git a/crates/tropical-gemm-python/python/tropical_gemm/pytorch.py b/crates/tropical-gemm-python/python/tropical_gemm/pytorch.py index 728cf61..066ebfe 100644 --- a/crates/tropical-gemm-python/python/tropical_gemm/pytorch.py +++ b/crates/tropical-gemm-python/python/tropical_gemm/pytorch.py @@ -37,89 +37,31 @@ GPU_AVAILABLE = tropical_gemm.cuda_available() -# =========================================================================== -# Helper: Use Rust CPU backend as fallback for GPU tensors without DLPack -# =========================================================================== - - -def _rust_cpu_maxplus_with_argmax( - a: torch.Tensor, b: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Use optimized Rust CPU backend for tropical max-plus matmul. - - This is used as fallback when DLPack is not available. Transfers data - to CPU, uses Rust SIMD-optimized backend, then transfers back to device. - """ - m, k = a.shape - n = b.shape[1] - device = a.device - dtype = a.dtype - - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) - - if not a_np.flags["C_CONTIGUOUS"]: - a_np = np.ascontiguousarray(a_np) - if not b_np.flags["C_CONTIGUOUS"]: - b_np = np.ascontiguousarray(b_np) - - c_flat, argmax_flat = tropical_gemm.maxplus_matmul_with_argmax(a_np, b_np) - - c = torch.from_numpy(np.array(c_flat).reshape(m, n)).to(device=device, dtype=dtype) - argmax = torch.from_numpy(np.array(argmax_flat).reshape(m, n)).to(device) - - return c, argmax - - -def _rust_cpu_minplus_with_argmax( - a: torch.Tensor, b: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """Use optimized Rust CPU backend for tropical min-plus matmul.""" - m, k = a.shape - n = b.shape[1] - device = a.device - dtype = a.dtype - - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) - - if not a_np.flags["C_CONTIGUOUS"]: - a_np = np.ascontiguousarray(a_np) - if not b_np.flags["C_CONTIGUOUS"]: - b_np = np.ascontiguousarray(b_np) - - c_flat, argmax_flat = tropical_gemm.minplus_matmul_with_argmax(a_np, b_np) - - c = torch.from_numpy(np.array(c_flat).reshape(m, n)).to(device=device, dtype=dtype) - argmax = torch.from_numpy(np.array(argmax_flat).reshape(m, n)).to(device) - - return c, argmax - - -def _rust_cpu_maxmul_with_argmax( - a: torch.Tensor, b: torch.Tensor -) -> tuple[torch.Tensor, torch.Tensor]: - """Use optimized Rust CPU backend for tropical max-mul matmul.""" - m, k = a.shape - n = b.shape[1] - device = a.device - dtype = a.dtype - - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) - - if not a_np.flags["C_CONTIGUOUS"]: - a_np = np.ascontiguousarray(a_np) - if not b_np.flags["C_CONTIGUOUS"]: - b_np = np.ascontiguousarray(b_np) - - c_flat, argmax_flat = tropical_gemm.maxmul_matmul_with_argmax(a_np, b_np) - - c = torch.from_numpy(np.array(c_flat).reshape(m, n)).to(device=device, dtype=dtype) - argmax = torch.from_numpy(np.array(argmax_flat).reshape(m, n)).to(device) - - return c, argmax +def _get_dtype_funcs(dtype: torch.dtype): + """Get the appropriate Rust functions for the given dtype.""" + if dtype == torch.float64: + return { + "maxplus": tropical_gemm.maxplus_matmul_with_argmax_f64, + "minplus": tropical_gemm.minplus_matmul_with_argmax_f64, + "maxmul": tropical_gemm.maxmul_matmul_with_argmax_f64, + "backward_a": tropical_gemm.backward_a_f64, + "backward_b": tropical_gemm.backward_b_f64, + "maxmul_backward_a": tropical_gemm.maxmul_backward_a_f64, + "maxmul_backward_b": tropical_gemm.maxmul_backward_b_f64, + "np_dtype": np.float64, + } + else: + # Default to f32 for float32 and other types + return { + "maxplus": tropical_gemm.maxplus_matmul_with_argmax, + "minplus": tropical_gemm.minplus_matmul_with_argmax, + "maxmul": tropical_gemm.maxmul_matmul_with_argmax, + "backward_a": tropical_gemm.backward_a, + "backward_b": tropical_gemm.backward_b, + "maxmul_backward_a": tropical_gemm.maxmul_backward_a, + "maxmul_backward_b": tropical_gemm.maxmul_backward_b, + "np_dtype": np.float32, + } # =========================================================================== @@ -137,6 +79,8 @@ class TropicalMaxPlusMatmul(torch.autograd.Function): For each output C[i,j], the gradient flows back to: - A[i, argmax[i,j]] - B[argmax[i,j], j] + + Supports both float32 and float64 inputs (uses appropriate Rust backend). """ @staticmethod @@ -153,10 +97,17 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: """ m, k = a.shape n = b.shape[1] + funcs = _get_dtype_funcs(a.dtype) + + # Convert to contiguous numpy arrays (preserving dtype) + a_np = a.detach().cpu().numpy() + b_np = b.detach().cpu().numpy() - # Convert to contiguous numpy arrays - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) + # Cast to appropriate dtype if needed (e.g., float16 -> float32) + if a_np.dtype != funcs["np_dtype"]: + a_np = a_np.astype(funcs["np_dtype"]) + if b_np.dtype != funcs["np_dtype"]: + b_np = b_np.astype(funcs["np_dtype"]) # Ensure contiguous layout if not a_np.flags["C_CONTIGUOUS"]: @@ -164,18 +115,19 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: if not b_np.flags["C_CONTIGUOUS"]: b_np = np.ascontiguousarray(b_np) - # Call the optimized Rust implementation (returns flattened arrays) - c_flat, argmax_flat = tropical_gemm.maxplus_matmul_with_argmax(a_np, b_np) + # Call the optimized Rust implementation (returns flattened numpy arrays) + c_flat, argmax_flat = funcs["maxplus"](a_np, b_np) - # Reshape to 2D - c_np = np.array(c_flat).reshape(m, n) - argmax_np = np.array(argmax_flat).reshape(m, n) + # Reshape to 2D (c_flat and argmax_flat are already numpy arrays) + c_np = np.asarray(c_flat).reshape(m, n) + argmax_np = np.asarray(argmax_flat).reshape(m, n) - # Save argmax for backward pass + # Save for backward pass ctx.save_for_backward(torch.from_numpy(argmax_np)) ctx.k = k ctx.m = m ctx.n = n + ctx.dtype = a.dtype return torch.from_numpy(c_np).to(a.device) @@ -191,23 +143,28 @@ def backward(ctx, grad_c: torch.Tensor): k = ctx.k m = ctx.m n = ctx.n + funcs = _get_dtype_funcs(ctx.dtype) - # Ensure contiguous numpy arrays - grad_c_np = grad_c.cpu().numpy().astype(np.float32) + # Convert to numpy (preserving dtype for grad_c) + grad_c_np = grad_c.cpu().numpy() argmax_np = argmax.numpy().astype(np.int32) + # Cast to appropriate dtype if needed + if grad_c_np.dtype != funcs["np_dtype"]: + grad_c_np = grad_c_np.astype(funcs["np_dtype"]) + if not grad_c_np.flags["C_CONTIGUOUS"]: grad_c_np = np.ascontiguousarray(grad_c_np) if not argmax_np.flags["C_CONTIGUOUS"]: argmax_np = np.ascontiguousarray(argmax_np) - # Compute gradients using the Rust backend (returns flattened arrays) - grad_a_flat = tropical_gemm.backward_a(grad_c_np, argmax_np, k) - grad_b_flat = tropical_gemm.backward_b(grad_c_np, argmax_np, k) + # Compute gradients using the Rust backend (returns flattened numpy arrays) + grad_a_flat = funcs["backward_a"](grad_c_np, argmax_np, k) + grad_b_flat = funcs["backward_b"](grad_c_np, argmax_np, k) - # Reshape to 2D - grad_a = torch.from_numpy(np.array(grad_a_flat).reshape(m, k)).to(grad_c.device) - grad_b = torch.from_numpy(np.array(grad_b_flat).reshape(k, n)).to(grad_c.device) + # Reshape to 2D (grad_*_flat are already numpy arrays) + grad_a = torch.from_numpy(np.asarray(grad_a_flat).reshape(m, k)).to(grad_c.device) + grad_b = torch.from_numpy(np.asarray(grad_b_flat).reshape(k, n)).to(grad_c.device) return grad_a, grad_b @@ -219,30 +176,38 @@ class TropicalMinPlusMatmul(torch.autograd.Function): Forward: C[i,j] = min_k(A[i,k] + B[k,j]) Useful for shortest path computations in graphs. + Supports both float32 and float64 inputs. """ @staticmethod def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: m, k = a.shape n = b.shape[1] + funcs = _get_dtype_funcs(a.dtype) + + a_np = a.detach().cpu().numpy() + b_np = b.detach().cpu().numpy() - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) + if a_np.dtype != funcs["np_dtype"]: + a_np = a_np.astype(funcs["np_dtype"]) + if b_np.dtype != funcs["np_dtype"]: + b_np = b_np.astype(funcs["np_dtype"]) if not a_np.flags["C_CONTIGUOUS"]: a_np = np.ascontiguousarray(a_np) if not b_np.flags["C_CONTIGUOUS"]: b_np = np.ascontiguousarray(b_np) - c_flat, argmax_flat = tropical_gemm.minplus_matmul_with_argmax(a_np, b_np) + c_flat, argmax_flat = funcs["minplus"](a_np, b_np) - c_np = np.array(c_flat).reshape(m, n) - argmax_np = np.array(argmax_flat).reshape(m, n) + c_np = np.asarray(c_flat).reshape(m, n) + argmax_np = np.asarray(argmax_flat).reshape(m, n) ctx.save_for_backward(torch.from_numpy(argmax_np)) ctx.k = k ctx.m = m ctx.n = n + ctx.dtype = a.dtype return torch.from_numpy(c_np).to(a.device) @@ -252,20 +217,24 @@ def backward(ctx, grad_c: torch.Tensor): k = ctx.k m = ctx.m n = ctx.n + funcs = _get_dtype_funcs(ctx.dtype) - grad_c_np = grad_c.cpu().numpy().astype(np.float32) + grad_c_np = grad_c.cpu().numpy() argmax_np = argmax.numpy().astype(np.int32) + if grad_c_np.dtype != funcs["np_dtype"]: + grad_c_np = grad_c_np.astype(funcs["np_dtype"]) + if not grad_c_np.flags["C_CONTIGUOUS"]: grad_c_np = np.ascontiguousarray(grad_c_np) if not argmax_np.flags["C_CONTIGUOUS"]: argmax_np = np.ascontiguousarray(argmax_np) - grad_a_flat = tropical_gemm.backward_a(grad_c_np, argmax_np, k) - grad_b_flat = tropical_gemm.backward_b(grad_c_np, argmax_np, k) + grad_a_flat = funcs["backward_a"](grad_c_np, argmax_np, k) + grad_b_flat = funcs["backward_b"](grad_c_np, argmax_np, k) - grad_a = torch.from_numpy(np.array(grad_a_flat).reshape(m, k)).to(grad_c.device) - grad_b = torch.from_numpy(np.array(grad_b_flat).reshape(k, n)).to(grad_c.device) + grad_a = torch.from_numpy(np.asarray(grad_a_flat).reshape(m, k)).to(grad_c.device) + grad_b = torch.from_numpy(np.asarray(grad_b_flat).reshape(k, n)).to(grad_c.device) return grad_a, grad_b @@ -282,25 +251,32 @@ class TropicalMaxMulMatmul(torch.autograd.Function): - grad_B[k,j] = grad_C[i,j] * A[i,k] if k == argmax[i,j] Useful for max-probability computations (non-log space). + Supports both float32 and float64 inputs. """ @staticmethod def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: m, k = a.shape n = b.shape[1] + funcs = _get_dtype_funcs(a.dtype) - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) + a_np = a.detach().cpu().numpy() + b_np = b.detach().cpu().numpy() + + if a_np.dtype != funcs["np_dtype"]: + a_np = a_np.astype(funcs["np_dtype"]) + if b_np.dtype != funcs["np_dtype"]: + b_np = b_np.astype(funcs["np_dtype"]) if not a_np.flags["C_CONTIGUOUS"]: a_np = np.ascontiguousarray(a_np) if not b_np.flags["C_CONTIGUOUS"]: b_np = np.ascontiguousarray(b_np) - c_flat, argmax_flat = tropical_gemm.maxmul_matmul_with_argmax(a_np, b_np) + c_flat, argmax_flat = funcs["maxmul"](a_np, b_np) - c_np = np.array(c_flat).reshape(m, n) - argmax_np = np.array(argmax_flat).reshape(m, n) + c_np = np.asarray(c_flat).reshape(m, n) + argmax_np = np.asarray(argmax_flat).reshape(m, n) # Save inputs and argmax for backward (needed for multiplicative gradient) ctx.save_for_backward( @@ -311,6 +287,7 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: ctx.k = k ctx.m = m ctx.n = n + ctx.dtype = a.dtype return torch.from_numpy(c_np).to(a.device) @@ -320,11 +297,15 @@ def backward(ctx, grad_c: torch.Tensor): k_dim = ctx.k m = ctx.m n = ctx.n + funcs = _get_dtype_funcs(ctx.dtype) - grad_c_np = grad_c.cpu().numpy().astype(np.float32) + grad_c_np = grad_c.cpu().numpy() argmax_np = argmax.numpy().astype(np.int32) - a_np = a.numpy().astype(np.float32) - b_np = b.numpy().astype(np.float32) + a_np = a.numpy() + b_np = b.numpy() + + if grad_c_np.dtype != funcs["np_dtype"]: + grad_c_np = grad_c_np.astype(funcs["np_dtype"]) if not grad_c_np.flags["C_CONTIGUOUS"]: grad_c_np = np.ascontiguousarray(grad_c_np) @@ -332,13 +313,13 @@ def backward(ctx, grad_c: torch.Tensor): argmax_np = np.ascontiguousarray(argmax_np) # Use multiplicative backward rule - grad_a_flat = tropical_gemm.maxmul_backward_a(grad_c_np, argmax_np, b_np) - grad_b_flat = tropical_gemm.maxmul_backward_b(grad_c_np, argmax_np, a_np) + grad_a_flat = funcs["maxmul_backward_a"](grad_c_np, argmax_np, b_np) + grad_b_flat = funcs["maxmul_backward_b"](grad_c_np, argmax_np, a_np) - grad_a = torch.from_numpy(np.array(grad_a_flat).reshape(m, k_dim)).to( + grad_a = torch.from_numpy(np.asarray(grad_a_flat).reshape(m, k_dim)).to( grad_c.device ) - grad_b = torch.from_numpy(np.array(grad_b_flat).reshape(k_dim, n)).to( + grad_b = torch.from_numpy(np.asarray(grad_b_flat).reshape(k_dim, n)).to( grad_c.device ) @@ -351,15 +332,16 @@ def backward(ctx, grad_c: torch.Tensor): # Check if DLPack functions are available (CUDA build) _DLPACK_AVAILABLE = hasattr(tropical_gemm, "maxplus_matmul_dlpack") +_BATCHED_DLPACK_AVAILABLE = hasattr(tropical_gemm, "maxplus_matmul_batched_dlpack") class TropicalMaxPlusMatmulGPU(torch.autograd.Function): """ GPU-accelerated MaxPlus tropical matrix multiplication. - Uses the optimized Rust CUDA backend via DLPack for zero-copy GPU tensor exchange. - Input tensors stay on GPU; only the forward pass output requires a D2H transfer - (Phase 1). The backward pass runs entirely on GPU using PyTorch scatter operations. + Uses the optimized Rust CUDA backend via DLPack for true zero-copy GPU tensor exchange. + Both inputs and outputs stay on GPU - no host transfers in the forward pass. + The backward pass runs entirely on GPU using PyTorch scatter operations. """ @staticmethod @@ -367,19 +349,32 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: m, k = a.shape n = b.shape[1] - if _DLPACK_AVAILABLE and a.is_cuda: - # Use Rust CUDA backend via DLPack (zero-copy for inputs) - a_contig = a.detach().contiguous() - b_contig = b.detach().contiguous() - - c_flat, argmax_flat = tropical_gemm.maxplus_matmul_dlpack(a_contig, b_contig) - - # Reshape results (numpy arrays from Rust) - c = torch.from_numpy(np.array(c_flat).reshape(m, n)).to(a.device) - argmax = torch.from_numpy(np.array(argmax_flat).reshape(m, n)).to(a.device) - else: - # Fallback to optimized Rust CPU backend (still O(M*K + K*N) memory) - c, argmax = _rust_cpu_maxplus_with_argmax(a, b) + # GPU functions require CUDA tensors and DLPack support + if not a.is_cuda: + raise RuntimeError( + "TropicalMaxPlusMatmulGPU requires CUDA tensors. " + "Use TropicalMaxPlusMatmul for CPU tensors." + ) + if not _DLPACK_AVAILABLE: + raise RuntimeError( + "DLPack support not available. Rebuild tropical-gemm with CUDA support " + "or use TropicalMaxPlusMatmul for CPU execution." + ) + + # Use Rust CUDA backend via DLPack (zero-copy for inputs and outputs) + a_contig = a.detach() + if not a_contig.is_contiguous(): + a_contig = a_contig.contiguous() + b_contig = b.detach() + if not b_contig.is_contiguous(): + b_contig = b_contig.contiguous() + + # Returns DLPack capsules - data stays on GPU + c_capsule, argmax_capsule = tropical_gemm.maxplus_matmul_dlpack(a_contig, b_contig) + + # Convert DLPack capsules to PyTorch tensors (zero-copy on GPU) + c = torch.from_dlpack(c_capsule).reshape(m, n) + argmax = torch.from_dlpack(argmax_capsule).reshape(m, n).to(torch.int64) ctx.save_for_backward(argmax) ctx.k = k @@ -420,18 +415,32 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: m, k = a.shape n = b.shape[1] - if _DLPACK_AVAILABLE and a.is_cuda: - # Use Rust CUDA backend via DLPack - a_contig = a.detach().contiguous() - b_contig = b.detach().contiguous() - - c_flat, argmax_flat = tropical_gemm.minplus_matmul_dlpack(a_contig, b_contig) - - c = torch.from_numpy(np.array(c_flat).reshape(m, n)).to(a.device) - argmax = torch.from_numpy(np.array(argmax_flat).reshape(m, n)).to(a.device) - else: - # Fallback to optimized Rust CPU backend - c, argmax = _rust_cpu_minplus_with_argmax(a, b) + # GPU functions require CUDA tensors and DLPack support + if not a.is_cuda: + raise RuntimeError( + "TropicalMinPlusMatmulGPU requires CUDA tensors. " + "Use TropicalMinPlusMatmul for CPU tensors." + ) + if not _DLPACK_AVAILABLE: + raise RuntimeError( + "DLPack support not available. Rebuild tropical-gemm with CUDA support " + "or use TropicalMinPlusMatmul for CPU execution." + ) + + # Use Rust CUDA backend via DLPack (zero-copy for inputs and outputs) + a_contig = a.detach() + if not a_contig.is_contiguous(): + a_contig = a_contig.contiguous() + b_contig = b.detach() + if not b_contig.is_contiguous(): + b_contig = b_contig.contiguous() + + # Returns DLPack capsules - data stays on GPU + c_capsule, argmax_capsule = tropical_gemm.minplus_matmul_dlpack(a_contig, b_contig) + + # Convert DLPack capsules to PyTorch tensors (zero-copy on GPU) + c = torch.from_dlpack(c_capsule).reshape(m, n) + argmax = torch.from_dlpack(argmax_capsule).reshape(m, n).to(torch.int64) ctx.save_for_backward(argmax) ctx.k = k @@ -473,22 +482,35 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: m, k = a.shape n = b.shape[1] - if _DLPACK_AVAILABLE and a.is_cuda: - # Use Rust CUDA backend via DLPack - a_contig = a.detach().contiguous() - b_contig = b.detach().contiguous() - - c_flat, argmax_flat = tropical_gemm.maxmul_matmul_dlpack(a_contig, b_contig) - - c = torch.from_numpy(np.array(c_flat).reshape(m, n)).to(a.device) - argmax = torch.from_numpy(np.array(argmax_flat).reshape(m, n)).to(a.device) - - # Save original tensors for multiplicative backward - ctx.save_for_backward(a.detach(), b.detach(), argmax) - else: - # Fallback to optimized Rust CPU backend - c, argmax = _rust_cpu_maxmul_with_argmax(a, b) - ctx.save_for_backward(a.detach(), b.detach(), argmax) + # GPU functions require CUDA tensors and DLPack support + if not a.is_cuda: + raise RuntimeError( + "TropicalMaxMulMatmulGPU requires CUDA tensors. " + "Use TropicalMaxMulMatmul for CPU tensors." + ) + if not _DLPACK_AVAILABLE: + raise RuntimeError( + "DLPack support not available. Rebuild tropical-gemm with CUDA support " + "or use TropicalMaxMulMatmul for CPU execution." + ) + + # Use Rust CUDA backend via DLPack (zero-copy for inputs and outputs) + a_contig = a.detach() + if not a_contig.is_contiguous(): + a_contig = a_contig.contiguous() + b_contig = b.detach() + if not b_contig.is_contiguous(): + b_contig = b_contig.contiguous() + + # Returns DLPack capsules - data stays on GPU + c_capsule, argmax_capsule = tropical_gemm.maxmul_matmul_dlpack(a_contig, b_contig) + + # Convert DLPack capsules to PyTorch tensors (zero-copy on GPU) + c = torch.from_dlpack(c_capsule).reshape(m, n) + argmax = torch.from_dlpack(argmax_capsule).reshape(m, n).to(torch.int64) + + # Save original tensors for multiplicative backward + ctx.save_for_backward(a.detach(), b.detach(), argmax) ctx.k = k ctx.m = m @@ -550,27 +572,62 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: batch, m, k = a.shape n = b.shape[2] - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) + # Check if either tensor is on CUDA + if a.is_cuda or b.is_cuda: + # Validate both tensors are on the same device + if a.device != b.device: + raise RuntimeError( + f"Tensors must be on the same device: a is on {a.device}, b is on {b.device}" + ) + # CUDA tensors require batched DLPack support + if not _BATCHED_DLPACK_AVAILABLE: + raise RuntimeError( + "Batched GPU operations require CUDA support. " + "Rebuild tropical-gemm with CUDA support or move tensors to CPU." + ) + # Use Rust CUDA backend via DLPack (zero-copy for inputs AND outputs) + a_contig = a.detach() + if not a_contig.is_contiguous(): + a_contig = a_contig.contiguous() + b_contig = b.detach() + if not b_contig.is_contiguous(): + b_contig = b_contig.contiguous() + + # Returns DLPack capsules - data stays on GPU + c_capsule, argmax_capsule = tropical_gemm.maxplus_matmul_batched_dlpack( + a_contig, b_contig + ) + + # Convert DLPack capsules to PyTorch tensors (zero-copy on GPU) + c = torch.from_dlpack(c_capsule) + # Cast to int64 for PyTorch scatter operations + argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) + else: + # CPU path + a_np = a.detach().cpu().numpy().astype(np.float32) + b_np = b.detach().cpu().numpy().astype(np.float32) - if not a_np.flags["C_CONTIGUOUS"]: - a_np = np.ascontiguousarray(a_np) - if not b_np.flags["C_CONTIGUOUS"]: - b_np = np.ascontiguousarray(b_np) + if not a_np.flags["C_CONTIGUOUS"]: + a_np = np.ascontiguousarray(a_np) + if not b_np.flags["C_CONTIGUOUS"]: + b_np = np.ascontiguousarray(b_np) - c_flat, argmax_flat = tropical_gemm.maxplus_matmul_batched_with_argmax(a_np, b_np) + c_flat, argmax_flat = tropical_gemm.maxplus_matmul_batched_with_argmax( + a_np, b_np + ) - c_np = np.array(c_flat).reshape(batch, m, n) - argmax_np = np.array(argmax_flat).reshape(batch, m, n) + c = torch.from_numpy(np.asarray(c_flat).reshape(batch, m, n)).to(a.device) + argmax = torch.from_numpy(np.asarray(argmax_flat).reshape(batch, m, n)).to( + device=a.device, dtype=torch.int64 + ) - # Save argmax on the same device as input for backward pass - ctx.save_for_backward(torch.from_numpy(argmax_np).to(a.device)) + ctx.save_for_backward(argmax) ctx.k = k ctx.batch = batch ctx.m = m ctx.n = n - return torch.from_numpy(c_np).to(a.device) + return c @staticmethod def backward(ctx, grad_c: torch.Tensor): @@ -611,27 +668,62 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: batch, m, k = a.shape n = b.shape[2] - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) + # Check if either tensor is on CUDA + if a.is_cuda or b.is_cuda: + # Validate both tensors are on the same device + if a.device != b.device: + raise RuntimeError( + f"Tensors must be on the same device: a is on {a.device}, b is on {b.device}" + ) + # CUDA tensors require batched DLPack support + if not _BATCHED_DLPACK_AVAILABLE: + raise RuntimeError( + "Batched GPU operations require CUDA support. " + "Rebuild tropical-gemm with CUDA support or move tensors to CPU." + ) + # Use Rust CUDA backend via DLPack (zero-copy for inputs AND outputs) + a_contig = a.detach() + if not a_contig.is_contiguous(): + a_contig = a_contig.contiguous() + b_contig = b.detach() + if not b_contig.is_contiguous(): + b_contig = b_contig.contiguous() + + # Returns DLPack capsules - data stays on GPU + c_capsule, argmax_capsule = tropical_gemm.minplus_matmul_batched_dlpack( + a_contig, b_contig + ) + + # Convert DLPack capsules to PyTorch tensors (zero-copy on GPU) + c = torch.from_dlpack(c_capsule) + # Cast to int64 for PyTorch scatter operations + argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) + else: + # CPU path + a_np = a.detach().cpu().numpy().astype(np.float32) + b_np = b.detach().cpu().numpy().astype(np.float32) - if not a_np.flags["C_CONTIGUOUS"]: - a_np = np.ascontiguousarray(a_np) - if not b_np.flags["C_CONTIGUOUS"]: - b_np = np.ascontiguousarray(b_np) + if not a_np.flags["C_CONTIGUOUS"]: + a_np = np.ascontiguousarray(a_np) + if not b_np.flags["C_CONTIGUOUS"]: + b_np = np.ascontiguousarray(b_np) - c_flat, argmax_flat = tropical_gemm.minplus_matmul_batched_with_argmax(a_np, b_np) + c_flat, argmax_flat = tropical_gemm.minplus_matmul_batched_with_argmax( + a_np, b_np + ) - c_np = np.array(c_flat).reshape(batch, m, n) - argmax_np = np.array(argmax_flat).reshape(batch, m, n) + c = torch.from_numpy(np.asarray(c_flat).reshape(batch, m, n)).to(a.device) + argmax = torch.from_numpy(np.asarray(argmax_flat).reshape(batch, m, n)).to( + device=a.device, dtype=torch.int64 + ) - # Save argmax on the same device as input for backward pass - ctx.save_for_backward(torch.from_numpy(argmax_np).to(a.device)) + ctx.save_for_backward(argmax) ctx.k = k ctx.batch = batch ctx.m = m ctx.n = n - return torch.from_numpy(c_np).to(a.device) + return c @staticmethod def backward(ctx, grad_c: torch.Tensor): @@ -676,32 +768,69 @@ def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: batch, m, k = a.shape n = b.shape[2] - a_np = a.detach().cpu().numpy().astype(np.float32) - b_np = b.detach().cpu().numpy().astype(np.float32) - - if not a_np.flags["C_CONTIGUOUS"]: - a_np = np.ascontiguousarray(a_np) - if not b_np.flags["C_CONTIGUOUS"]: - b_np = np.ascontiguousarray(b_np) - - c_flat, argmax_flat = tropical_gemm.maxmul_matmul_batched_with_argmax(a_np, b_np) - - c_np = np.array(c_flat).reshape(batch, m, n) - argmax_np = np.array(argmax_flat).reshape(batch, m, n) + # Check if either tensor is on CUDA + if a.is_cuda or b.is_cuda: + # Validate both tensors are on the same device + if a.device != b.device: + raise RuntimeError( + f"Tensors must be on the same device: a is on {a.device}, b is on {b.device}" + ) + # CUDA tensors require batched DLPack support + if not _BATCHED_DLPACK_AVAILABLE: + raise RuntimeError( + "Batched GPU operations require CUDA support. " + "Rebuild tropical-gemm with CUDA support or move tensors to CPU." + ) + # Use Rust CUDA backend via DLPack (zero-copy for inputs AND outputs) + a_contig = a.detach() + if not a_contig.is_contiguous(): + a_contig = a_contig.contiguous() + b_contig = b.detach() + if not b_contig.is_contiguous(): + b_contig = b_contig.contiguous() + + # Returns DLPack capsules - data stays on GPU + c_capsule, argmax_capsule = tropical_gemm.maxmul_matmul_batched_dlpack( + a_contig, b_contig + ) + + # Convert DLPack capsules to PyTorch tensors (zero-copy on GPU) + c = torch.from_dlpack(c_capsule) + # Cast to int64 for PyTorch scatter operations + argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) + # Save tensors for backward pass + ctx.save_for_backward(a.detach(), b.detach(), argmax) + else: + # CPU path + a_np = a.detach().cpu().numpy().astype(np.float32) + b_np = b.detach().cpu().numpy().astype(np.float32) + + if not a_np.flags["C_CONTIGUOUS"]: + a_np = np.ascontiguousarray(a_np) + if not b_np.flags["C_CONTIGUOUS"]: + b_np = np.ascontiguousarray(b_np) + + c_flat, argmax_flat = tropical_gemm.maxmul_matmul_batched_with_argmax( + a_np, b_np + ) + + c = torch.from_numpy(np.asarray(c_flat).reshape(batch, m, n)).to(a.device) + argmax = torch.from_numpy(np.asarray(argmax_flat).reshape(batch, m, n)).to( + device=a.device, dtype=torch.int64 + ) + # Save tensors on the same device as input for backward pass + ctx.save_for_backward( + torch.from_numpy(a_np).to(a.device), + torch.from_numpy(b_np).to(a.device), + argmax, + ) - # Save tensors on the same device as input for backward pass - device = a.device - ctx.save_for_backward( - torch.from_numpy(a_np).to(device), - torch.from_numpy(b_np).to(device), - torch.from_numpy(argmax_np).to(device), - ) ctx.k = k ctx.batch = batch ctx.m = m ctx.n = n - return torch.from_numpy(c_np).to(device) + return c @staticmethod def backward(ctx, grad_c: torch.Tensor): diff --git a/crates/tropical-gemm-python/src/lib.rs b/crates/tropical-gemm-python/src/lib.rs index 4c2d24e..b08fe5b 100644 --- a/crates/tropical-gemm-python/src/lib.rs +++ b/crates/tropical-gemm-python/src/lib.rs @@ -7,7 +7,8 @@ //! //! - `cuda`: Enable GPU acceleration via CUDA -use numpy::{IntoPyArray, PyArray1, PyReadonlyArray2, PyReadonlyArray3, PyUntypedArrayMethods}; +use numpy::ndarray::Array2; +use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray2, PyReadonlyArray3, PyUntypedArrayMethods, ToPyArray}; use pyo3::prelude::*; // Use fully qualified path to avoid naming conflict with the pymodule @@ -43,17 +44,17 @@ fn maxplus_matmul<'py>( ))); } - // Get contiguous data - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - // Perform tropical matmul - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); - // Extract scalar values from semiring wrapper - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); - - // Create output array + // Create output array (requires GIL) Ok(c_scalars.into_pyarray(py)) } @@ -84,12 +85,15 @@ fn minplus_matmul<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; - - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } @@ -123,14 +127,18 @@ fn maxplus_matmul_with_argmax<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_data, m, k, b_data, n); - - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + // Release GIL during heavy compute + let (c_scalars, argmax_i32) = py.allow_threads(|| { + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(&a_data, m, k, &b_data, n); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + }); let c_result = c_scalars.into_pyarray(py); let argmax_result = argmax_i32.into_pyarray(py); @@ -167,14 +175,18 @@ fn minplus_matmul_with_argmax<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_data, m, k, b_data, n); - - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + // Release GIL during heavy compute + let (c_scalars, argmax_i32) = py.allow_threads(|| { + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(&a_data, m, k, &b_data, n); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + }); let c_result = c_scalars.into_pyarray(py); let argmax_result = argmax_i32.into_pyarray(py); @@ -207,21 +219,24 @@ fn backward_a<'py>( let m = shape[0]; let n = shape[1]; - let grad_c_data = grad_c.as_slice()?; - let argmax_data = argmax.as_slice()?; - - // Compute gradient w.r.t. A - let mut grad_a = vec![0.0f32; m * k]; - - for i in 0..m { - for j in 0..n { - let idx = i * n + j; - let k_idx = argmax_data[idx] as usize; - if k_idx < k { - grad_a[i * k + k_idx] += grad_c_data[idx]; + // Clone to owned data before releasing GIL + let grad_c_data = grad_c.as_slice()?.to_vec(); + let argmax_data = argmax.as_slice()?.to_vec(); + + // Release GIL during compute + let grad_a = py.allow_threads(|| { + let mut grad_a = vec![0.0f32; m * k]; + for i in 0..m { + for j in 0..n { + let idx = i * n + j; + let k_idx = argmax_data[idx] as usize; + if k_idx < k { + grad_a[i * k + k_idx] += grad_c_data[idx]; + } } } - } + grad_a + }); Ok(grad_a.into_pyarray(py)) } @@ -251,21 +266,24 @@ fn backward_b<'py>( let m = shape[0]; let n = shape[1]; - let grad_c_data = grad_c.as_slice()?; - let argmax_data = argmax.as_slice()?; - - // Compute gradient w.r.t. B - let mut grad_b = vec![0.0f32; k * n]; - - for i in 0..m { - for j in 0..n { - let idx = i * n + j; - let k_idx = argmax_data[idx] as usize; - if k_idx < k { - grad_b[k_idx * n + j] += grad_c_data[idx]; + // Clone to owned data before releasing GIL + let grad_c_data = grad_c.as_slice()?.to_vec(); + let argmax_data = argmax.as_slice()?.to_vec(); + + // Release GIL during compute + let grad_b = py.allow_threads(|| { + let mut grad_b = vec![0.0f32; k * n]; + for i in 0..m { + for j in 0..n { + let idx = i * n + j; + let k_idx = argmax_data[idx] as usize; + if k_idx < k { + grad_b[k_idx * n + j] += grad_c_data[idx]; + } } } - } + grad_b + }); Ok(grad_b.into_pyarray(py)) } @@ -274,6 +292,125 @@ fn backward_b<'py>( // MaxMul operations (f32) // ============================================================================ +// ============================================================================ +// 2D output variants (f32) +// ============================================================================ + +/// Tropical MaxPlus matrix multiplication returning 2D array: C[i,j] = max_k(A[i,k] + B[k,j]) +/// +/// Args: +/// a: Input matrix A of shape (M, K) +/// b: Input matrix B of shape (K, N) +/// +/// Returns: +/// Result matrix C of shape (M, N) as a 2D array +#[pyfunction] +fn maxplus_matmul_2d<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, f32>, + b: PyReadonlyArray2<'py, f32>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + // Create 2D output array + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +/// Tropical MinPlus matrix multiplication returning 2D array: C[i,j] = min_k(A[i,k] + B[k,j]) +#[pyfunction] +fn minplus_matmul_2d<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, f32>, + b: PyReadonlyArray2<'py, f32>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + // Create 2D output array + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +/// Tropical MaxMul matrix multiplication returning 2D array: C[i,j] = max_k(A[i,k] * B[k,j]) +#[pyfunction] +fn maxmul_matmul_2d<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, f32>, + b: PyReadonlyArray2<'py, f32>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + // Create 2D output array + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + /// Tropical MaxMul matrix multiplication: C[i,j] = max_k(A[i,k] * B[k,j]) #[pyfunction] fn maxmul_matmul<'py>( @@ -294,11 +431,15 @@ fn maxmul_matmul<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } @@ -323,14 +464,18 @@ fn maxmul_matmul_with_argmax<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_data, m, k, b_data, n); - - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + // Release GIL during heavy compute + let (c_scalars, argmax_i32) = py.allow_threads(|| { + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(&a_data, m, k, &b_data, n); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + }); Ok((c_scalars.into_pyarray(py), argmax_i32.into_pyarray(py))) } @@ -359,11 +504,15 @@ fn maxplus_matmul_f64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } @@ -388,11 +537,15 @@ fn minplus_matmul_f64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } @@ -417,15 +570,122 @@ fn maxmul_matmul_f64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } +// ============================================================================ +// 2D output variants (f64) +// ============================================================================ + +/// Tropical MaxPlus matrix multiplication returning 2D array (f64): C[i,j] = max_k(A[i,k] + B[k,j]) +#[pyfunction] +fn maxplus_matmul_2d_f64<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, f64>, + b: PyReadonlyArray2<'py, f64>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +/// Tropical MinPlus matrix multiplication returning 2D array (f64): C[i,j] = min_k(A[i,k] + B[k,j]) +#[pyfunction] +fn minplus_matmul_2d_f64<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, f64>, + b: PyReadonlyArray2<'py, f64>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +/// Tropical MaxMul matrix multiplication returning 2D array (f64): C[i,j] = max_k(A[i,k] * B[k,j]) +#[pyfunction] +fn maxmul_matmul_2d_f64<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, f64>, + b: PyReadonlyArray2<'py, f64>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + /// Tropical MaxPlus matrix multiplication with argmax tracking (f64). #[pyfunction] fn maxplus_matmul_with_argmax_f64<'py>( @@ -446,14 +706,18 @@ fn maxplus_matmul_with_argmax_f64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_data, m, k, b_data, n); - - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + // Release GIL during heavy compute + let (c_scalars, argmax_i32) = py.allow_threads(|| { + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(&a_data, m, k, &b_data, n); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + }); Ok((c_scalars.into_pyarray(py), argmax_i32.into_pyarray(py))) } @@ -478,14 +742,18 @@ fn minplus_matmul_with_argmax_f64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; - - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_data, m, k, b_data, n); + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + // Release GIL during heavy compute + let (c_scalars, argmax_i32) = py.allow_threads(|| { + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(&a_data, m, k, &b_data, n); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + }); Ok((c_scalars.into_pyarray(py), argmax_i32.into_pyarray(py))) } @@ -510,14 +778,18 @@ fn maxmul_matmul_with_argmax_f64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_data, m, k, b_data, n); - - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + // Release GIL during heavy compute + let (c_scalars, argmax_i32) = py.allow_threads(|| { + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(&a_data, m, k, &b_data, n); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + }); Ok((c_scalars.into_pyarray(py), argmax_i32.into_pyarray(py))) } @@ -534,20 +806,24 @@ fn backward_a_f64<'py>( let m = shape[0]; let n = shape[1]; - let grad_c_data = grad_c.as_slice()?; - let argmax_data = argmax.as_slice()?; - - let mut grad_a = vec![0.0f64; m * k]; - - for i in 0..m { - for j in 0..n { - let idx = i * n + j; - let k_idx = argmax_data[idx] as usize; - if k_idx < k { - grad_a[i * k + k_idx] += grad_c_data[idx]; + // Clone to owned data before releasing GIL + let grad_c_data = grad_c.as_slice()?.to_vec(); + let argmax_data = argmax.as_slice()?.to_vec(); + + // Release GIL during compute + let grad_a = py.allow_threads(|| { + let mut grad_a = vec![0.0f64; m * k]; + for i in 0..m { + for j in 0..n { + let idx = i * n + j; + let k_idx = argmax_data[idx] as usize; + if k_idx < k { + grad_a[i * k + k_idx] += grad_c_data[idx]; + } } } - } + grad_a + }); Ok(grad_a.into_pyarray(py)) } @@ -564,20 +840,24 @@ fn backward_b_f64<'py>( let m = shape[0]; let n = shape[1]; - let grad_c_data = grad_c.as_slice()?; - let argmax_data = argmax.as_slice()?; - - let mut grad_b = vec![0.0f64; k * n]; - - for i in 0..m { - for j in 0..n { - let idx = i * n + j; - let k_idx = argmax_data[idx] as usize; - if k_idx < k { - grad_b[k_idx * n + j] += grad_c_data[idx]; + // Clone to owned data before releasing GIL + let grad_c_data = grad_c.as_slice()?.to_vec(); + let argmax_data = argmax.as_slice()?.to_vec(); + + // Release GIL during compute + let grad_b = py.allow_threads(|| { + let mut grad_b = vec![0.0f64; k * n]; + for i in 0..m { + for j in 0..n { + let idx = i * n + j; + let k_idx = argmax_data[idx] as usize; + if k_idx < k { + grad_b[k_idx * n + j] += grad_c_data[idx]; + } } } - } + grad_b + }); Ok(grad_b.into_pyarray(py)) } @@ -605,22 +885,26 @@ fn maxmul_backward_a<'py>( let n = shape[1]; let k = b.shape()[0]; - let grad_c_data = grad_c.as_slice()?; - let argmax_data = argmax.as_slice()?; - let b_data = b.as_slice()?; - - let mut grad_a = vec![0.0f32; m * k]; - - for i in 0..m { - for j in 0..n { - let idx = i * n + j; - let k_idx = argmax_data[idx] as usize; - if k_idx < k { - // grad_A[i,k] += grad_C[i,j] * B[k,j] - grad_a[i * k + k_idx] += grad_c_data[idx] * b_data[k_idx * n + j]; + // Clone to owned data before releasing GIL + let grad_c_data = grad_c.as_slice()?.to_vec(); + let argmax_data = argmax.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + // Release GIL during compute + let grad_a = py.allow_threads(|| { + let mut grad_a = vec![0.0f32; m * k]; + for i in 0..m { + for j in 0..n { + let idx = i * n + j; + let k_idx = argmax_data[idx] as usize; + if k_idx < k { + // grad_A[i,k] += grad_C[i,j] * B[k,j] + grad_a[i * k + k_idx] += grad_c_data[idx] * b_data[k_idx * n + j]; + } } } - } + grad_a + }); Ok(grad_a.into_pyarray(py)) } @@ -641,22 +925,26 @@ fn maxmul_backward_b<'py>( let n = shape[1]; let k = a.shape()[1]; - let grad_c_data = grad_c.as_slice()?; - let argmax_data = argmax.as_slice()?; - let a_data = a.as_slice()?; - - let mut grad_b = vec![0.0f32; k * n]; - - for i in 0..m { - for j in 0..n { - let idx = i * n + j; - let k_idx = argmax_data[idx] as usize; - if k_idx < k { - // grad_B[k,j] += grad_C[i,j] * A[i,k] - grad_b[k_idx * n + j] += grad_c_data[idx] * a_data[i * k + k_idx]; + // Clone to owned data before releasing GIL + let grad_c_data = grad_c.as_slice()?.to_vec(); + let argmax_data = argmax.as_slice()?.to_vec(); + let a_data = a.as_slice()?.to_vec(); + + // Release GIL during compute + let grad_b = py.allow_threads(|| { + let mut grad_b = vec![0.0f32; k * n]; + for i in 0..m { + for j in 0..n { + let idx = i * n + j; + let k_idx = argmax_data[idx] as usize; + if k_idx < k { + // grad_B[k,j] += grad_C[i,j] * A[i,k] + grad_b[k_idx * n + j] += grad_c_data[idx] * a_data[i * k + k_idx]; + } } } - } + grad_b + }); Ok(grad_b.into_pyarray(py)) } @@ -674,21 +962,25 @@ fn maxmul_backward_a_f64<'py>( let n = shape[1]; let k = b.shape()[0]; - let grad_c_data = grad_c.as_slice()?; - let argmax_data = argmax.as_slice()?; - let b_data = b.as_slice()?; - - let mut grad_a = vec![0.0f64; m * k]; - - for i in 0..m { - for j in 0..n { - let idx = i * n + j; - let k_idx = argmax_data[idx] as usize; - if k_idx < k { - grad_a[i * k + k_idx] += grad_c_data[idx] * b_data[k_idx * n + j]; + // Clone to owned data before releasing GIL + let grad_c_data = grad_c.as_slice()?.to_vec(); + let argmax_data = argmax.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + // Release GIL during compute + let grad_a = py.allow_threads(|| { + let mut grad_a = vec![0.0f64; m * k]; + for i in 0..m { + for j in 0..n { + let idx = i * n + j; + let k_idx = argmax_data[idx] as usize; + if k_idx < k { + grad_a[i * k + k_idx] += grad_c_data[idx] * b_data[k_idx * n + j]; + } } } - } + grad_a + }); Ok(grad_a.into_pyarray(py)) } @@ -706,21 +998,25 @@ fn maxmul_backward_b_f64<'py>( let n = shape[1]; let k = a.shape()[1]; - let grad_c_data = grad_c.as_slice()?; - let argmax_data = argmax.as_slice()?; - let a_data = a.as_slice()?; - - let mut grad_b = vec![0.0f64; k * n]; - - for i in 0..m { - for j in 0..n { - let idx = i * n + j; - let k_idx = argmax_data[idx] as usize; - if k_idx < k { - grad_b[k_idx * n + j] += grad_c_data[idx] * a_data[i * k + k_idx]; + // Clone to owned data before releasing GIL + let grad_c_data = grad_c.as_slice()?.to_vec(); + let argmax_data = argmax.as_slice()?.to_vec(); + let a_data = a.as_slice()?.to_vec(); + + // Release GIL during compute + let grad_b = py.allow_threads(|| { + let mut grad_b = vec![0.0f64; k * n]; + for i in 0..m { + for j in 0..n { + let idx = i * n + j; + let k_idx = argmax_data[idx] as usize; + if k_idx < k { + grad_b[k_idx * n + j] += grad_c_data[idx] * a_data[i * k + k_idx]; + } } } - } + grad_b + }); Ok(grad_b.into_pyarray(py)) } @@ -749,11 +1045,15 @@ fn maxplus_matmul_i32<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } @@ -778,11 +1078,15 @@ fn minplus_matmul_i32<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } @@ -807,15 +1111,122 @@ fn maxmul_matmul_i32<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } +// ============================================================================ +// 2D output variants (i32) +// ============================================================================ + +/// Tropical MaxPlus matrix multiplication returning 2D array (i32): C[i,j] = max_k(A[i,k] + B[k,j]) +#[pyfunction] +fn maxplus_matmul_2d_i32<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, i32>, + b: PyReadonlyArray2<'py, i32>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +/// Tropical MinPlus matrix multiplication returning 2D array (i32): C[i,j] = min_k(A[i,k] + B[k,j]) +#[pyfunction] +fn minplus_matmul_2d_i32<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, i32>, + b: PyReadonlyArray2<'py, i32>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +/// Tropical MaxMul matrix multiplication returning 2D array (i32): C[i,j] = max_k(A[i,k] * B[k,j]) +#[pyfunction] +fn maxmul_matmul_2d_i32<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, i32>, + b: PyReadonlyArray2<'py, i32>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + // ============================================================================ // i64 operations // ============================================================================ @@ -840,11 +1251,15 @@ fn maxplus_matmul_i64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } @@ -869,11 +1284,15 @@ fn minplus_matmul_i64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } @@ -898,27 +1317,134 @@ fn maxmul_matmul_i64<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let c_data = tropical_matmul::>(a_data, m, k, b_data, n); - let c_scalars: Vec = c_data.iter().map(|x| x.value()).collect(); + // Release GIL during heavy compute + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); Ok(c_scalars.into_pyarray(py)) } // ============================================================================ -// Batched operations (3D arrays: batch × rows × cols) +// 2D output variants (i64) // ============================================================================ -/// Batched MaxPlus tropical matrix multiplication with argmax tracking. -/// -/// Args: -/// a: Input tensor of shape (batch, M, K) -/// b: Input tensor of shape (batch, K, N) -/// -/// Returns: -/// Tuple of (C, argmax) where: +/// Tropical MaxPlus matrix multiplication returning 2D array (i64): C[i,j] = max_k(A[i,k] + B[k,j]) +#[pyfunction] +fn maxplus_matmul_2d_i64<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, i64>, + b: PyReadonlyArray2<'py, i64>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +/// Tropical MinPlus matrix multiplication returning 2D array (i64): C[i,j] = min_k(A[i,k] + B[k,j]) +#[pyfunction] +fn minplus_matmul_2d_i64<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, i64>, + b: PyReadonlyArray2<'py, i64>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +/// Tropical MaxMul matrix multiplication returning 2D array (i64): C[i,j] = max_k(A[i,k] * B[k,j]) +#[pyfunction] +fn maxmul_matmul_2d_i64<'py>( + py: Python<'py>, + a: PyReadonlyArray2<'py, i64>, + b: PyReadonlyArray2<'py, i64>, +) -> PyResult>> { + let a_shape = a.shape(); + let b_shape = b.shape(); + let m = a_shape[0]; + let k = a_shape[1]; + let n = b_shape[1]; + + if k != b_shape[0] { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Dimension mismatch: A is {}x{}, B is {}x{}", + m, k, b_shape[0], n + ))); + } + + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); + + let c_scalars = py.allow_threads(|| { + let c_data = tropical_matmul::>(&a_data, m, k, &b_data, n); + c_data.iter().map(|x| x.value()).collect::>() + }); + + let arr = Array2::from_shape_vec((m, n), c_scalars) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Shape error: {}", e)))?; + Ok(arr.to_pyarray(py).to_owned()) +} + +// ============================================================================ +// Batched operations (3D arrays: batch × rows × cols) +// ============================================================================ + +/// Batched MaxPlus tropical matrix multiplication with argmax tracking. +/// +/// Args: +/// a: Input tensor of shape (batch, M, K) +/// b: Input tensor of shape (batch, K, N) +/// +/// Returns: +/// Tuple of (C, argmax) where: /// - C: Result tensor of shape (batch × M × N) as flattened array /// - argmax: Indices of shape (batch × M × N) as flattened array #[pyfunction] @@ -948,30 +1474,36 @@ fn maxplus_matmul_batched_with_argmax<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let stride_a = m * k; - let stride_b = k * n; - let stride_c = m * n; + // Release GIL during heavy compute + let (c_result, argmax_result) = py.allow_threads(|| { + let stride_a = m * k; + let stride_b = k * n; + let stride_c = m * n; - let mut c_result = vec![0.0f32; batch * stride_c]; - let mut argmax_result = vec![0i32; batch * stride_c]; + let mut c_result = vec![0.0f32; batch * stride_c]; + let mut argmax_result = vec![0i32; batch * stride_c]; - for i in 0..batch { - let a_slice = &a_data[i * stride_a..(i + 1) * stride_a]; - let b_slice = &b_data[i * stride_b..(i + 1) * stride_b]; + for i in 0..batch { + let a_slice = &a_data[i * stride_a..(i + 1) * stride_a]; + let b_slice = &b_data[i * stride_b..(i + 1) * stride_b]; - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_slice, m, k, b_slice, n); + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(a_slice, m, k, b_slice, n); - for (j, val) in result.values.iter().enumerate() { - c_result[i * stride_c + j] = val.value(); - } - for (j, &idx) in result.argmax.iter().enumerate() { - argmax_result[i * stride_c + j] = idx as i32; + for (j, val) in result.values.iter().enumerate() { + c_result[i * stride_c + j] = val.value(); + } + for (j, &idx) in result.argmax.iter().enumerate() { + argmax_result[i * stride_c + j] = idx as i32; + } } - } + + (c_result, argmax_result) + }); Ok((c_result.into_pyarray(py), argmax_result.into_pyarray(py))) } @@ -1013,30 +1545,36 @@ fn minplus_matmul_batched_with_argmax<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let stride_a = m * k; - let stride_b = k * n; - let stride_c = m * n; + // Release GIL during heavy compute + let (c_result, argmax_result) = py.allow_threads(|| { + let stride_a = m * k; + let stride_b = k * n; + let stride_c = m * n; - let mut c_result = vec![0.0f32; batch * stride_c]; - let mut argmax_result = vec![0i32; batch * stride_c]; + let mut c_result = vec![0.0f32; batch * stride_c]; + let mut argmax_result = vec![0i32; batch * stride_c]; - for i in 0..batch { - let a_slice = &a_data[i * stride_a..(i + 1) * stride_a]; - let b_slice = &b_data[i * stride_b..(i + 1) * stride_b]; + for i in 0..batch { + let a_slice = &a_data[i * stride_a..(i + 1) * stride_a]; + let b_slice = &b_data[i * stride_b..(i + 1) * stride_b]; - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_slice, m, k, b_slice, n); + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(a_slice, m, k, b_slice, n); - for (j, val) in result.values.iter().enumerate() { - c_result[i * stride_c + j] = val.value(); - } - for (j, &idx) in result.argmax.iter().enumerate() { - argmax_result[i * stride_c + j] = idx as i32; + for (j, val) in result.values.iter().enumerate() { + c_result[i * stride_c + j] = val.value(); + } + for (j, &idx) in result.argmax.iter().enumerate() { + argmax_result[i * stride_c + j] = idx as i32; + } } - } + + (c_result, argmax_result) + }); Ok((c_result.into_pyarray(py), argmax_result.into_pyarray(py))) } @@ -1078,30 +1616,36 @@ fn maxmul_matmul_batched_with_argmax<'py>( ))); } - let a_data = a.as_slice()?; - let b_data = b.as_slice()?; + // Clone to owned data before releasing GIL + let a_data = a.as_slice()?.to_vec(); + let b_data = b.as_slice()?.to_vec(); - let stride_a = m * k; - let stride_b = k * n; - let stride_c = m * n; + // Release GIL during heavy compute + let (c_result, argmax_result) = py.allow_threads(|| { + let stride_a = m * k; + let stride_b = k * n; + let stride_c = m * n; - let mut c_result = vec![0.0f32; batch * stride_c]; - let mut argmax_result = vec![0i32; batch * stride_c]; + let mut c_result = vec![0.0f32; batch * stride_c]; + let mut argmax_result = vec![0i32; batch * stride_c]; - for i in 0..batch { - let a_slice = &a_data[i * stride_a..(i + 1) * stride_a]; - let b_slice = &b_data[i * stride_b..(i + 1) * stride_b]; + for i in 0..batch { + let a_slice = &a_data[i * stride_a..(i + 1) * stride_a]; + let b_slice = &b_data[i * stride_b..(i + 1) * stride_b]; - let result: GemmWithArgmax> = - tropical_matmul_with_argmax::>(a_slice, m, k, b_slice, n); + let result: GemmWithArgmax> = + tropical_matmul_with_argmax::>(a_slice, m, k, b_slice, n); - for (j, val) in result.values.iter().enumerate() { - c_result[i * stride_c + j] = val.value(); - } - for (j, &idx) in result.argmax.iter().enumerate() { - argmax_result[i * stride_c + j] = idx as i32; + for (j, val) in result.values.iter().enumerate() { + c_result[i * stride_c + j] = val.value(); + } + for (j, &idx) in result.argmax.iter().enumerate() { + argmax_result[i * stride_c + j] = idx as i32; + } } - } + + (c_result, argmax_result) + }); Ok((c_result.into_pyarray(py), argmax_result.into_pyarray(py))) } @@ -1113,14 +1657,207 @@ fn maxmul_matmul_batched_with_argmax<'py>( #[cfg(feature = "cuda")] mod gpu { use super::*; - use dlpark::ffi::{DataTypeCode, DeviceType}; - use dlpark::ManagedTensor; - use dlpark::TensorView; + use dlpark::ffi::{DataType, DataTypeCode, Device, DeviceType}; + use dlpark::{ManagerCtx, ManagedTensor, ShapeAndStrides, ToTensor, TensorView}; + #[allow(deprecated)] + use pyo3::IntoPy; + use std::ffi::c_void; use tropical_gemm_cuda::{ - get_global_context, launch_gemm_external_with_argmax_f32, tropical_matmul_gpu, - tropical_matmul_gpu_with_argmax, ExternalGpuMatrix, + get_context_for_device, launch_gemm_external_batched_with_argmax_f32, + launch_gemm_external_with_argmax_f32, tropical_matmul_gpu, + tropical_matmul_gpu_with_argmax, ExternalGpuMatrix, ExternalGpuTensor3, + GpuMatrix, GpuTensor3, }; + // ======================================================================== + // DLPack wrapper types for GPU tensor export + // ======================================================================== + + /// Wrapper for GpuTensor3 that implements ToTensor for DLPack export. + /// + /// This wrapper owns the GPU tensor and provides the necessary metadata + /// for DLPack tensor exchange. The tensor data remains on GPU. + struct DLPackGpuTensor3F32 { + tensor: GpuTensor3, + shape: [i64; 3], + device_id: i32, + } + + impl DLPackGpuTensor3F32 { + fn new(tensor: GpuTensor3, device_id: i32) -> Self { + let shape = [ + tensor.batch() as i64, + tensor.rows() as i64, + tensor.cols() as i64, + ]; + Self { tensor, shape, device_id } + } + } + + impl ToTensor for DLPackGpuTensor3F32 { + fn data_ptr(&self) -> *mut c_void { + self.tensor.device_ptr() as *mut c_void + } + + fn shape_and_strides(&self) -> ShapeAndStrides { + // Row-major (C-contiguous) strides for 3D tensor + ShapeAndStrides::new_contiguous(&self.shape) + } + + fn device(&self) -> Device { + Device { + device_type: DeviceType::Cuda, + device_id: self.device_id, + } + } + + fn dtype(&self) -> DataType { + DataType { + code: DataTypeCode::Float, + bits: 32, + lanes: 1, + } + } + + fn byte_offset(&self) -> u64 { + 0 + } + } + + /// Wrapper for GpuTensor3 (argmax indices) that implements ToTensor. + struct DLPackGpuTensor3I32 { + tensor: GpuTensor3, + shape: [i64; 3], + device_id: i32, + } + + impl DLPackGpuTensor3I32 { + fn new(tensor: GpuTensor3, device_id: i32) -> Self { + let shape = [ + tensor.batch() as i64, + tensor.rows() as i64, + tensor.cols() as i64, + ]; + Self { tensor, shape, device_id } + } + } + + impl ToTensor for DLPackGpuTensor3I32 { + fn data_ptr(&self) -> *mut c_void { + self.tensor.device_ptr() as *mut c_void + } + + fn shape_and_strides(&self) -> ShapeAndStrides { + ShapeAndStrides::new_contiguous(&self.shape) + } + + fn device(&self) -> Device { + Device { + device_type: DeviceType::Cuda, + device_id: self.device_id, + } + } + + fn dtype(&self) -> DataType { + DataType { + code: DataTypeCode::Int, + bits: 32, + lanes: 1, + } + } + + fn byte_offset(&self) -> u64 { + 0 + } + } + + /// Wrapper for GpuMatrix (2D) that implements ToTensor for DLPack export. + struct DLPackGpuMatrixF32 { + matrix: GpuMatrix, + shape: [i64; 2], + device_id: i32, + } + + impl DLPackGpuMatrixF32 { + fn new(matrix: GpuMatrix, device_id: i32) -> Self { + let shape = [matrix.rows() as i64, matrix.cols() as i64]; + Self { matrix, shape, device_id } + } + } + + impl ToTensor for DLPackGpuMatrixF32 { + fn data_ptr(&self) -> *mut c_void { + self.matrix.device_ptr() as *mut c_void + } + + fn shape_and_strides(&self) -> ShapeAndStrides { + // Row-major (C-contiguous) strides for 2D matrix + ShapeAndStrides::new_contiguous(&self.shape) + } + + fn device(&self) -> Device { + Device { + device_type: DeviceType::Cuda, + device_id: self.device_id, + } + } + + fn dtype(&self) -> DataType { + DataType { + code: DataTypeCode::Float, + bits: 32, + lanes: 1, + } + } + + fn byte_offset(&self) -> u64 { + 0 + } + } + + /// Wrapper for GpuMatrix (2D argmax) that implements ToTensor for DLPack export. + struct DLPackGpuMatrixI32 { + matrix: GpuMatrix, + shape: [i64; 2], + device_id: i32, + } + + impl DLPackGpuMatrixI32 { + fn new(matrix: GpuMatrix, device_id: i32) -> Self { + let shape = [matrix.rows() as i64, matrix.cols() as i64]; + Self { matrix, shape, device_id } + } + } + + impl ToTensor for DLPackGpuMatrixI32 { + fn data_ptr(&self) -> *mut c_void { + self.matrix.device_ptr() as *mut c_void + } + + fn shape_and_strides(&self) -> ShapeAndStrides { + ShapeAndStrides::new_contiguous(&self.shape) + } + + fn device(&self) -> Device { + Device { + device_type: DeviceType::Cuda, + device_id: self.device_id, + } + } + + fn dtype(&self) -> DataType { + DataType { + code: DataTypeCode::Int, + bits: 32, + lanes: 1, + } + } + + fn byte_offset(&self) -> u64 { + 0 + } + } + /// Helper function to extract ManagedTensor from a Python object. /// Calls __dlpack__() if available, otherwise tries direct extraction. fn extract_dlpack_tensor(_py: Python, obj: &Bound<'_, pyo3::PyAny>) -> PyResult { @@ -1343,37 +2080,84 @@ mod gpu { } // ======================================================================== - // DLPack zero-copy functions + // DLPack zero-copy functions (2D tensors) // ======================================================================== /// MaxPlus matrix multiplication using DLPack for zero-copy GPU tensor exchange. /// /// This function accepts PyTorch tensors (or any DLPack-compatible tensor) directly - /// and performs the computation without copying input data for GPU tensors. + /// and performs the computation without copying data for GPU tensors. /// /// Args: - /// a: Input tensor A of shape (M, K) - must support __dlpack__() - /// b: Input tensor B of shape (K, N) - must support __dlpack__() + /// a: Input tensor A of shape (M, K) - must support __dlpack__(), f32 + /// b: Input tensor B of shape (K, N) - must support __dlpack__(), f32 /// /// Returns: - /// Tuple of (C, argmax) as numpy arrays where: - /// - C: Result matrix of shape (M, N) as flattened f32 array - /// - argmax: Indices of shape (M, N) as flattened i32 array + /// Tuple of (C, argmax) where the type depends on input device: /// - /// Note: - /// - For GPU tensors: Uses zero-copy DLPack interface with Rust CUDA backend - /// - For CPU tensors: Falls back to optimized Rust CPU backend + /// For CUDA tensors: Returns DLPack capsules (data stays on GPU) + /// - C: Result of shape (M*N,) as f32 - use `torch.from_dlpack(c).reshape(m, n)` + /// - argmax: Indices of shape (M*N,) as i32 - use `torch.from_dlpack(argmax).reshape(m, n)` + /// + /// For CPU tensors: Returns numpy arrays (flattened) + /// - C: Result of shape (M*N,) as f32 - use `torch.from_numpy(c).reshape(m, n)` + /// - argmax: Indices of shape (M*N,) as i32 - use `torch.from_numpy(argmax).reshape(m, n)` #[pyfunction] - pub fn maxplus_matmul_dlpack<'py>( - py: Python<'py>, - a: Bound<'py, pyo3::PyAny>, - b: Bound<'py, pyo3::PyAny>, - ) -> PyResult<(Bound<'py, PyArray1>, Bound<'py, PyArray1>)> { + pub fn maxplus_matmul_dlpack( + py: Python<'_>, + a: Bound<'_, pyo3::PyAny>, + b: Bound<'_, pyo3::PyAny>, + ) -> PyResult<(PyObject, PyObject)> { + dlpack_2d_impl(py, a, b, "tropical_maxplus_f32_nn_with_argmax", Algebra::MaxPlus) + } + + /// MinPlus matrix multiplication using DLPack for zero-copy GPU tensor exchange. + /// + /// Returns DLPack capsules - use `torch.from_dlpack(capsule)` to convert. + #[pyfunction] + pub fn minplus_matmul_dlpack( + py: Python<'_>, + a: Bound<'_, pyo3::PyAny>, + b: Bound<'_, pyo3::PyAny>, + ) -> PyResult<(PyObject, PyObject)> { + dlpack_2d_impl(py, a, b, "tropical_minplus_f32_nn_with_argmax", Algebra::MinPlus) + } + + /// MaxMul matrix multiplication using DLPack for zero-copy GPU tensor exchange. + /// + /// Returns DLPack capsules - use `torch.from_dlpack(capsule)` to convert. + #[pyfunction] + pub fn maxmul_matmul_dlpack( + py: Python<'_>, + a: Bound<'_, pyo3::PyAny>, + b: Bound<'_, pyo3::PyAny>, + ) -> PyResult<(PyObject, PyObject)> { + dlpack_2d_impl(py, a, b, "tropical_maxmul_f32_nn_with_argmax", Algebra::MaxMul) + } + + /// Algebra type for CPU dispatch in 2D DLPack functions. + enum Algebra { + MaxPlus, + MinPlus, + MaxMul, + } + + /// Implementation for 2D DLPack operations. + /// + /// Returns DLPack capsules that keep data on GPU - use `torch.from_dlpack()` + /// to convert to PyTorch tensors. + fn dlpack_2d_impl( + py: Python<'_>, + a: Bound<'_, pyo3::PyAny>, + b: Bound<'_, pyo3::PyAny>, + kernel_name: &'static str, + algebra: Algebra, + ) -> PyResult<(PyObject, PyObject)> { // Extract tensor info from DLPack let a_tensor = extract_dlpack_tensor(py, &a)?; let b_tensor = extract_dlpack_tensor(py, &b)?; - // Get device info (ManagedTensor implements TensorView trait) + // Get device info let a_device = TensorView::device(&a_tensor); let b_device = TensorView::device(&b_tensor); @@ -1393,6 +2177,8 @@ mod gpu { ))); } + let device_id = a_device.device_id; + // Get dtype and validate let a_dtype = TensorView::dtype(&a_tensor); let b_dtype = TensorView::dtype(&b_tensor); @@ -1434,6 +2220,14 @@ mod gpu { ))); } + // Guard against zero-sized dimensions + if m == 0 || k == 0 || n == 0 { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "Zero-sized dimensions not supported: m={}, k={}, n={}", + m, k, n + ))); + } + // Check strides for contiguity let a_strides = TensorView::strides(&a_tensor); let b_strides = TensorView::strides(&b_tensor); @@ -1451,7 +2245,7 @@ mod gpu { } match a_device.device_type { - DeviceType::Cuda | DeviceType::CudaHost => { + DeviceType::Cuda => { // GPU path: zero-copy using DLPack let a_ptr = TensorView::data_ptr(&a_tensor) as u64; let b_ptr = TensorView::data_ptr(&b_tensor) as u64; @@ -1460,54 +2254,81 @@ mod gpu { let a_ext = unsafe { ExternalGpuMatrix::from_raw(a_ptr, m, k) }; let b_ext = unsafe { ExternalGpuMatrix::from_raw(b_ptr, k, n) }; - // Get the global CUDA context - let ctx = get_global_context().map_err(|e| { + // Get CUDA context for the input device + let ctx = get_context_for_device(device_id as usize).map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA error: {}", e)) })?; // Launch kernel let result = unsafe { - launch_gemm_external_with_argmax_f32( - ctx, - "tropical_maxplus_f32_nn_with_argmax", - &a_ext, - &b_ext, - m, - k, - n, - ) + launch_gemm_external_with_argmax_f32(ctx, kernel_name, &a_ext, &b_ext, m, k, n) } .map_err(|e| { pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA kernel error: {}", e)) })?; - // Download results to host (column-major layout) - let c_data = result.matrix_to_host(ctx).map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA D2H error: {}", e)) - })?; - let argmax_data = result.argmax_to_host(ctx).map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA D2H error: {}", e)) - })?; + // Split result into matrix and argmax, then wrap for DLPack export + let (c_matrix, argmax_matrix) = result.into_parts(); + + // Wrap in DLPack-compatible wrappers with correct device_id + let c_dlpack = DLPackGpuMatrixF32::new(c_matrix, device_id); + let argmax_dlpack = DLPackGpuMatrixI32::new(argmax_matrix, device_id); + + // Convert to DLPack capsules using ManagerCtx + let c_capsule = ManagerCtx::new(c_dlpack).into_py(py); + let argmax_capsule = ManagerCtx::new(argmax_dlpack).into_py(py); - Ok((c_data.into_pyarray(py), argmax_data.into_pyarray(py))) + Ok((c_capsule, argmax_capsule)) + } + DeviceType::CudaHost => { + // CudaHost (pinned memory) is not supported - the pointer is a host pointer + // that cannot be used directly by CUDA kernels without explicit handling + Err(pyo3::exceptions::PyValueError::new_err( + "CudaHost (pinned memory) tensors are not supported. Use regular CUDA tensors.", + )) } DeviceType::Cpu => { - // CPU path: use existing CPU backend + // CPU path: use existing CPU backend, return numpy arrays as PyObject let a_ptr = TensorView::data_ptr(&a_tensor) as *const f32; let b_ptr = TensorView::data_ptr(&b_tensor) as *const f32; let a_data = unsafe { std::slice::from_raw_parts(a_ptr, m * k) }; let b_data = unsafe { std::slice::from_raw_parts(b_ptr, k * n) }; - let result: ::tropical_gemm::GemmWithArgmax> = - ::tropical_gemm::tropical_matmul_with_argmax::>( - a_data, m, k, b_data, n, - ); - - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); - - Ok((c_scalars.into_pyarray(py), argmax_i32.into_pyarray(py))) + let (c_scalars, argmax_i32) = match algebra { + Algebra::MaxPlus => { + let result: ::tropical_gemm::GemmWithArgmax> = + ::tropical_gemm::tropical_matmul_with_argmax::>( + a_data, m, k, b_data, n, + ); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + } + Algebra::MinPlus => { + let result: ::tropical_gemm::GemmWithArgmax> = + ::tropical_gemm::tropical_matmul_with_argmax::>( + a_data, m, k, b_data, n, + ); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + } + Algebra::MaxMul => { + let result: ::tropical_gemm::GemmWithArgmax> = + ::tropical_gemm::tropical_matmul_with_argmax::>( + a_data, m, k, b_data, n, + ); + let c: Vec = result.values.iter().map(|x| x.value()).collect(); + let argmax: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + (c, argmax) + } + }; + + Ok(( + c_scalars.into_pyarray(py).into_any().unbind(), + argmax_i32.into_pyarray(py).into_any().unbind(), + )) } _ => Err(pyo3::exceptions::PyValueError::new_err(format!( "Unsupported device type: {:?}", @@ -1516,13 +2337,86 @@ mod gpu { } } - /// MinPlus matrix multiplication using DLPack for zero-copy GPU tensor exchange. + /// Check if CUDA is available. #[pyfunction] - pub fn minplus_matmul_dlpack<'py>( - py: Python<'py>, - a: Bound<'py, pyo3::PyAny>, - b: Bound<'py, pyo3::PyAny>, - ) -> PyResult<(Bound<'py, PyArray1>, Bound<'py, PyArray1>)> { + pub fn cuda_available() -> bool { + true + } + + // ======================================================================== + // Batched DLPack functions (3D tensors) + // ======================================================================== + // + // Note: Only f32 with argmax is supported for batched GPU operations. + // f64/i32/i64 and non-argmax variants are not implemented. + // For other dtypes, use the CPU batched API. + + /// Batched MaxPlus matrix multiplication using DLPack for zero-copy GPU tensor exchange. + /// + /// Computes C[b,i,j] = max_k(A[b,i,k] + B[b,k,j]) for each batch b. + /// + /// # Limitations + /// + /// - Only f32 dtype is supported (GPU batched) + /// - Always returns argmax (no non-argmax variant) + /// + /// Args: + /// a: Input tensor A of shape (batch, M, K) - must support __dlpack__(), f32, CUDA + /// b: Input tensor B of shape (batch, K, N) - must support __dlpack__(), f32, CUDA + /// + /// Returns: + /// Tuple of (C, argmax) as DLPack capsules where: + /// - C: Result tensor of shape (batch, M, N) as f32 CUDA tensor + /// - argmax: Indices of shape (batch, M, N) as i32 CUDA tensor + /// + /// Use `torch.from_dlpack(capsule)` to convert to PyTorch tensors. + /// + /// Raises: + /// RuntimeError: If tensors are not on CUDA or DLPack extraction fails + /// ValueError: If tensors are not f32 or not 3D + #[pyfunction] + pub fn maxplus_matmul_batched_dlpack( + py: Python<'_>, + a: Bound<'_, pyo3::PyAny>, + b: Bound<'_, pyo3::PyAny>, + ) -> PyResult<(PyObject, PyObject)> { + batched_dlpack_impl(py, a, b, "tropical_maxplus_f32_nn_batched_with_argmax") + } + + /// Batched MinPlus matrix multiplication using DLPack for zero-copy GPU tensor exchange. + /// + /// Returns DLPack capsules - use `torch.from_dlpack(capsule)` to convert. + #[pyfunction] + pub fn minplus_matmul_batched_dlpack( + py: Python<'_>, + a: Bound<'_, pyo3::PyAny>, + b: Bound<'_, pyo3::PyAny>, + ) -> PyResult<(PyObject, PyObject)> { + batched_dlpack_impl(py, a, b, "tropical_minplus_f32_nn_batched_with_argmax") + } + + /// Batched MaxMul matrix multiplication using DLPack for zero-copy GPU tensor exchange. + /// + /// Returns DLPack capsules - use `torch.from_dlpack(capsule)` to convert. + #[pyfunction] + pub fn maxmul_matmul_batched_dlpack( + py: Python<'_>, + a: Bound<'_, pyo3::PyAny>, + b: Bound<'_, pyo3::PyAny>, + ) -> PyResult<(PyObject, PyObject)> { + batched_dlpack_impl(py, a, b, "tropical_maxmul_f32_nn_batched_with_argmax") + } + + /// Implementation for batched DLPack operations. + /// + /// Returns DLPack capsules that keep data on GPU - use `torch.from_dlpack()` + /// to convert to PyTorch tensors. + fn batched_dlpack_impl( + py: Python<'_>, + a: Bound<'_, pyo3::PyAny>, + b: Bound<'_, pyo3::PyAny>, + kernel_name: &'static str, + ) -> PyResult<(PyObject, PyObject)> { // Extract tensor info from DLPack let a_tensor = extract_dlpack_tensor(py, &a)?; let b_tensor = extract_dlpack_tensor(py, &b)?; @@ -1531,21 +2425,31 @@ mod gpu { let a_device = TensorView::device(&a_tensor); let b_device = TensorView::device(&b_tensor); - if a_device.device_type != b_device.device_type { - return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Tensors must be on the same device type: A is on {:?}, B is on {:?}", - a_device.device_type, b_device.device_type + // Validate: must be on CUDA device (not CudaHost/pinned memory) + if a_device.device_type != DeviceType::Cuda { + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Tensor A must be on CUDA device, got {:?}. Use CPU batched functions for CPU tensors.", + a_device.device_type + ))); + } + + if b_device.device_type != DeviceType::Cuda { + return Err(pyo3::exceptions::PyRuntimeError::new_err(format!( + "Tensor B must be on CUDA device, got {:?}. Use CPU batched functions for CPU tensors.", + b_device.device_type ))); } - // Validate: both tensors must be on the same device ID (for multi-GPU) if a_device.device_id != b_device.device_id { return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Tensors must be on the same device: A is on device {}, B is on device {}", + "Tensors must be on the same CUDA device: A is on cuda:{}, B is on cuda:{}", a_device.device_id, b_device.device_id ))); } + let device_id = a_device.device_id; + + // Get dtype and validate let a_dtype = TensorView::dtype(&a_tensor); let b_dtype = TensorView::dtype(&b_tensor); if a_dtype != b_dtype { @@ -1557,180 +2461,64 @@ mod gpu { if a_dtype.code != DataTypeCode::Float || a_dtype.bits != 32 { return Err(pyo3::exceptions::PyValueError::new_err( - "Only f32 tensors are supported for DLPack interface", + "Only f32 tensors are supported for batched DLPack interface", )); } + // Get shapes - must be 3D let a_shape = TensorView::shape(&a_tensor); let b_shape = TensorView::shape(&b_tensor); - if a_shape.len() != 2 || b_shape.len() != 2 { + if a_shape.len() != 3 || b_shape.len() != 3 { return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Expected 2D tensors, got A with {} dims, B with {} dims", + "Expected 3D tensors, got A with {} dims, B with {} dims", a_shape.len(), b_shape.len() ))); } - let m = a_shape[0] as usize; - let k = a_shape[1] as usize; - let k2 = b_shape[0] as usize; - let n = b_shape[1] as usize; - - if k != k2 { - return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Dimension mismatch: A is {}x{}, B is {}x{}", - m, k, k2, n - ))); - } - - let a_strides = TensorView::strides(&a_tensor); - let b_strides = TensorView::strides(&b_tensor); - - let a_contiguous = a_strides.is_none() - || a_strides.map_or(false, |s| s.len() == 2 && s[1] == 1 && s[0] == k as i64); - let b_contiguous = b_strides.is_none() - || b_strides.map_or(false, |s| s.len() == 2 && s[1] == 1 && s[0] == n as i64); - - if !a_contiguous || !b_contiguous { - return Err(pyo3::exceptions::PyValueError::new_err( - "Tensors must be contiguous (call .contiguous() on PyTorch tensors)", - )); - } - - match a_device.device_type { - DeviceType::Cuda | DeviceType::CudaHost => { - let a_ptr = TensorView::data_ptr(&a_tensor) as u64; - let b_ptr = TensorView::data_ptr(&b_tensor) as u64; - - let a_ext = unsafe { ExternalGpuMatrix::from_raw(a_ptr, m, k) }; - let b_ext = unsafe { ExternalGpuMatrix::from_raw(b_ptr, k, n) }; - - let ctx = get_global_context().map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA error: {}", e)) - })?; - - let result = unsafe { - launch_gemm_external_with_argmax_f32( - ctx, - "tropical_minplus_f32_nn_with_argmax", - &a_ext, - &b_ext, - m, - k, - n, - ) - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA kernel error: {}", e)) - })?; - - let c_data = result.matrix_to_host(ctx).map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA D2H error: {}", e)) - })?; - let argmax_data = result.argmax_to_host(ctx).map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA D2H error: {}", e)) - })?; - - Ok((c_data.into_pyarray(py), argmax_data.into_pyarray(py))) - } - DeviceType::Cpu => { - let a_ptr = TensorView::data_ptr(&a_tensor) as *const f32; - let b_ptr = TensorView::data_ptr(&b_tensor) as *const f32; - - let a_data = unsafe { std::slice::from_raw_parts(a_ptr, m * k) }; - let b_data = unsafe { std::slice::from_raw_parts(b_ptr, k * n) }; - - let result: ::tropical_gemm::GemmWithArgmax> = - ::tropical_gemm::tropical_matmul_with_argmax::>( - a_data, m, k, b_data, n, - ); - - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); - - Ok((c_scalars.into_pyarray(py), argmax_i32.into_pyarray(py))) - } - _ => Err(pyo3::exceptions::PyValueError::new_err(format!( - "Unsupported device type: {:?}", - a_device.device_type - ))), - } - } - - /// MaxMul matrix multiplication using DLPack for zero-copy GPU tensor exchange. - #[pyfunction] - pub fn maxmul_matmul_dlpack<'py>( - py: Python<'py>, - a: Bound<'py, pyo3::PyAny>, - b: Bound<'py, pyo3::PyAny>, - ) -> PyResult<(Bound<'py, PyArray1>, Bound<'py, PyArray1>)> { - let a_tensor = extract_dlpack_tensor(py, &a)?; - let b_tensor = extract_dlpack_tensor(py, &b)?; - - let a_device = TensorView::device(&a_tensor); - let b_device = TensorView::device(&b_tensor); - - if a_device.device_type != b_device.device_type { - return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Tensors must be on the same device type: A is on {:?}, B is on {:?}", - a_device.device_type, b_device.device_type - ))); - } - - // Validate: both tensors must be on the same device ID (for multi-GPU) - if a_device.device_id != b_device.device_id { - return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Tensors must be on the same device: A is on device {}, B is on device {}", - a_device.device_id, b_device.device_id - ))); - } + let batch = a_shape[0] as usize; + let m = a_shape[1] as usize; + let k = a_shape[2] as usize; + let batch_b = b_shape[0] as usize; + let k2 = b_shape[1] as usize; + let n = b_shape[2] as usize; - let a_dtype = TensorView::dtype(&a_tensor); - let b_dtype = TensorView::dtype(&b_tensor); - if a_dtype != b_dtype { + if batch != batch_b { return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Tensors must have the same dtype: A is {:?}, B is {:?}", - a_dtype, b_dtype + "Batch size mismatch: A has batch {}, B has batch {}", + batch, batch_b ))); } - if a_dtype.code != DataTypeCode::Float || a_dtype.bits != 32 { - return Err(pyo3::exceptions::PyValueError::new_err( - "Only f32 tensors are supported for DLPack interface", - )); - } - - let a_shape = TensorView::shape(&a_tensor); - let b_shape = TensorView::shape(&b_tensor); - - if a_shape.len() != 2 || b_shape.len() != 2 { + if k != k2 { return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Expected 2D tensors, got A with {} dims, B with {} dims", - a_shape.len(), - b_shape.len() + "Dimension mismatch: A is (batch, {}, {}), B is (batch, {}, {})", + m, k, k2, n ))); } - let m = a_shape[0] as usize; - let k = a_shape[1] as usize; - let k2 = b_shape[0] as usize; - let n = b_shape[1] as usize; - - if k != k2 { + // Guard against zero-sized dimensions (would cause invalid CUDA launch) + if batch == 0 || m == 0 || k == 0 || n == 0 { return Err(pyo3::exceptions::PyValueError::new_err(format!( - "Dimension mismatch: A is {}x{}, B is {}x{}", - m, k, k2, n + "Zero-sized dimensions not supported: batch={}, m={}, k={}, n={}", + batch, m, k, n ))); } + // Check strides for contiguity (row-major per batch) let a_strides = TensorView::strides(&a_tensor); let b_strides = TensorView::strides(&b_tensor); + // For 3D row-major (C-contiguous): strides should be [m*k, k, 1] let a_contiguous = a_strides.is_none() - || a_strides.map_or(false, |s| s.len() == 2 && s[1] == 1 && s[0] == k as i64); + || a_strides.map_or(false, |s| { + s.len() == 3 && s[2] == 1 && s[1] == k as i64 && s[0] == (m * k) as i64 + }); let b_contiguous = b_strides.is_none() - || b_strides.map_or(false, |s| s.len() == 2 && s[1] == 1 && s[0] == n as i64); + || b_strides.map_or(false, |s| { + s.len() == 3 && s[2] == 1 && s[1] == n as i64 && s[0] == (k * n) as i64 + }); if !a_contiguous || !b_contiguous { return Err(pyo3::exceptions::PyValueError::new_err( @@ -1738,70 +2526,40 @@ mod gpu { )); } - match a_device.device_type { - DeviceType::Cuda | DeviceType::CudaHost => { - let a_ptr = TensorView::data_ptr(&a_tensor) as u64; - let b_ptr = TensorView::data_ptr(&b_tensor) as u64; - - let a_ext = unsafe { ExternalGpuMatrix::from_raw(a_ptr, m, k) }; - let b_ext = unsafe { ExternalGpuMatrix::from_raw(b_ptr, k, n) }; - - let ctx = get_global_context().map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA error: {}", e)) - })?; - - let result = unsafe { - launch_gemm_external_with_argmax_f32( - ctx, - "tropical_maxmul_f32_nn_with_argmax", - &a_ext, - &b_ext, - m, - k, - n, - ) - } - .map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA kernel error: {}", e)) - })?; + // GPU path: zero-copy using DLPack + let a_ptr = TensorView::data_ptr(&a_tensor) as u64; + let b_ptr = TensorView::data_ptr(&b_tensor) as u64; - let c_data = result.matrix_to_host(ctx).map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA D2H error: {}", e)) - })?; - let argmax_data = result.argmax_to_host(ctx).map_err(|e| { - pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA D2H error: {}", e)) - })?; + // Create external 3D tensor views + let a_ext = unsafe { ExternalGpuTensor3::from_raw_contiguous(a_ptr, batch, m, k) }; + let b_ext = unsafe { ExternalGpuTensor3::from_raw_contiguous(b_ptr, batch, k, n) }; - Ok((c_data.into_pyarray(py), argmax_data.into_pyarray(py))) - } - DeviceType::Cpu => { - let a_ptr = TensorView::data_ptr(&a_tensor) as *const f32; - let b_ptr = TensorView::data_ptr(&b_tensor) as *const f32; + // Get CUDA context for the input device + let ctx = get_context_for_device(device_id as usize).map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA error: {}", e)) + })?; - let a_data = unsafe { std::slice::from_raw_parts(a_ptr, m * k) }; - let b_data = unsafe { std::slice::from_raw_parts(b_ptr, k * n) }; + // Launch batched kernel + let result = unsafe { + launch_gemm_external_batched_with_argmax_f32(ctx, kernel_name, &a_ext, &b_ext, batch, m, k, n) + } + .map_err(|e| { + pyo3::exceptions::PyRuntimeError::new_err(format!("CUDA kernel error: {}", e)) + })?; - let result: ::tropical_gemm::GemmWithArgmax> = - ::tropical_gemm::tropical_matmul_with_argmax::>( - a_data, m, k, b_data, n, - ); + // Split result into tensor and argmax, then wrap for DLPack export + let (c_tensor, argmax_tensor) = result.into_parts(); - let c_scalars: Vec = result.values.iter().map(|x| x.value()).collect(); - let argmax_i32: Vec = result.argmax.iter().map(|&x| x as i32).collect(); + // Wrap in DLPack-compatible wrappers with correct device_id + let c_dlpack = DLPackGpuTensor3F32::new(c_tensor, device_id); + let argmax_dlpack = DLPackGpuTensor3I32::new(argmax_tensor, device_id); - Ok((c_scalars.into_pyarray(py), argmax_i32.into_pyarray(py))) - } - _ => Err(pyo3::exceptions::PyValueError::new_err(format!( - "Unsupported device type: {:?}", - a_device.device_type - ))), - } - } + // Convert to DLPack capsules using ManagerCtx + // ManagerCtx owns the tensor and exports it as a DLPack capsule + let c_capsule = ManagerCtx::new(c_dlpack).into_py(py); + let argmax_capsule = ManagerCtx::new(argmax_dlpack).into_py(py); - /// Check if CUDA is available. - #[pyfunction] - pub fn cuda_available() -> bool { - true + Ok((c_capsule, argmax_capsule)) } /// Register GPU functions in the module. @@ -1816,6 +2574,10 @@ mod gpu { m.add_function(wrap_pyfunction!(maxplus_matmul_dlpack, m)?)?; m.add_function(wrap_pyfunction!(minplus_matmul_dlpack, m)?)?; m.add_function(wrap_pyfunction!(maxmul_matmul_dlpack, m)?)?; + // Batched DLPack functions + m.add_function(wrap_pyfunction!(maxplus_matmul_batched_dlpack, m)?)?; + m.add_function(wrap_pyfunction!(minplus_matmul_batched_dlpack, m)?)?; + m.add_function(wrap_pyfunction!(maxmul_matmul_batched_dlpack, m)?)?; m.add_function(wrap_pyfunction!(cuda_available, m)?)?; Ok(()) } @@ -1880,6 +2642,26 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(minplus_matmul_i64, m)?)?; m.add_function(wrap_pyfunction!(maxmul_matmul_i64, m)?)?; + // 2D output variants (f32) + m.add_function(wrap_pyfunction!(maxplus_matmul_2d, m)?)?; + m.add_function(wrap_pyfunction!(minplus_matmul_2d, m)?)?; + m.add_function(wrap_pyfunction!(maxmul_matmul_2d, m)?)?; + + // 2D output variants (f64) + m.add_function(wrap_pyfunction!(maxplus_matmul_2d_f64, m)?)?; + m.add_function(wrap_pyfunction!(minplus_matmul_2d_f64, m)?)?; + m.add_function(wrap_pyfunction!(maxmul_matmul_2d_f64, m)?)?; + + // 2D output variants (i32) + m.add_function(wrap_pyfunction!(maxplus_matmul_2d_i32, m)?)?; + m.add_function(wrap_pyfunction!(minplus_matmul_2d_i32, m)?)?; + m.add_function(wrap_pyfunction!(maxmul_matmul_2d_i32, m)?)?; + + // 2D output variants (i64) + m.add_function(wrap_pyfunction!(maxplus_matmul_2d_i64, m)?)?; + m.add_function(wrap_pyfunction!(minplus_matmul_2d_i64, m)?)?; + m.add_function(wrap_pyfunction!(maxmul_matmul_2d_i64, m)?)?; + // GPU operations (if available) gpu::register(m)?; diff --git a/crates/tropical-gemm-python/tests/test_pytorch_gradients.py b/crates/tropical-gemm-python/tests/test_pytorch_gradients.py index c796467..e9d4669 100644 --- a/crates/tropical-gemm-python/tests/test_pytorch_gradients.py +++ b/crates/tropical-gemm-python/tests/test_pytorch_gradients.py @@ -472,13 +472,15 @@ def test_gpu_maxplus_forward(): """Test GPU MaxPlus forward pass matches CPU.""" torch.manual_seed(42) - a = torch.randn(4, 3) - b = torch.randn(3, 5) + a_cpu = torch.randn(4, 3) + b_cpu = torch.randn(3, 5) + a_gpu = a_cpu.cuda() + b_gpu = b_cpu.cuda() - c_cpu = tropical_maxplus_matmul(a, b) - c_gpu = tropical_maxplus_matmul_gpu(a, b) + c_cpu = tropical_maxplus_matmul(a_cpu, b_cpu) + c_gpu = tropical_maxplus_matmul_gpu(a_gpu, b_gpu) - assert torch.allclose(c_cpu, c_gpu, atol=1e-4), f"GPU result differs from CPU" + assert torch.allclose(c_cpu, c_gpu.cpu(), atol=1e-4), f"GPU result differs from CPU" @pytest.mark.skipif(not GPU_AVAILABLE, reason="CUDA not available") @@ -486,13 +488,15 @@ def test_gpu_minplus_forward(): """Test GPU MinPlus forward pass matches CPU.""" torch.manual_seed(42) - a = torch.randn(4, 3) - b = torch.randn(3, 5) + a_cpu = torch.randn(4, 3) + b_cpu = torch.randn(3, 5) + a_gpu = a_cpu.cuda() + b_gpu = b_cpu.cuda() - c_cpu = tropical_minplus_matmul(a, b) - c_gpu = tropical_minplus_matmul_gpu(a, b) + c_cpu = tropical_minplus_matmul(a_cpu, b_cpu) + c_gpu = tropical_minplus_matmul_gpu(a_gpu, b_gpu) - assert torch.allclose(c_cpu, c_gpu, atol=1e-4), f"GPU result differs from CPU" + assert torch.allclose(c_cpu, c_gpu.cpu(), atol=1e-4), f"GPU result differs from CPU" @pytest.mark.skipif(not GPU_AVAILABLE, reason="CUDA not available") @@ -500,8 +504,8 @@ def test_gpu_maxplus_gradient(): """Test GPU MaxPlus backward pass.""" torch.manual_seed(42) - a = torch.tensor([[1.0, 10.0], [5.0, 2.0]], requires_grad=True) - b = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True) + a = torch.tensor([[1.0, 10.0], [5.0, 2.0]], device="cuda", requires_grad=True) + b = torch.tensor([[1.0, 2.0], [3.0, 4.0]], device="cuda", requires_grad=True) c = tropical_maxplus_matmul_gpu(a, b) loss = c.sum() @@ -519,9 +523,9 @@ def test_gpu_optimization(): """Test GPU optimization convergence.""" torch.manual_seed(42) - a = torch.randn(3, 4, requires_grad=True) - b = torch.randn(4, 3, requires_grad=True) - target = torch.randn(3, 3) + a = torch.randn(3, 4, device="cuda", requires_grad=True) + b = torch.randn(4, 3, device="cuda", requires_grad=True) + target = torch.randn(3, 3, device="cuda") optimizer = torch.optim.Adam([a, b], lr=0.1) @@ -678,8 +682,12 @@ def test_dlpack_gpu_tensor_zero_copy(): b = torch.randn(50, 80, dtype=torch.float32, device='cuda') # This should use the zero-copy DLPack path with Rust CUDA backend - c_flat, _ = tropical_gemm.maxplus_matmul_dlpack(a, b) - c = torch.from_numpy(np.array(c_flat).reshape(100, 80)) + # Returns DLPack capsules - data stays on GPU + c_capsule, _ = tropical_gemm.maxplus_matmul_dlpack(a, b) + c = torch.from_dlpack(c_capsule).reshape(100, 80) + + # Verify result is on GPU + assert c.is_cuda, "Result should be on GPU" # Verify result matches CPU reference a_cpu = a.cpu() @@ -687,7 +695,7 @@ def test_dlpack_gpu_tensor_zero_copy(): c_ref_flat, _ = tropical_gemm.maxplus_matmul_with_argmax(a_cpu.numpy(), b_cpu.numpy()) c_ref = torch.from_numpy(np.array(c_ref_flat).reshape(100, 80)) - assert torch.allclose(c, c_ref, atol=1e-5), "GPU DLPack path should match CPU reference" + assert torch.allclose(c.cpu(), c_ref, atol=1e-5), "GPU DLPack path should match CPU reference" @pytest.mark.skipif(not GPU_AVAILABLE, reason="CUDA not available") diff --git a/crates/tropical-gemm/src/core/gemm.rs b/crates/tropical-gemm/src/core/gemm.rs index 0d4c69d..51cf8da 100644 --- a/crates/tropical-gemm/src/core/gemm.rs +++ b/crates/tropical-gemm/src/core/gemm.rs @@ -71,7 +71,10 @@ pub unsafe fn tropical_gemm_inner>( return; } - // Allocate packing buffers + // TODO(#34): Avoid repeated allocation by accepting caller-provided workspace. + // For repeated GEMM calls, consider adding a workspace-based API: + // pub struct GemmWorkspace { packed_a: Vec, packed_b: Vec } + // pub fn tropical_gemm_with_workspace(..., workspace: &mut GemmWorkspace) let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)]; let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)]; @@ -185,7 +188,7 @@ pub unsafe fn tropical_gemm_with_argmax_inner< let ldc = result.ld; let (c, argmax) = result.as_mut_ptrs(); - // Allocate packing buffers + // TODO(#34): Avoid repeated allocation by accepting caller-provided workspace. let mut packed_a = vec![T::Scalar::scalar_zero(); packed_a_size(params.mc, params.kc, K::MR)]; let mut packed_b = vec![T::Scalar::scalar_zero(); packed_b_size(params.kc, params.nc, K::NR)]; diff --git a/crates/tropical-gemm/src/mat/owned.rs b/crates/tropical-gemm/src/mat/owned.rs index 9592c68..0ea057e 100644 --- a/crates/tropical-gemm/src/mat/owned.rs +++ b/crates/tropical-gemm/src/mat/owned.rs @@ -93,7 +93,12 @@ impl Mat { /// Create a matrix from row-major scalar data. /// /// This is a convenience method that converts row-major input to column-major storage. - #[deprecated(since = "0.4.0", note = "use from_col_major instead for direct column-major input")] + /// + /// # Performance Warning + /// + /// This method performs an O(m×n) transpose operation. For performance-critical code, + /// provide data in column-major order and use [`from_col_major`] instead. + #[deprecated(since = "0.4.0", note = "use from_col_major instead for direct column-major input; this method has O(m×n) transpose overhead")] pub fn from_row_major(data: &[S::Scalar], nrows: usize, ncols: usize) -> Self where S::Scalar: Copy, diff --git a/docs/src/api-reference.md b/docs/src/api-reference.md index 03e34ae..24cc30f 100644 --- a/docs/src/api-reference.md +++ b/docs/src/api-reference.md @@ -152,15 +152,42 @@ import numpy as np a = np.array([[1, 2], [3, 4]], dtype=np.float32) b = np.array([[5, 6], [7, 8]], dtype=np.float32) -# Basic operations -c = tropical_gemm.maxplus_matmul(a, b) -c = tropical_gemm.minplus_matmul(a, b) -c = tropical_gemm.maxmul_matmul(a, b) +# Basic operations (returns flattened 1D array) +c_flat = tropical_gemm.maxplus_matmul(a, b) +c = c_flat.reshape(a.shape[0], b.shape[1]) + +# 2D output (returns proper 2D array directly) +c = tropical_gemm.maxplus_matmul_2d(a, b) # shape: (m, n) +c = tropical_gemm.minplus_matmul_2d(a, b) +c = tropical_gemm.maxmul_matmul_2d(a, b) # With argmax values, argmax = tropical_gemm.maxplus_matmul_with_argmax(a, b) ``` +### 2D Output Functions + +The `*_matmul_2d` variants return properly shaped 2D NumPy arrays without manual reshaping: + +| Type | MaxPlus | MinPlus | MaxMul | +|------|---------|---------|--------| +| f32 | `maxplus_matmul_2d` | `minplus_matmul_2d` | `maxmul_matmul_2d` | +| f64 | `maxplus_matmul_2d_f64` | `minplus_matmul_2d_f64` | `maxmul_matmul_2d_f64` | +| i32 | `maxplus_matmul_2d_i32` | `minplus_matmul_2d_i32` | `maxmul_matmul_2d_i32` | +| i64 | `maxplus_matmul_2d_i64` | `minplus_matmul_2d_i64` | `maxmul_matmul_2d_i64` | + +```python +# f64 example +a = np.array([[1, 2], [3, 4]], dtype=np.float64) +b = np.array([[5, 6], [7, 8]], dtype=np.float64) +c = tropical_gemm.maxplus_matmul_2d_f64(a, b) # shape: (2, 2) + +# i32 example +a = np.array([[1, 2], [3, 4]], dtype=np.int32) +b = np.array([[5, 6], [7, 8]], dtype=np.int32) +c = tropical_gemm.maxplus_matmul_2d_i32(a, b) # shape: (2, 2) +``` + ### Backward Pass ```python diff --git a/docs/src/changelog.md b/docs/src/changelog.md index dd3141b..b460f0d 100644 --- a/docs/src/changelog.md +++ b/docs/src/changelog.md @@ -2,14 +2,25 @@ All notable changes to tropical-gemm. -## [Unreleased] +## [0.2.0] ### Added +- **2D output functions**: New `*_matmul_2d` variants that return properly shaped 2D arrays instead of flattened 1D output. Available for all semirings (maxplus, minplus, maxmul) and data types (f32, f64, i32, i64): + - `maxplus_matmul_2d`, `minplus_matmul_2d`, `maxmul_matmul_2d` (f32) + - `maxplus_matmul_2d_f64`, `minplus_matmul_2d_f64`, `maxmul_matmul_2d_f64` + - `maxplus_matmul_2d_i32`, `minplus_matmul_2d_i32`, `maxmul_matmul_2d_i32` + - `maxplus_matmul_2d_i64`, `minplus_matmul_2d_i64`, `maxmul_matmul_2d_i64` - mdBook documentation - Comprehensive architecture documentation - Performance tuning guide - Troubleshooting guide +### Changed +- **GIL release during compute**: All CPU functions now release Python's GIL during heavy computation, allowing other Python threads to run concurrently. This improves performance in multi-threaded Python applications. + +### Fixed +- **Batched CPU path copies**: Fixed unnecessary memory copies in batched PyTorch operations by using `np.asarray()` instead of `np.array()` for zero-copy array creation when possible. + ## [0.1.0] - Initial Release ### Features diff --git a/docs/src/performance.md b/docs/src/performance.md index df4a0f8..db2905a 100644 --- a/docs/src/performance.md +++ b/docs/src/performance.md @@ -172,6 +172,52 @@ output = tropical_matmul(large_batch_a, large_batch_b) outputs = [tropical_matmul(a, b) for a, b in zip(small_as, small_bs)] ``` +## Python Threading + +### GIL Release During Compute + +All CPU functions release Python's GIL during heavy computation, allowing other Python threads to run concurrently: + +```python +import threading +import tropical_gemm +import numpy as np + +def background_task(): + # This can run while tropical_gemm computes + print("Background task running") + +a = np.random.randn(1000, 1000).astype(np.float32) +b = np.random.randn(1000, 1000).astype(np.float32) + +# Start background thread +t = threading.Thread(target=background_task) +t.start() + +# GIL is released during compute - background thread can run +c = tropical_gemm.maxplus_matmul(a, b) + +t.join() +``` + +This is particularly useful in: +- Web servers (Flask, FastAPI) handling concurrent requests +- GUI applications that need to remain responsive +- Async applications using concurrent.futures + +### Zero-Copy with 2D Functions + +The `*_matmul_2d` functions return properly shaped 2D arrays without reshaping overhead: + +```python +# Recommended: Use 2D functions for cleaner code +c = tropical_gemm.maxplus_matmul_2d(a, b) # shape: (m, n) + +# Older pattern requiring reshape +c_flat = tropical_gemm.maxplus_matmul(a, b) # shape: (m*n,) +c = c_flat.reshape(m, n) +``` + ## Memory Considerations ### Argmax Memory diff --git a/docs/src/pytorch.md b/docs/src/pytorch.md index 2f071e7..345f8a6 100644 --- a/docs/src/pytorch.md +++ b/docs/src/pytorch.md @@ -225,8 +225,8 @@ class TropicalMaxPlusMatmul(torch.autograd.Function): # Forward pass with argmax tracking c_flat, argmax_flat = tropical_gemm.maxplus_matmul_with_argmax(a_np, b_np) - c_np = np.array(c_flat).reshape(m, n) - argmax_np = np.array(argmax_flat).reshape(m, n) + c_np = np.asarray(c_flat).reshape(m, n) # zero-copy when possible + argmax_np = np.asarray(argmax_flat).reshape(m, n) # Save for backward ctx.save_for_backward(torch.from_numpy(argmax_np)) @@ -249,8 +249,8 @@ class TropicalMaxPlusMatmul(torch.autograd.Function): grad_a_flat = tropical_gemm.backward_a(grad_c_np, argmax_np, k) grad_b_flat = tropical_gemm.backward_b(grad_c_np, argmax_np, k) - grad_a = torch.from_numpy(np.array(grad_a_flat).reshape(m, k)).to(grad_c.device) - grad_b = torch.from_numpy(np.array(grad_b_flat).reshape(k, n)).to(grad_c.device) + grad_a = torch.from_numpy(np.asarray(grad_a_flat).reshape(m, k)).to(grad_c.device) + grad_b = torch.from_numpy(np.asarray(grad_b_flat).reshape(k, n)).to(grad_c.device) return grad_a, grad_b ```