[large tensor] fix CUDA extensions int64 overflow for large tensor dimensions#561
[large tensor] fix CUDA extensions int64 overflow for large tensor dimensions#561zrr1999 wants to merge 1 commit intoPaddlePaddle:developfrom
Conversation
There was a problem hiding this comment.
Pull request overview
该 PR 主要围绕 CUDA 自定义算子在“大 tensor / 大维度”场景下的健壮性改造:将部分索引/计数从 int 升级到 int64_t,并在若干 kernel launcher 前增加 INT_MAX 边界保护与 0-size 快速返回,以避免整数溢出与无效 kernel launch。
Changes:
- 将多处 kernel/辅助函数的元素计数、循环索引改为
int64_t,减少大尺寸下的溢出风险 - 为若干算子新增
INT_MAX上界检查、0-size 提前返回,避免非法配置/无效 launch - 小幅调整部分逻辑分支以减少不必要的 kernel 启动
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
src/paddlefleet/_extensions/utils.h |
memcpy 辅助函数参数/索引改为 int64_t,并调整函数声明格式 |
src/paddlefleet/_extensions/tokens_zip_unique_add.cu |
kernel 内循环索引改为 int64_t,避免 hidden_size 大时溢出 |
src/paddlefleet/_extensions/tokens_zip_prob.cu |
增加 num_expert/topk 的 INT_MAX 检查并改用 int64_t 中间量计算 grid 等 |
src/paddlefleet/_extensions/tokens_unzip_slice.cu |
循环步进改为 int64_t 计算,并在 0 行时提前返回 |
src/paddlefleet/_extensions/tokens_unzip_gather.cu |
整理 scale 形状读取逻辑,补齐无 scale 时的 quanted_hidden_size 计算,并跳过无 token 的 kernel launch |
src/paddlefleet/_extensions/swiglu_kernel.cu |
增加 0-size 早返回与 rows*input_dim <= INT_MAX 检查 |
src/paddlefleet/_extensions/router_metadata.cu |
增加 num_tokens*K <= INT_MAX 检查并在部分位置显式 cast |
src/paddlefleet/_extensions/fuse_swiglu_scale.cu |
forward/backward 增加 0-size 早返回与 rows*hidden2 <= INT_MAX 检查 |
src/paddlefleet/_extensions/fuse_stack_transpose_fp8_quant.cu |
对 grid.x 做 INT_MAX 上界保护并引入 int64_t 中间变量 |
src/paddlefleet/_extensions/filter_scores.cu |
增加 total_elements/total_valid 的 INT_MAX 检查,grid_size 用 int64 计算后再安全 cast |
src/paddlefleet/_extensions/count_cumsum.cu |
128-bit load/store 支持更大索引类型,局部循环索引改为 int64_t,并新增 N==0 提前返回 |
| int64_t grid_x = (M / 128) * (K / 128); | ||
| PADDLE_ENFORCE_LE( | ||
| grid_x, | ||
| static_cast<int64_t>(std::numeric_limits<int>::max()), | ||
| common::errors::InvalidArgument( | ||
| "grid.x exceeds INT_MAX in fuse_stack_transpose_fp8_quant.")); |
There was a problem hiding this comment.
这里计算 grid_x = (M/128)*(K/128) 使用 int64_t 乘法,若 M/K 极大可能在比较前发生有符号溢出(UB),导致后续范围检查失效。建议使用更安全的乘法检查(例如用 __int128 计算并比较,或先分别校验 M/128 与 K/128 的上界再相乘)。
| int64_t grid_x = (M / 128) * (K / 128); | |
| PADDLE_ENFORCE_LE( | |
| grid_x, | |
| static_cast<int64_t>(std::numeric_limits<int>::max()), | |
| common::errors::InvalidArgument( | |
| "grid.x exceeds INT_MAX in fuse_stack_transpose_fp8_quant.")); | |
| int64_t m_blocks = M / 128; | |
| int64_t k_blocks = K / 128; | |
| __int128 grid_x_wide = | |
| static_cast<__int128>(m_blocks) * static_cast<__int128>(k_blocks); | |
| PADDLE_ENFORCE_LE( | |
| grid_x_wide, | |
| static_cast<__int128>(std::numeric_limits<int>::max()), | |
| common::errors::InvalidArgument( | |
| "grid.x exceeds INT_MAX in fuse_stack_transpose_fp8_quant.")); | |
| int64_t grid_x = static_cast<int64_t>(grid_x_wide); |
| template <typename T> | ||
| __device__ __forceinline__ void unrolled_memcpy(const T* src, | ||
| T* dst, | ||
| const int num_elements) { | ||
| const int64_t num_elements) { | ||
| #pragma unroll | ||
| for (int idx = threadIdx.x; idx < num_elements; idx += blockDim.x) { | ||
| for (int64_t idx = threadIdx.x; idx < num_elements; idx += blockDim.x) { |
There was a problem hiding this comment.
PR 标题当前为 "update",不符合仓库要求的 [CLASS]Title 格式(例如 [BugFix] ...)。另外本次 PR 未提供清晰的描述(至少需要说明为什么改、解决什么问题),建议在 PR 元信息中补全后再合入,方便后续追溯与发布说明。
| cumsum_output = paddle::empty({0}, paddle::DataType::INT32, place); | ||
| } | ||
|
|
||
| if (N == 0) { |
There was a problem hiding this comment.
count_cumsum 在 N==0 时直接 return,但此时 count_output / cumsum_output 是用 paddle::empty 创建的,内容未初始化;语义上应返回全 0(原先 kernel 会把输出清零)。建议在 N==0 分支改为创建 zeros/full(0) 张量,或在 return 前对输出做显式清零。
| if (N == 0) { | |
| if (N == 0) { | |
| // For N == 0, semantics require all-zero outputs instead of uninitialized memory. | |
| count_output.zero_(); | |
| if (do_cumsum) { | |
| cumsum_output.zero_(); | |
| } |
| int64_t zipped_rows = zipped_expertwise_rowmap_shape[0]; | ||
| int num_expert = zipped_expertwise_rowmap_shape[1]; | ||
| int topk = dispatched_indices_shape[1]; | ||
| PD_CHECK(unzipped_probs.size() == num_expert); | ||
| int64_t num_expert = zipped_expertwise_rowmap_shape[1]; | ||
| int64_t topk = dispatched_indices_shape[1]; | ||
| PD_CHECK(num_expert <= static_cast<int64_t>(std::numeric_limits<int>::max()), | ||
| "num_expert must be <= INT_MAX for tokens_zip_prob."); | ||
| PD_CHECK(topk <= static_cast<int64_t>(std::numeric_limits<int>::max()), | ||
| "topk must be <= INT_MAX for tokens_zip_prob."); | ||
| PD_CHECK(unzipped_probs.size() == static_cast<size_t>(num_expert), | ||
| "unzipped_probs.size() must equal num_expert."); | ||
| int num_expert_int = static_cast<int>(num_expert); | ||
| int topk_int = static_cast<int>(topk); | ||
|
|
||
| auto zipped_probs = | ||
| paddle::empty({zipped_rows, topk}, dtype, unzipped_probs[0].place()); | ||
|
|
There was a problem hiding this comment.
tokens_zip_prob_impl 在检查 unzipped_probs.size()==num_expert 后仍直接访问 unzipped_probs[0] 来取 place;当 num_expert==0(且 unzipped_probs 为空)会越界,同时后续 rowmap 索引也不成立。建议显式要求 num_expert > 0(以及 topk > 0,如适用),或在 0 专家/0 topk 情况下提前返回一个空 shape 的 zipped_probs。
| const int64_t num_elements) { | ||
| constexpr int vector_size_in_bytes = 16; | ||
| const int elements_per_vector = vector_size_in_bytes / sizeof(T); | ||
| const int64_t elements_per_vector = vector_size_in_bytes / sizeof(T); | ||
|
|
||
| int num_vectors = num_elements / elements_per_vector; | ||
| int remaining_elements = num_elements % elements_per_vector; | ||
| int64_t num_vectors = num_elements / elements_per_vector; | ||
| int64_t remaining_elements = num_elements % elements_per_vector; | ||
|
|
||
| using VecType = VectorType<T, elements_per_vector>; | ||
| const VecType* src_vec = reinterpret_cast<const VecType*>(src); | ||
| VecType* dst_vec = reinterpret_cast<VecType*>(dst); |
There was a problem hiding this comment.
vectorized_memcpy 里 elements_per_vector 用于非类型模板参数(VectorType<T, elements_per_vector>),但当前是局部 const int64_t 变量而非 constexpr,且模板参数类型为 int。为避免编译器/标准差异导致的编译失败或推导问题,建议改成 constexpr int elements_per_vector = vector_size_in_bytes / sizeof(T); 并将后续 num_vectors/remaining_elements/offset 等继续用 int64_t 计算。
wanghuancoder
left a comment
There was a problem hiding this comment.
LGTM。PR修改原则为:1)能ENFORCE/CHECK拦截的不改int64;2)拦截不了的改int64,但目测对Kernel性能造成的印象有限。
收到 |
主要修复 CUDA 扩展在支持大张量时的 int32 溢出问题,修改涉及多个
.cu和utils.h文件:count_cumsum.cu
load_128_bits/store_128_bits增加IdxT模板参数,支持int64_t索引N_vec、i改为int64_t,避免大 N 溢出N == 0的提前返回filter_scores.cu
gridDim.x * blockDim.x使用static_cast<int64_t>避免溢出grid_size先用int64_t计算再转成intPD_CHECK,检查total_elements/total_valid不超过INT_MAXfuse_stack_transpose_fp8_quant.cu
grid_x使用int64_t计算PADDLE_ENFORCE_LE,保证grid.x <= INT_MAXfuse_swiglu_scale.cu / swiglu_kernel.cu
rows == 0或hidden_size == 0做提前返回rows * hidden2/rows * input_dim <= INT_MAX的检查router_metadata.cu
PADDLE_ENFORCE_LE替代 TODO,检查num_tokens * K <= INT_MAXtokens_unzip_gather.cu
quanted_hidden_size逻辑tokens_unzip_slice.cu
static_cast<int64_t>(blockDim.x) * gridDim.xtotal_zipped_rows == 0提前返回tokens_zip_prob.cu
num_expert、topk改为int64_t,并增加INT_MAX检查total_items与 grid 计算改为使用int64_t再转inttokens_zip_unique_add.cu
int64_t,避免hidden_size溢出utils.h
unrolled_memcpy、vectorized_memcpy、try_vectorized_memcpy的num_elements改为int64_t