Skip to content

feat: batched GPU ops with DLPack zero-copy output#35

Merged
isPANN merged 10 commits intomainfrom
feat/batched-gpu-dlpack-zero-copy
Jan 29, 2026
Merged

feat: batched GPU ops with DLPack zero-copy output#35
isPANN merged 10 commits intomainfrom
feat/batched-gpu-dlpack-zero-copy

Conversation

@isPANN
Copy link
Collaborator

@isPANN isPANN commented Jan 28, 2026

Summary

Implements true zero-copy batched GPU operations that keep results on device, addressing the performance issues documented in #34.

Key changes:

  • Batched CUDA kernels: Add tropical_{maxplus,minplus,maxmul}_f32_nn_batched_with_argmax kernels using blockIdx.z for batch dimension
  • DLPack zero-copy output: Return DLPack capsules instead of numpy arrays - no D2H transfer for GPU results
  • Multi-GPU support: Per-device context cache sized by CudaDevice::count(), with get_context_for_device(device_id)
  • Robustness: Reject CudaHost, guard zero-sized dims, validate same-device tensors, cast argmax to int64

Performance impact:

Before: GPU compute → D2H transfer → numpy → H2D transfer → PyTorch tensor
After: GPU compute → DLPack capsule → PyTorch tensor (zero-copy)

Test plan

  • CPU Rust tests pass (24 passed)
  • Python CPU tests pass (158 passed, 51 skipped)
  • GPU tests on CUDA server (17 tests need CUDA)
  • Julia reference dataset tests on GPU

Testing on GPU server

# Build with CUDA
cd crates/tropical-gemm-python
maturin develop --features cuda --release

# Run all tests including GPU
python -m pytest tests/ -v

# Run Julia reference tests specifically
python -m pytest tests/test_julia_reference.py -v

Closes #34

🤖 Generated with Claude Code

@codecov
Copy link

codecov bot commented Jan 28, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 96.41%. Comparing base (0c4e0fb) to head (839df3c).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main      #35   +/-   ##
=======================================
  Coverage   96.41%   96.41%           
=======================================
  Files          18       18           
  Lines         892      892           
=======================================
  Hits          860      860           
  Misses         32       32           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request implements batched GPU operations with true zero-copy DLPack output to address performance issues documented in #34. The implementation adds CUDA batched kernels, DLPack zero-copy export for GPU results, multi-GPU support with per-device context caching, and comprehensive input validation.

Changes:

  • Batched CUDA kernels using blockIdx.z for the batch dimension with proper strided memory access
  • DLPack zero-copy output via DLPackGpuTensor3F32 and DLPackGpuTensor3I32 wrapper types
  • Multi-GPU support with device-specific context caching and get_context_for_device(device_id) function

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
crates/tropical-gemm-cuda/kernels/tropical_gemm.cu Adds batched GEMM kernel macro and f32 kernel instantiations using blockIdx.z for batch indexing
crates/tropical-gemm-cuda/src/lib.rs Implements per-device CUDA context caching with thread-safe initialization and device count validation
crates/tropical-gemm-cuda/src/context.rs Registers three new batched kernel names in KERNEL_NAMES array
crates/tropical-gemm-cuda/src/memory.rs Adds ExternalGpuTensor3, GpuTensor3, and GpuTensor3WithArgmax types for 3D tensor operations
crates/tropical-gemm-cuda/src/kernels.rs Implements launch_gemm_external_batched_with_argmax_f32 with proper grid configuration and parameter swapping
crates/tropical-gemm-python/src/lib.rs Adds DLPack wrapper types and batched DLPack functions with comprehensive validation (device type, dtype, dimensions, contiguity, zero-size guards)
crates/tropical-gemm-python/python/tropical_gemm/pytorch.py Updates batched autograd functions to route CUDA tensors through DLPack interface and validate device consistency
crates/tropical-gemm/src/mat/owned.rs Enhances deprecation warning with O(m×n) performance cost documentation
crates/tropical-gemm/src/core/gemm.rs Adds TODO comments referencing #34 for workspace-based API to avoid repeated allocations

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 815 to 822
# Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather)
argmax = torch.from_dlpack(argmax_capsule).to(torch.int64)
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argmax tensor cast to int64 creates an extra GPU memory copy, defeating the zero-copy goal. PyTorch's scatter_add_ supports int32 indices for CUDA tensors, so the cast is unnecessary. Consider removing the .to(torch.int64) call to maintain zero-copy semantics.

