From dccf67e7b3b70cc3d30a2e071e51a83c31d17e31 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 6 Feb 2026 16:35:09 -0800 Subject: [PATCH] [Common] Bucket batch size with higher granularity for THD (#2653) bucket max_b with more granularity when >512 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- transformer_engine/common/fused_attn/utils.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 727aac447b..a897b09330 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -535,11 +535,13 @@ size_t get_max_batch_size(size_t batch_size) { // batch size is expected to be 10s-100s // b = 1, ..., 32 -> max_b = 32 // b = 33, ..., 512 -> max_b = next power of 2 - // otherwise -> max_b = b + // b = 513, ... -> max_b = increment by 512 if (log2_b <= 5) { max_b = 32; } else if (log2_b <= 9) { max_b = pow(2, log2_b); + } else { + max_b = (batch_size + 511) / 512 * 512; } return max_b; }