-
Notifications
You must be signed in to change notification settings - Fork 77
Mahdieh/gb200 nvloptimized #708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR optimizes MSCCL++ allreduce operations for GB200 (NVIDIA's next-generation GPU with compute capability 10.0) by adjusting block counts, switch channels, and memory alignment parameters specifically for this architecture. The changes enable better performance on GB200 systems by leveraging architecture-specific optimizations in NVLS (NVLink Switch) operations.
Key changes include:
- Increased block count from 8 to 24 for GB200 in NVLS allreduce operations
- Increased switch channel count from 8 to 24 for GB200 during initialization
- Added 16-byte alignment for block size calculations in device code for GB200
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.
| File | Description |
|---|---|
| apps/nccl/src/allreduce.hpp | Adds GB200-specific device-side optimizations with 16-byte block size alignment and adjusted last-block size handling |
| apps/nccl/src/allreduce.cu | Implements runtime detection of GB200 to configure optimal block counts (24) and switch channels (24) versus default values (8) |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| nSwitchChannels_ = 8; | ||
| cudaDeviceProp prop; | ||
| cudaGetDeviceProperties(&prop, 0); | ||
| nSwitchChannels_ = (prop.major == 10) ? 24 : 8; |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic numbers 24 and 8 for nSwitchChannels lack documentation. Consider adding a comment explaining why 24 channels are optimal for GB200 (compute capability 10.0) versus 8 channels for earlier architectures, or referencing relevant documentation or performance tuning results.
apps/nccl/src/allreduce.hpp
Outdated
| int nBlocks = gridDim.x; | ||
| int bid = blockIdx.x; | ||
| size_t sizePerRank = size / nRanksPerNode; | ||
| #if __CUDA_ARCH__ >=1000 |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inconsistent spacing in preprocessor directive. The space before >= should be added for consistency with other similar directives in this file (e.g., lines 771, 819, 849, 871, 968, 1100 all use __CUDA_ARCH__ >= 900 with a space).
| int bid = blockIdx.x; | ||
| size_t sizePerRank = size / nRanksPerNode; | ||
| #if __CUDA_ARCH__ >=1000 | ||
| size_t sizePerBlock = (sizePerRank / nBlocks) / 16 * 16; |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic number 16 is used for alignment without explanation. Consider adding a comment explaining that this is for 16-byte alignment optimization on GB200, or defining it as a named constant to improve code readability and maintainability.
apps/nccl/src/allreduce.cu
Outdated
| using ChannelType = mscclpp::DeviceHandle<mscclpp::BaseMemoryChannel>; | ||
| int nBlocks = nRanksPerNode; | ||
| cudaDeviceProp prop; | ||
| cudaGetDeviceProperties(&prop, 0); |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error checking for cudaGetDeviceProperties. While this codebase typically uses MSCCLPP_CUDATHROW for error checking CUDA API calls (as seen in src/npkit/npkit.cc and test files), this call doesn't check the return value. Consider adding error checking to handle potential failures gracefully.
| int nBlocks = nRanksPerNode; | ||
| cudaDeviceProp prop; | ||
| cudaGetDeviceProperties(&prop, 0); | ||
| int nBlocks = (prop.major == 10) ? 24 : nRanksPerNode; |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The magic number 24 for nBlocks on GB200 lacks documentation. Consider adding a comment explaining why 24 blocks are optimal for compute capability 10.0, or referencing relevant documentation or performance tuning results.
| cudaDeviceProp prop; | ||
| cudaGetDeviceProperties(&prop, 0); | ||
| int nBlocks = (prop.major == 10) ? 24 : nRanksPerNode; |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Repeated device property query with identical logic. The same cudaGetDeviceProperties call and prop.major == 10 check is duplicated at line 336-338. Consider caching the device properties or extracting this logic into a helper function to avoid code duplication and potential inconsistencies.
| cudaDeviceProp prop; | ||
| cudaGetDeviceProperties(&prop, 0); | ||
| int nBlocks = (prop.major == 10) ? 24 : nRanksPerNode; |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performance concern: cudaGetDeviceProperties is called on every kernel launch in the hot path. This is an expensive host-side operation that queries device properties. Consider caching the result during initialization (e.g., in the NvlsAdapter struct or as a class member) to avoid this overhead on every allreduce operation.
apps/nccl/src/allreduce.cu
Outdated
| std::unordered_map<std::string, std::shared_ptr<void>>&) { | ||
| nSwitchChannels_ = 8; | ||
| cudaDeviceProp prop; | ||
| cudaGetDeviceProperties(&prop, 0); |
Copilot
AI
Dec 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing error checking for cudaGetDeviceProperties. While this codebase typically uses MSCCLPP_CUDATHROW for error checking CUDA API calls (as seen in src/npkit/npkit.cc and test files), this call doesn't check the return value. Consider adding error checking to handle potential failures gracefully.
This PR improves the performance of msccl++ on GB200. We need to update the quick start guide also adding that this option needs to be added to cmake command for compilation on GB200:
-DMSCCLPP_GPU_ARCHS=100