Suggested change
# Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather)
argmax = torch.from_dlpack(argmax_capsule).to(torch.int64)
# Convert argmax DLPack capsule to PyTorch tensor (zero-copy on GPU)
argmax = torch.from_dlpack(argmax_capsule)

Copilot uses AI. Check for mistakes.
Comment on lines 24 to 25
import warnings

Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warnings module is imported but never used in the file. Consider removing this unused import.

Suggested change
import warnings

Copilot uses AI. Check for mistakes.
Comment on lines 619 to 626
# Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather)
argmax = torch.from_dlpack(argmax_capsule).to(torch.int64)
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argmax tensor cast to int64 creates an extra GPU memory copy, defeating the zero-copy goal. PyTorch's scatter_add_ supports int32 indices for CUDA tensors, so the cast is unnecessary. Consider removing the .to(torch.int64) call to maintain zero-copy semantics. The DLPack tensor can be used directly after torch.from_dlpack(argmax_capsule).

Suggested change
# Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather)
argmax = torch.from_dlpack(argmax_capsule).to(torch.int64)
argmax = torch.from_dlpack(argmax_capsule)

Copilot uses AI. Check for mistakes.
Comment on lines 715 to 722
# Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather)
argmax = torch.from_dlpack(argmax_capsule).to(torch.int64)
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The argmax tensor cast to int64 creates an extra GPU memory copy, defeating the zero-copy goal. PyTorch's scatter_add_ supports int32 indices for CUDA tensors, so the cast is unnecessary. Consider removing the .to(torch.int64) call to maintain zero-copy semantics.

Suggested change
# Cast argmax to int64 for PyTorch indexing ops (scatter_add_, gather)
argmax = torch.from_dlpack(argmax_capsule).to(torch.int64)
# Convert argmax from DLPack without changing dtype to preserve zero-copy semantics
argmax = torch.from_dlpack(argmax_capsule)

Copilot uses AI. Check for mistakes.
@isPANN isPANN force-pushed the feat/batched-gpu-dlpack-zero-copy branch from 65a4005 to 70ffcc2 Compare January 28, 2026 18:43
Implements true zero-copy batched GPU operations that keep results on device:

CUDA kernel changes:
- Add batched GEMM kernels using blockIdx.z for batch dimension
- Register tropical_{maxplus,minplus,maxmul}_f32_nn_batched_with_argmax

DLPack export:
- Implement ToTensor for GpuTensor3<f32> and GpuTensor3<i32>
- Return DLPack capsules instead of numpy arrays (no D2H transfer)
- Use torch.from_dlpack() in Python for zero-copy tensor creation

Multi-GPU support:
- Add per-device context cache using CudaDevice::count()
- get_context_for_device(device_id) for any valid CUDA device
- Pass correct device_id to DLPack output metadata

Robustness:
- Reject CudaHost inputs (require DeviceType::Cuda)
- Guard against zero-sized dimensions
- Validate tensors are on same CUDA device
- Cast argmax to int64 for PyTorch indexing ops

Closes #34

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@isPANN isPANN force-pushed the feat/batched-gpu-dlpack-zero-copy branch from 70ffcc2 to b70e4c0 Compare January 28, 2026 18:48
GiggleLiu and others added 6 commits January 29, 2026 09:12
- Remove unused 'warnings' import
- Remove .to(torch.int64) cast on argmax tensors to preserve zero-copy
  semantics (PyTorch scatter_add_ supports int32 indices on CUDA)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add device_ptr(), into_inner() to GpuMatrix for DLPack export
- Add into_parts() to GpuMatrixWithArgmax for splitting result
- Create DLPackGpuMatrixF32/I32 wrapper types implementing ToTensor
- Refactor 2D DLPack functions into dlpack_2d_impl() helper
- Return DLPack capsules for GPU path (data stays on GPU)
- Use get_context_for_device() instead of get_global_context()
- Reject CudaHost inputs with explicit error message
- Add zero-size dimension guard
- Update Python GPU classes to use torch.from_dlpack()

Before: GPU results were downloaded to host, Python copied back to GPU
After: GPU results stay on GPU via DLPack zero-copy capsules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Fix scatter() int64 requirement for batched GPU path
- Remove unused get_global_context import
- Add #[allow(deprecated)] for IntoPy (dlpark dependency)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
The test was using the old API (torch.from_numpy) but now
maxplus_matmul_dlpack returns DLPack capsules for GPU tensors.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add _get_dtype_funcs() helper to support both f32 and f64 dtypes
- Remove unused _rust_cpu_*_with_argmax helper functions
- Replace np.array() with np.asarray() to avoid unnecessary copies
- Update GPU docstrings to reflect true zero-copy (no D2H transfer)
- Fix DLPack docstring to clarify CPU returns numpy arrays
- Save dtype in ctx for consistent backward pass handling

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@isPANN
Copy link
Collaborator Author

isPANN commented Jan 29, 2026

Summary of Latest Fixes

2D DLPack Zero-Copy Output (commit 971f72f)

Fixed the 2D DLPack GPU path to return zero-copy capsules instead of doing D2H transfer:

  • Added device_ptr(), into_inner() to GpuMatrix for DLPack export
  • Added into_parts() to GpuMatrixWithArgmax for splitting result
  • Created DLPackGpuMatrixF32/DLPackGpuMatrixI32 wrapper types implementing ToTensor
  • Refactored 2D DLPack functions into dlpack_2d_impl() helper
  • Use get_context_for_device() instead of get_global_context() for multi-GPU support
  • Reject CudaHost inputs with explicit error message
  • Updated Python GPU classes to use torch.from_dlpack()

Before: GPU results downloaded to host, Python copied back to GPU
After: GPU results stay on GPU via DLPack zero-copy capsules

CPU Dtype Handling & Code Cleanup (commit 2a39fb3)

  • Added _get_dtype_funcs() helper to support both f32 and f64 dtypes (no more silent precision loss)
  • Removed unused _rust_cpu_*_with_argmax helper functions (dead code)
  • Replaced np.array() with np.asarray() to avoid unnecessary copies
  • Updated GPU docstrings to reflect true zero-copy (no D2H transfer mention)
  • Fixed DLPack docstring to clarify CPU returns numpy arrays
  • Save dtype in ctx for consistent backward pass handling

Bug Fixes

  • Cast batched GPU argmax to int64 for PyTorch scatter operations (commit 09a9994)
  • Updated test to use torch.from_dlpack() for GPU path (commit d9180a3)

isPANN and others added 3 commits January 29, 2026 13:45
…d copies

- Release GIL with py.allow_threads() during heavy Rust compute in all
  29 CPU functions (matmul, with_argmax, backward, batched)
- Add 12 new *_matmul_2d functions returning PyArray2 for all dtypes
  (f32, f64, i32, i64) to avoid manual reshape after matmul
- Fix unnecessary copies in batched CPU path by using np.asarray()
  instead of np.array() in pytorch.py

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- Add 2D output functions to API reference with usage examples
- Document GIL release during compute in performance guide
- Add Python threading section with example code
- Update PyTorch example to use np.asarray for zero-copy
- Update changelog with new features and fixes

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
@isPANN isPANN merged commit cf4dbd1 into main Jan 29, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Performance: Memory layout and data transfer optimizations

2 participants