diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 727aac447b..a897b09330 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -535,11 +535,13 @@ size_t get_max_batch_size(size_t batch_size) { // batch size is expected to be 10s-100s // b = 1, ..., 32 -> max_b = 32 // b = 33, ..., 512 -> max_b = next power of 2 - // otherwise -> max_b = b + // b = 513, ... -> max_b = increment by 512 if (log2_b <= 5) { max_b = 32; } else if (log2_b <= 9) { max_b = pow(2, log2_b); + } else { + max_b = (batch_size + 511) / 512 * 512; } return max_b; }