Skip to content

Conversation

@justinrosner
Copy link
Contributor

Motivation

This PR adds in new scaled wmma instructions that are available on gfx1250.

This implements: https://github.com/ROCm/rocMLIR-internal/issues/2133

Technical Details

Upstream changes needed for this:

Note: Both of these external commits can be dropped when the December upstream merge goes in

rocMLIR changes:

  • Add extra logic in WmmaInsnGroup/AccelEmitter
  • Updates to AmdArchDb to allow for fp4 wmma types

Test Plan

  • Nightly CI
  • gfx1250 emulation tests

Test Result

  • Nightly CI
  • gfx1250 emulation

Submission Checklist

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds support for scaled WMMA (Wave Matrix Multiply-Accumulate) instructions for AMD's gfx1250 architecture, enabling matrix operations with per-block scaling for small float types (FP4, FP6, FP8/BF8).

Key changes:

  • Introduces new scaled_wmma operation and lowering pipeline for gfx1250
  • Adds infrastructure to differentiate between scaled and non-scaled WMMA instruction variants
  • Extends type support to include FP4 and FP6 formats for WMMA operations

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp Implements scaled WMMA emission in WmmaEmitter, adds forScaledOp parameter to AccelEmitter::select
mlir/lib/Dialect/Rock/Transforms/ThreadwiseGemmLowering.cpp Updates scale handling to preserve vector types for WMMA (vs extracting scalars for MFMA)
mlir/lib/Dialect/Rock/Transforms/GridwiseGemmToBlockwise.cpp Passes forScaledOp flag to AccelEmitter selection
mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp Passes forScaledOp flag to AccelEmitter selection
mlir/lib/Dialect/Rock/IR/WmmaInsnGroup.cpp Adds SmallFloat_To_F32_TyId type, instruction selection logic for scaled WMMA, and forScaledOp parameter
mlir/lib/Dialect/Rock/IR/RockDialect.cpp Extends validation to support FP8 and FP4 types for scaled GEMMs
mlir/lib/Dialect/Rock/IR/AmdArchDb.cpp Enables FP4 and FP6 types for WMMA when scaled GEMM is available
mlir/include/mlir/Dialect/Rock/IR/WmmaInsnGroup.h Adds SmallFloat_To_F32_TyId enum value and isScaled field to WmmaInsn
mlir/include/mlir/Dialect/Rock/IR/RockOps.td Updates vector type constraints for FP4 in memory operations
mlir/include/mlir/Dialect/Rock/IR/AccelEmitter.h Adds forScaledOp parameter to select() method
external/llvm-project/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp Implements ScaledWMMAOp verification for matrix/scale type combinations
external/llvm-project/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp Implements ScaledWMMAOpLowering to convert to ROCDL intrinsics, refactors helper functions
external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td Defines ROCDL scaled WMMA intrinsic operations, renames operands to lowercase
external/llvm-project/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td Defines ScaledWMMAOp with comprehensive documentation
mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir Simplifies CHECK patterns to use CHECK-DAG for order-independent matching
mlir/test/Dialect/Rock/lowering_wmma_gemm.mlir Adds comprehensive tests for scaled FP4 WMMA operations
external/llvm-project/mlir/test/Target/LLVMIR/rocdl.mlir Adds extensive tests for scaled WMMA intrinsic lowering and attributes
external/llvm-project/mlir/test/Dialect/LLVMIR/rocdl.mlir Adds basic scaled WMMA operation syntax tests
external/llvm-project/mlir/test/Dialect/AMDGPU/ops.mlir Adds tests for scaled WMMA operation parsing and printing
external/llvm-project/mlir/test/Conversion/AMDGPUToROCDL/wmma-gfx1250.mlir Adds conversion tests and error cases for scaled WMMA

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

isa<Float8E5M2Type, Float8E4M3FNType>(elementType));

if (hasScaledGemm) {
isValidWmmaType = isValidWmmaType || isa<Float4E2M1FNType>(elementType) ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The .td also mentions fp8 right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use isValidScaledGemmMatrixType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The FP8 types are handled by the isValidWmmaType. I don't think that we want to use isValidScaledGemmMatrixType here because it would make certain FP8 types valid when hasScaledGemm=true but hasOcpFp8ConversionInsts=False

argScaleB =
vector::ExtractOp::create(b, loc, argScaleB, zeroConstantOp);

// For MFMA, extract scalar from scale vector; for WMMA, keep as vector
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused, how is this relevant for the current PR? Taking the opportunity to refactor scaled MFMA a bit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scales are handled differently between MFMA and WMMA. For MFMA they can either be a scalar or a vector, but for WMMA they have to be a vector. This code path chooses the accel instructions, and since previously we only had scales for MFMA, this code needed a bit of updating to accommodate for both.

b, loc, vectorType, wmmaInsn.mPerAccel, wmmaInsn.nPerAccel,
wmmaInsn.kDim, argA, argB, vectorC, scaleA, firstScaleLane, scaleB,
firstScaleLane);
vectorD = wmma.getDestD();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We can move this outside the if

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move vectorD outside the if/else then we have to declare wmma as an Op before the if/else (scaled and non-scaled don't share a common interface), and in doing so we would lose the ability to call getDestD() and would instead have to call getResult(0), which is messier in my opinion.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e2e tests?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants