diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index a7f4d9f5763ff..2c1ca1cddba48 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -1048,6 +1048,10 @@ def __init__(self) -> None: ROCmGemmConfig(256, 128, 32, self.default_num_stages, 8, group_m=16), ROCmGemmConfig(256, 128, 64, self.default_num_stages, 8, group_m=4), ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), + ROCmGemmConfig(16, 16, 256, 2, 4, group_m=1, waves_per_eu=2), + ROCmGemmConfig(16, 16, 128, 3, 4, group_m=1, waves_per_eu=4), + ROCmGemmConfig(16, 16, 256, 3, 4, group_m=1, waves_per_eu=4), + ROCmGemmConfig(64, 256, 64, 16, 4, group_m=1, waves_per_eu=2), ] # Exhaustive search for mm configs