-
Notifications
You must be signed in to change notification settings - Fork 52
Add rocMLIR support for gfx1250 scaled wmma instructions #2193
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds support for scaled WMMA (Wave Matrix Multiply-Accumulate) instructions for AMD's gfx1250 architecture, enabling matrix operations with per-block scaling for small float types (FP4, FP6, FP8/BF8).
Key changes:
- Introduces new
scaled_wmmaoperation and lowering pipeline for gfx1250 - Adds infrastructure to differentiate between scaled and non-scaled WMMA instruction variants
- Extends type support to include FP4 and FP6 formats for WMMA operations
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp |
Implements scaled WMMA emission in WmmaEmitter, adds forScaledOp parameter to AccelEmitter::select |
mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp |
Updates scale handling to preserve vector types for WMMA (vs extracting scalars for MFMA) |
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp |
Passes forScaledOp flag to AccelEmitter selection |
mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp |
Passes forScaledOp flag to AccelEmitter selection |
mlir/lib/Dialect/Rock/IR/WmmaInsnGroup.cpp |
Adds SmallFloat_To_F32_TyId type, instruction selection logic for scaled WMMA, and forScaledOp parameter |
mlir/lib/Dialect/Rock/IR/RockDialect.cpp |
Extends validation to support FP8 and FP4 types for scaled GEMMs |
mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp |
Enables FP4 and FP6 types for WMMA when scaled GEMM is available |
mlir/include/mlir/Dialect/Rock/IR/WmmaInsnGroup.h |
Adds SmallFloat_To_F32_TyId enum value and isScaled field to WmmaInsn |
mlir/include/mlir/Dialect/Rock/IR/RockOps.td |
Updates vector type constraints for FP4 in memory operations |
mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h |
Adds forScaledOp parameter to select() method |
external/llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp |
Implements ScaledWMMAOp verification for matrix/scale type combinations |
external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp |
Implements ScaledWMMAOpLowering to convert to ROCDL intrinsics, refactors helper functions |
external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td |
Defines ROCDL scaled WMMA intrinsic operations, renames operands to lowercase |
external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td |
Defines ScaledWMMAOp with comprehensive documentation |
mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir |
Simplifies CHECK patterns to use CHECK-DAG for order-independent matching |
mlir/test/Dialect/Rock/lowering_wmma_gemm.mlir |
Adds comprehensive tests for scaled FP4 WMMA operations |
external/llvm-project/mlir/test/Target/LLVMIR/rocdl.mlir |
Adds extensive tests for scaled WMMA intrinsic lowering and attributes |
external/llvm-project/mlir/test/Dialect/LLVMIR/rocdl.mlir |
Adds basic scaled WMMA operation syntax tests |
external/llvm-project/mlir/test/Dialect/AMDGPU/ops.mlir |
Adds tests for scaled WMMA operation parsing and printing |
external/llvm-project/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir |
Adds conversion tests and error cases for scaled WMMA |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Show resolved
Hide resolved
external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Show resolved
Hide resolved
external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Show resolved
Hide resolved
| isa<Float8E5M2Type, Float8E4M3FNType>(elementType)); | ||
|
|
||
| if (hasScaledGemm) { | ||
| isValidWmmaType = isValidWmmaType || isa<Float4E2M1FNType>(elementType) || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The .td also mentions fp8 right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use isValidScaledGemmMatrixType?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The FP8 types are handled by the isValidWmmaType. I don't think that we want to use isValidScaledGemmMatrixType here because it would make certain FP8 types valid when hasScaledGemm=true but hasOcpFp8ConversionInsts=False
| argScaleB = | ||
| vector::ExtractOp::create(b, loc, argScaleB, zeroConstantOp); | ||
|
|
||
| // For MFMA, extract scalar from scale vector; for WMMA, keep as vector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, how is this relevant for the current PR? Taking the opportunity to refactor scaled MFMA a bit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scales are handled differently between MFMA and WMMA. For MFMA they can either be a scalar or a vector, but for WMMA they have to be a vector. This code path chooses the accel instructions, and since previously we only had scales for MFMA, this code needed a bit of updating to accommodate for both.
| b, loc, vectorType, wmmaInsn.mPerAccel, wmmaInsn.nPerAccel, | ||
| wmmaInsn.kDim, argA, argB, vectorC, scaleA, firstScaleLane, scaleB, | ||
| firstScaleLane); | ||
| vectorD = wmma.getDestD(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: We can move this outside the if
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we move vectorD outside the if/else then we have to declare wmma as an Op before the if/else (scaled and non-scaled don't share a common interface), and in doing so we would lose the ability to call getDestD() and would instead have to call getResult(0), which is messier in my opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e2e tests?
Motivation
This PR adds in new scaled wmma instructions that are available on gfx1250.
This implements: https://github.com/ROCm/rocMLIR-internal/issues/2133
Technical Details
Upstream changes needed for this:
Note: Both of these external commits can be dropped when the December upstream merge goes in
rocMLIR changes:
Test Plan
Test Result
Submission Checklist