Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 38 additions & 19 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2624,6 +2624,42 @@ def addConfig__(xblock:int, yblock:int, num_warps:int, num_stages:int):
filename=filename,
)

def _get_reduction_block_limits(
size_hints: dict[str, int],
inductor_meta: dict[str, Any],
) -> tuple[int, bool]:
"""
Determine the maximum reduction block size and whether the kernel is register intensive.

Heuristic to reduce R0_BLOCK if a kernel potentially needs many registers.
We consider load and reduction operations since loads move data into registers
and reductions need accumulators.
"""
register_intensive = False

# Get the maximum reduction dimension hint (already a power of 2)
# For multi-dimensional reductions (r0_, r1_, etc.), use the largest
reduction_hints = [
numel for prefix, numel in size_hints.items() if prefix_is_reduction(prefix)
]
max_reduction_hint = max(reduction_hints) if reduction_hints else 1

# Cap based on actual reduction size and analytically derived maximums
max_cap = 4096 if torch.version.hip else 2048
max_rblock = min(max_reduction_hint, max_cap)

loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get(
"num_reduction", 0
)

if size_hints["x"] >= 1024 and loads_and_red >= 10:
# Reduce block size for kernels that potentially need many registers
# Scale down by factor of 2 for better register allocation
max_rblock = min(max_reduction_hint, max_cap // 2)
register_intensive = True

return max_rblock, register_intensive


def _reduction_configs(
*, size_hints: dict[str, int], inductor_meta: dict[str, Any], num_dynamic=0
Expand All @@ -2638,26 +2674,9 @@ def _reduction_configs(
"max_autotune_pointwise"
)

register_intensive = False
MAX_R0_BLOCK = 2048
loads_and_red = inductor_meta.get("num_load", 0) + inductor_meta.get(
"num_reduction", 0
MAX_R0_BLOCK, register_intensive = _get_reduction_block_limits(
size_hints, inductor_meta
)
if size_hints["x"] >= 1024 and loads_and_red >= 10:
# A heuristics to reduce R0_BLOCK if a kernel potentially need many registers.
# Consider load and reduction since load need move data into registers and
# reduction needs an accumulator.
#
# The magic numbers are a bit arbitrary.
#
# We cannot rely on dynamically scaling down R0_BLOCK later, since sometimes
# triton makes it to use less registers with worse perf. Check:
# https://github.com/pytorch/pytorch/issues/126463
#
# The heuristic is a very simple one since registers can be reused. But
# hopefully it can be a good enough indicator.
MAX_R0_BLOCK = 1024
register_intensive = True

def make_config(
x,
Expand Down