Skip to content

Conversation

@momchil-velikov
Copy link
Collaborator

Also add a missing testcase for fixed size linalg.batch_matmul vectorization.

Also add a missing testcase for fixed size `linalg.batch_matmul`
vectorization.
@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

Also add a missing testcase for fixed size linalg.batch_matmul vectorization.


Full diff: https://github.com/llvm/llvm-project/pull/172333.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir (+84)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..4d7e45aa8036f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2640,6 +2640,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
   // Cond 4: Only the following ops are supported in the
   // presence of scalable vectors
   return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+                 isa<linalg::BatchMatmulOp>(op) ||
                  isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
                  isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
                  isa<linalg::BatchMmt4DOp>(op) ||
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 170bae6141609..1f8762bd3b1ef 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1725,3 +1725,87 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+func.func @batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+  linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                      outs(%C: memref<?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME:  %[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>
+// CHECK:       %[[c0:.*]] = arith.constant 0 : index
+// CHECK:       %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
+// CHECK:       %[[c1:.*]] = arith.constant 1 : index
+// CHECK:       %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
+// CHECK:       %[[c2:.*]] = arith.constant 2 : index
+// CHECK:       %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
+// CHECK:       %[[c2_2:.*]] = arith.constant 2 : index
+// CHECK:       %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
+// CHECK:       %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK:       %[[P0:.*]] = ub.poison : f32
+// CHECK:       %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK:       %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x8x4xi1> -> vector<4x8x16x4xf32>
+// CHECK:       %[[P1:.*]] = ub.poison : f32
+// CHECK:       %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x16xi1>
+// CHECK:       %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x4x16xi1> -> vector<4x8x16x4xf32>
+// CHECK:       %[[P2:.*]] = ub.poison : f32
+// CHECK:       %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x16xi1>
+// CHECK:       %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK:       %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x16x4xf32>
+// CHECK:       %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x16x4xi1>
+// CHECK:       %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x16x4xf32> to vector<4x8x16xf32> } : vector<4x8x16x4xi1> -> vector<4x8x16xf32>
+// CHECK:       %[[c0_5:.*]] = arith.constant 0 : index
+// CHECK:       vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x16xf32>, memref<?x?x?xf32> } : vector<4x8x16xi1>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %matmul vector_sizes [4, 8, 16, 4] : !transform.any_op
+    transform.yield
+  }
+}
+
+// -----
+
+func.func @batch_matmul_scalable(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+  linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+                      outs(%C: memref<?x?x?xf32>)
+  return
+}
+
+// CHECK-LABEL: func.func @batch_matmul_scalable
+// CHECK-SAME:  (%[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>) {
+// CHECK:       %[[c0:.*]] = arith.constant 0 : index
+// CHECK:       %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
+// CHECK:       %[[c1:.*]] = arith.constant 1 : index
+// CHECK:       %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
+// CHECK:       %[[c2:.*]] = arith.constant 2 : index
+// CHECK:       %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
+// CHECK:       %[[c2_2:.*]] = arith.constant 2 : index
+// CHECK:       %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
+// CHECK:       %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK:       %[[P0:.*]] = ub.poison : f32
+// CHECK:       %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK:       %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x8x4xi1> -> vector<4x8x[16]x4xf32>
+// CHECK:       %[[P1:.*]] = ub.poison : f32
+// CHECK:       %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x[16]xi1>
+// CHECK:       %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x4x[16]xi1> -> vector<4x8x[16]x4xf32>
+// CHECK:       %[[P2:.*]] = ub.poison : f32
+// CHECK:       %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x[16]xi1>
+// CHECK:       %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x[16]xf32> } : vector<4x8x[16]xi1> -> vector<4x8x[16]xf32>
+// CHECK:       %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x[16]x4xf32>
+// CHECK:       %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x[16]x4xi1>
+// CHECK:       %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x[16]x4xf32> to vector<4x8x[16]xf32> } : vector<4x8x[16]x4xi1> -> vector<4x8x[16]xf32>
+// CHECK:       %[[c0_5:.*]] = arith.constant 0 : index
+// CHECK:       vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x[16]xf32>, memref<?x?x?xf32> } : vector<4x8x[16]xi1>
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+    transform.structured.vectorize %matmul vector_sizes [4, 8, [16], 4] : !transform.any_op
+    transform.yield
+  }
+}

@rengolin
Copy link
Member

Looks trivial to me. @arun-thmn any feedback?

What about batch_reduce_matmul? Would that be as trivial, too?

@arun-thmn
Copy link
Contributor

Looks trivial to me. @arun-thmn any feedback?

What about batch_reduce_matmul? Would that be as trivial, too?

+1 for batch.reduce_matmul

@banach-space
Copy link
Contributor

Looks trivial to me. @arun-thmn any feedback?
What about batch_reduce_matmul? Would that be as trivial, too?

+1 for batch.reduce_matmul

Is that a blocker for you? I am not against it, but as a rule of thumb, we only enable Ops that we actually run. This way, we have 100% confidence that everything works (i.e. lowering all the way to LLVM).

Perhaps that's too conservative - ultimately, this is just unblocking Linalg -> Vector lowering - but this way we avoid "scalable" vectors being enabled quite high-up (Linalg) and then not working further down the lowering pipeline.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM % formatting

Please wait for +1 from either @rengolin or @arun-thmn before landing.


// CHECK-LABEL: func.func @batch_matmul(
// CHECK-SAME: %[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>
// CHECK: %[[c0:.*]] = arith.constant 0 : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use CAPS for all LIT variables

Suggested change
// CHECK: %[[c0:.*]] = arith.constant 0 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants