feat: batched GPU ops with DLPack zero-copy output#35
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #35 +/- ##
=======================================
Coverage 96.41% 96.41%
=======================================
Files 18 18
Lines 892 892
=======================================
Hits 860 860
Misses 32 32 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
This pull request implements batched GPU operations with true zero-copy DLPack output to address performance issues documented in #34. The implementation adds CUDA batched kernels, DLPack zero-copy export for GPU results, multi-GPU support with per-device context caching, and comprehensive input validation.
Changes:
- Batched CUDA kernels using
blockIdx.zfor the batch dimension with proper strided memory access - DLPack zero-copy output via
DLPackGpuTensor3F32andDLPackGpuTensor3I32wrapper types - Multi-GPU support with device-specific context caching and
get_context_for_device(device_id)function
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| crates/tropical-gemm-cuda/kernels/tropical_gemm.cu | Adds batched GEMM kernel macro and f32 kernel instantiations using blockIdx.z for batch indexing |
| crates/tropical-gemm-cuda/src/lib.rs | Implements per-device CUDA context caching with thread-safe initialization and device count validation |
| crates/tropical-gemm-cuda/src/context.rs | Registers three new batched kernel names in KERNEL_NAMES array |
| crates/tropical-gemm-cuda/src/memory.rs | Adds ExternalGpuTensor3, GpuTensor3, and GpuTensor3WithArgmax types for 3D tensor operations |
| crates/tropical-gemm-cuda/src/kernels.rs | Implements launch_gemm_external_batched_with_argmax_f32 with proper grid configuration and parameter swapping |
| crates/tropical-gemm-python/src/lib.rs | Adds DLPack wrapper types and batched DLPack functions with comprehensive validation (device type, dtype, dimensions, contiguity, zero-size guards) |
| crates/tropical-gemm-python/python/tropical_gemm/pytorch.py | Updates batched autograd functions to route CUDA tensors through DLPack interface and validate device consistency |
| crates/tropical-gemm/src/mat/owned.rs | Enhances deprecation warning with O(m×n) performance cost documentation |
| crates/tropical-gemm/src/core/gemm.rs | Adds TODO comments referencing #34 for workspace-based API to avoid repeated allocations |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather) | ||
| argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) |
There was a problem hiding this comment.
The argmax tensor cast to int64 creates an extra GPU memory copy, defeating the zero-copy goal. PyTorch's scatter_add_ supports int32 indices for CUDA tensors, so the cast is unnecessary. Consider removing the .to(torch.int64) call to maintain zero-copy semantics.
| # Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather) | |
| argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) | |
| # Convert argmax DLPack capsule to PyTorch tensor (zero-copy on GPU) | |
| argmax = torch.from_dlpack(argmax_capsule) |
| import warnings | ||
|
|
There was a problem hiding this comment.
The warnings module is imported but never used in the file. Consider removing this unused import.
| import warnings |
| # Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather) | ||
| argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) |
There was a problem hiding this comment.
The argmax tensor cast to int64 creates an extra GPU memory copy, defeating the zero-copy goal. PyTorch's scatter_add_ supports int32 indices for CUDA tensors, so the cast is unnecessary. Consider removing the .to(torch.int64) call to maintain zero-copy semantics. The DLPack tensor can be used directly after torch.from_dlpack(argmax_capsule).
| # Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather) | |
| argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) | |
| argmax = torch.from_dlpack(argmax_capsule) |
| # Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather) | ||
| argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) |
There was a problem hiding this comment.
The argmax tensor cast to int64 creates an extra GPU memory copy, defeating the zero-copy goal. PyTorch's scatter_add_ supports int32 indices for CUDA tensors, so the cast is unnecessary. Consider removing the .to(torch.int64) call to maintain zero-copy semantics.
| # Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather) | |
| argmax = torch.from_dlpack(argmax_capsule).to(torch.int64) | |
| # Convert argmax from DLPack without changing dtype to preserve zero-copy semantics | |
| argmax = torch.from_dlpack(argmax_capsule) |
65a4005 to
70ffcc2
Compare
Implements true zero-copy batched GPU operations that keep results on device:
CUDA kernel changes:
- Add batched GEMM kernels using blockIdx.z for batch dimension
- Register tropical_{maxplus,minplus,maxmul}_f32_nn_batched_with_argmax
DLPack export:
- Implement ToTensor for GpuTensor3<f32> and GpuTensor3<i32>
- Return DLPack capsules instead of numpy arrays (no D2H transfer)
- Use torch.from_dlpack() in Python for zero-copy tensor creation
Multi-GPU support:
- Add per-device context cache using CudaDevice::count()
- get_context_for_device(device_id) for any valid CUDA device
- Pass correct device_id to DLPack output metadata
Robustness:
- Reject CudaHost inputs (require DeviceType::Cuda)
- Guard against zero-sized dimensions
- Validate tensors are on same CUDA device
- Cast argmax to int64 for PyTorch indexing ops
Closes #34
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
70ffcc2 to
b70e4c0
Compare
- Remove unused 'warnings' import - Remove .to(torch.int64) cast on argmax tensors to preserve zero-copy semantics (PyTorch scatter_add_ supports int32 indices on CUDA) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add device_ptr(), into_inner() to GpuMatrix for DLPack export - Add into_parts() to GpuMatrixWithArgmax for splitting result - Create DLPackGpuMatrixF32/I32 wrapper types implementing ToTensor - Refactor 2D DLPack functions into dlpack_2d_impl() helper - Return DLPack capsules for GPU path (data stays on GPU) - Use get_context_for_device() instead of get_global_context() - Reject CudaHost inputs with explicit error message - Add zero-size dimension guard - Update Python GPU classes to use torch.from_dlpack() Before: GPU results were downloaded to host, Python copied back to GPU After: GPU results stay on GPU via DLPack zero-copy capsules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix scatter() int64 requirement for batched GPU path - Remove unused get_global_context import - Add #[allow(deprecated)] for IntoPy (dlpark dependency) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The test was using the old API (torch.from_numpy) but now maxplus_matmul_dlpack returns DLPack capsules for GPU tensors. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add _get_dtype_funcs() helper to support both f32 and f64 dtypes - Remove unused _rust_cpu_*_with_argmax helper functions - Replace np.array() with np.asarray() to avoid unnecessary copies - Update GPU docstrings to reflect true zero-copy (no D2H transfer) - Fix DLPack docstring to clarify CPU returns numpy arrays - Save dtype in ctx for consistent backward pass handling Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary of Latest Fixes2D DLPack Zero-Copy Output (commit 971f72f)Fixed the 2D DLPack GPU path to return zero-copy capsules instead of doing D2H transfer:
Before: GPU results downloaded to host, Python copied back to GPU CPU Dtype Handling & Code Cleanup (commit 2a39fb3)
Bug Fixes |
…d copies - Release GIL with py.allow_threads() during heavy Rust compute in all 29 CPU functions (matmul, with_argmax, backward, batched) - Add 12 new *_matmul_2d functions returning PyArray2 for all dtypes (f32, f64, i32, i64) to avoid manual reshape after matmul - Fix unnecessary copies in batched CPU path by using np.asarray() instead of np.array() in pytorch.py Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add 2D output functions to API reference with usage examples - Document GIL release during compute in performance guide - Add Python threading section with example code - Update PyTorch example to use np.asarray for zero-copy - Update changelog with new features and fixes Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Summary
Implements true zero-copy batched GPU operations that keep results on device, addressing the performance issues documented in #34.
Key changes:
tropical_{maxplus,minplus,maxmul}_f32_nn_batched_with_argmaxkernels usingblockIdx.zfor batch dimensionCudaDevice::count(), withget_context_for_device(device_id)Performance impact:
Before: GPU compute → D2H transfer → numpy → H2D transfer → PyTorch tensor
After: GPU compute → DLPack capsule → PyTorch tensor (zero-copy)
Test plan
Testing on GPU server
Closes #34
🤖 Generated with Claude Code