Skip to content

Conversation

@justinrosner
Copy link
Contributor

@justinrosner justinrosner commented Jan 20, 2026

Motivation

This PR fixes a crash that MIGraphX was seeing when compiling an attention kernel with fusion: https://amd-hub.atlassian.net/browse/AIROCMLIR-438

Technical Details

When lowering gridwise_attention_accel ops with preSoftmax fusion, the gemm0 output buffer element type was unconditionally set to elemTypeV (the values input element type). This caused a type mismatch when the preSoftmax body's linalg.generic expected a different element type for it's gemm0 based input (e.g., when the linalg.generic was truncating/extending).

Test Plan

  • PR CI
  • Original kernel from MIGraphX is passing (and turned into a LIT test)

Test Result

  • PR CI

Submission Checklist

@justinrosner justinrosner requested a review from causten as a code owner January 20, 2026 13:44
Copilot AI review requested due to automatic review settings January 20, 2026 13:44
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a crash in MIGraphX when compiling attention kernels with preSoftmax fusion. The issue occurred when lowering gridwise_attention_accel operations where the gemm0 output buffer element type was incorrectly set to the values input element type (elemTypeV), causing a type mismatch when the preSoftmax body's linalg.generic operation expected a different element type (e.g., when truncating or extending).

Changes:

  • Modified element type determination logic to walk the preSoftmax body and extract the correct type from the first linalg.generic operation's gemm0-based input
  • Added a comprehensive LIT test that reproduces the original MIGraphX failure scenario with type conversions in the preSoftmax fusion body

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Added logic to walk preSoftmax body and determine gemmOutElemType from the first generic's gemm0-based input, and fusionOutElemType from the last generic's output, fixing the element type mismatch
mlir/test/Dialect/Rock/gridwise-gemm-linalg-failure.mlir New test file verifying correct handling of attention operations with preSoftmax fusion that performs f16 to f32 extension, ensuring the lowering produces correct buffer types

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants