diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index bb3bccdae0e14..4d7e45aa8036f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2640,6 +2640,7 @@ vectorizeScalableVectorPrecondition(Operation *op, // Cond 4: Only the following ops are supported in the // presence of scalable vectors return success(isElementwise(linalgOp) || isa(op) || + isa(op) || isa(op) || isa(op) || isa(op) || isa(op) || diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir index 170bae6141609..1f8762bd3b1ef 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops.mlir @@ -1725,3 +1725,87 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @batch_matmul(%A: memref, %B: memref, %C: memref) { + linalg.batch_matmul ins(%A, %B: memref, memref) + outs(%C: memref) + return +} + +// CHECK-LABEL: func.func @batch_matmul( +// CHECK-SAME: %[[A:.*]]: memref, %[[B:.*]]: memref, %[[C:.*]]: memref +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref +// CHECK: %[[c2_2:.*]] = arith.constant 2 : index +// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref +// CHECK: %[[c0_4:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]] = ub.poison : f32 +// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1> +// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref, vector<4x8x16x4xf32> } : vector<4x8x4xi1> -> vector<4x8x16x4xf32> +// CHECK: %[[P1:.*]] = ub.poison : f32 +// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x16xi1> +// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref, vector<4x8x16x4xf32> } : vector<4x4x16xi1> -> vector<4x8x16x4xf32> +// CHECK: %[[P2:.*]] = ub.poison : f32 +// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x16xi1> +// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32> +// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x16x4xf32> +// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x16x4xi1> +// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction , %[[MUL]], %[[VC]] [3] : vector<4x8x16x4xf32> to vector<4x8x16xf32> } : vector<4x8x16x4xi1> -> vector<4x8x16xf32> +// CHECK: %[[c0_5:.*]] = arith.constant 0 : index +// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x16xf32>, memref } : vector<4x8x16xi1> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %matmul vector_sizes [4, 8, 16, 4] : !transform.any_op + transform.yield + } +} + +// ----- + +func.func @batch_matmul_scalable(%A: memref, %B: memref, %C: memref) { + linalg.batch_matmul ins(%A, %B: memref, memref) + outs(%C: memref) + return +} + +// CHECK-LABEL: func.func @batch_matmul_scalable +// CHECK-SAME: (%[[A:.*]]: memref, %[[B:.*]]: memref, %[[C:.*]]: memref) { +// CHECK: %[[c0:.*]] = arith.constant 0 : index +// CHECK: %[[BATCH_DIM:.*]] = memref.dim %[[A]], %[[c0]] : memref +// CHECK: %[[c1:.*]] = arith.constant 1 : index +// CHECK: %[[M:.*]] = memref.dim %[[A]], %[[c1]] : memref +// CHECK: %[[c2:.*]] = arith.constant 2 : index +// CHECK: %[[N:.*]] = memref.dim %[[B]], %[[c2]] : memref +// CHECK: %[[c2_2:.*]] = arith.constant 2 : index +// CHECK: %[[K:.*]] = memref.dim %[[A]], %[[c2_2]] : memref +// CHECK: %[[c0_4:.*]] = arith.constant 0 : index +// CHECK: %[[P0:.*]] = ub.poison : f32 +// CHECK: %[[MA:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[K]] : vector<4x8x4xi1> +// CHECK: %[[VA:.*]] = vector.mask %[[MA]] { vector.transfer_read %[[A]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P0]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref, vector<4x8x[16]x4xf32> } : vector<4x8x4xi1> -> vector<4x8x[16]x4xf32> +// CHECK: %[[P1:.*]] = ub.poison : f32 +// CHECK: %[[MB:.*]] = vector.create_mask %[[BATCH_DIM]], %[[K]], %[[N]] : vector<4x4x[16]xi1> +// CHECK: %[[VB:.*]] = vector.mask %[[MB]] { vector.transfer_read %[[B]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P1]] {in_bounds = [true, true, true, true], permutation_map = #{{.*}}} : memref, vector<4x8x[16]x4xf32> } : vector<4x4x[16]xi1> -> vector<4x8x[16]x4xf32> +// CHECK: %[[P2:.*]] = ub.poison : f32 +// CHECK: %[[MC:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]] : vector<4x8x[16]xi1> +// CHECK: %[[VC:.*]] = vector.mask %[[MC]] { vector.transfer_read %[[C]][%[[c0_4]], %[[c0_4]], %[[c0_4]]], %[[P2]] {in_bounds = [true, true, true]} : memref, vector<4x8x[16]xf32> } : vector<4x8x[16]xi1> -> vector<4x8x[16]xf32> +// CHECK: %[[MUL:.*]] = arith.mulf %[[VA]], %[[VB]] : vector<4x8x[16]x4xf32> +// CHECK: %[[MRED:.*]] = vector.create_mask %[[BATCH_DIM]], %[[M]], %[[N]], %[[K]] : vector<4x8x[16]x4xi1> +// CHECK: %[[RED:.*]] = vector.mask %[[MRED]] { vector.multi_reduction , %[[MUL]], %[[VC]] [3] : vector<4x8x[16]x4xf32> to vector<4x8x[16]xf32> } : vector<4x8x[16]x4xi1> -> vector<4x8x[16]xf32> +// CHECK: %[[c0_5:.*]] = arith.constant 0 : index +// CHECK: vector.mask %[[MC]] { vector.transfer_write %[[RED]], %[[C]][%[[c0_5]], %[[c0_5]], %[[c0_5]]] {in_bounds = [true, true, true]} : vector<4x8x[16]xf32>, memref } : vector<4x8x[16]xi1> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %matmul vector_sizes [4, 8, [16], 4] : !transform.any_op + transform.yield + } +}