-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[MLIR] Enable scalable vectorization for linalg.batch_matmul #172333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Also add a missing testcase for fixed size `linalg.batch_matmul` vectorization.
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Momchil Velikov (momchil-velikov) ChangesAlso add a missing testcase for fixed size Full diff: https://github.com/llvm/llvm-project/pull/172333.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index bb3bccdae0e14..4d7e45aa8036f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2640,6 +2640,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
+ isa<linalg::BatchMatmulOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
isa<linalg::BatchMmt4DOp>(op) ||
diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
index 170bae6141609..1f8762bd3b1ef 100644
--- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir
@@ -1725,3 +1725,87 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+func.func @batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_matmul(
+// CHECK-SAME: %[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
+// CHECK: %[[c2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
+// CHECK: %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[P0:.*]] = ub.poison : f32
+// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x8x4xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[P1:.*]] = ub.poison : f32
+// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x16xi1>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x16x4xf32> } : vector<4x4x16xi1> -> vector<4x8x16x4xf32>
+// CHECK: %[[P2:.*]] = ub.poison : f32
+// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x16xi1>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x16x4xf32>
+// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x16x4xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x16x4xf32> to vector<4x8x16xf32> } : vector<4x8x16x4xi1> -> vector<4x8x16xf32>
+// CHECK: %[[c0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x16xf32>, memref<?x?x?xf32> } : vector<4x8x16xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [4, 8, 16, 4] : !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+func.func @batch_matmul_scalable(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memref<?x?x?xf32>) {
+ linalg.batch_matmul ins(%A, %B: memref<?x?x?xf32>, memref<?x?x?xf32>)
+ outs(%C: memref<?x?x?xf32>)
+ return
+}
+
+// CHECK-LABEL: func.func @batch_matmul_scalable
+// CHECK-SAME: (%[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32>) {
+// CHECK: %[[c0:.*]] = arith.constant 0 : index
+// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref<?x?x?xf32>
+// CHECK: %[[c1:.*]] = arith.constant 1 : index
+// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref<?x?x?xf32>
+// CHECK: %[[c2:.*]] = arith.constant 2 : index
+// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref<?x?x?xf32>
+// CHECK: %[[c2_2:.*]] = arith.constant 2 : index
+// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref<?x?x?xf32>
+// CHECK: %[[c0_4:.*]] = arith.constant 0 : index
+// CHECK: %[[P0:.*]] = ub.poison : f32
+// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1>
+// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x8x4xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[P1:.*]] = ub.poison : f32
+// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x[16]xi1>
+// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref<?x?x?xf32>, vector<4x8x[16]x4xf32> } : vector<4x4x[16]xi1> -> vector<4x8x[16]x4xf32>
+// CHECK: %[[P2:.*]] = ub.poison : f32
+// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x[16]xi1>
+// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref<?x?x?xf32>, vector<4x8x[16]xf32> } : vector<4x8x[16]xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x[16]x4xf32>
+// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x[16]x4xi1>
+// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction <add>, %[[MUL]], %[[VC]] [3] : vector<4x8x[16]x4xf32> to vector<4x8x[16]xf32> } : vector<4x8x[16]x4xi1> -> vector<4x8x[16]xf32>
+// CHECK: %[[c0_5:.*]] = arith.constant 0 : index
+// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x[16]xf32>, memref<?x?x?xf32> } : vector<4x8x[16]xi1>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %matmul vector_sizes [4, 8, [16], 4] : !transform.any_op
+ transform.yield
+ }
+}
|
|
Looks trivial to me. @arun-thmn any feedback? What about |
+1 for |
Is that a blocker for you? I am not against it, but as a rule of thumb, we only enable Ops that we actually run. This way, we have 100% confidence that everything works (i.e. lowering all the way to LLVM). Perhaps that's too conservative - ultimately, this is just unblocking |
banach-space
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM % formatting
Please wait for +1 from either @rengolin or @arun-thmn before landing.
|
|
||
| // CHECK-LABEL: func.func @batch_matmul( | ||
| // CHECK-SAME: %[[A:.*]]: memref<?x?x?xf32>, %[[B:.*]]: memref<?x?x?xf32>, %[[C:.*]]: memref<?x?x?xf32> | ||
| // CHECK: %[[c0:.*]] = arith.constant 0 : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use CAPS for all LIT variables
| // CHECK: %[[c0:.*]] = arith.constant 0 : index | |
| // CHECK: %[[C0:.*]] = arith.constant 0 : index |
Also add a missing testcase for fixed size
linalg.batch_matmulvectorization.