From dc44b213510eff6e84d29c999f0e052de3a80b8b Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Fri, 28 Nov 2025 15:14:52 +0000 Subject: [PATCH] [NO CP] Performance update increase MAX_R0_BLOCK --- torch/_inductor/runtime/triton_heuristics.py | 57 +++++++++++++------- 1 file changed, 38 insertions(+), 19 deletions(-) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2bd7c372f396f..015c8afc8815e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2624,6 +2624,42 @@ def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int): filename=filename, ) +def _get_reduction_block_limits( + size_hints: dict[str, int], + inductor_meta: dict[str, Any], +) -> tuple[int, bool]: + """ + Determine the maximum reduction block size and whether the kernel is register intensive. + + Heuristic to reduce R0_BLOCK if a kernel potentially needs many registers. + We consider load and reduction operations since loads move data into registers + and reductions need accumulators. + """ + register_intensive = False + + # Get the maximum reduction dimension hint (already a power of 2) + # For multi-dimensional reductions (r0_, r1_, etc.), use the largest + reduction_hints = [ + numel for prefix, numel in size_hints.items() if prefix_is_reduction(prefix) + ] + max_reduction_hint = max(reduction_hints) if reduction_hints else 1 + + # Cap based on actual reduction size and analytically derived maximums + max_cap = 4096 if torch.version.hip else 2048 + max_rblock = min(max_reduction_hint, max_cap) + + loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get( + "num_reduction", 0 + ) + + if size_hints["x"] >= 1024 and loads_and_red >= 10: + # Reduce block size for kernels that potentially need many registers + # Scale down by factor of 2 for better register allocation + max_rblock = min(max_reduction_hint, max_cap // 2) + register_intensive = True + + return max_rblock, register_intensive + def _reduction_configs( *, size_hints: dict[str, int], inductor_meta: dict[str, Any], num_dynamic=0 @@ -2638,26 +2674,9 @@ def _reduction_configs( "max_autotune_pointwise" ) - register_intensive = False - MAX_R0_BLOCK = 2048 - loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get( - "num_reduction", 0 + MAX_R0_BLOCK, register_intensive = _get_reduction_block_limits( + size_hints, inductor_meta ) - if size_hints["x"] >= 1024 and loads_and_red >= 10: - # A heuristics to reduce R0_BLOCK if a kernel potentially need many registers. - # Consider load and reduction since load need move data into registers and - # reduction needs an accumulator. - # - # The magic numbers are a bit arbitrary. - # - # We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes - # triton makes it to use less registers with worse perf. Check: - # https://github.com/pytorch/pytorch/issues/126463 - # - # The heuristic is a very simple one since registers can be reused. But - # hopefully it can be a good enough indicator. - MAX_R0_BLOCK = 1024 - register_intensive = True def make_config( x,