From 2a278ec2b95d68b11de6e85a37bd9d155779bcf8 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 16 May 2025 15:57:31 +0530 Subject: [PATCH 1/3] [MLIR][TORCH] Add shape verifier check for index_put op This commit adds a check to verify whether the shapes of the `values` operand of index_put op is broadcast compatible with the indexing result or not. Signed-off-by: Vivek Khandelwal --- .../TorchToTMTensor/TorchToTMTensor.cpp | 54 ++++++++++++++++--- 1 file changed, 47 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 01e1bbf0d26e..a2db29203b26 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -803,14 +803,42 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, return b.create(loc, valuesTy, values, outDimsList); } +// Check whether the shapes of the tensors are broadcastable or not. +// Two tensors are “broadcastable” if the following rules hold: +// 1.) Each tensor has at least one dimension. +// 2.) When iterating over the dimension sizes, starting at the trailing +// dimension, the dimension sizes must either be equal, one of them is 1, or +// one of them does not exist. +static LogicalResult +areStaticallyBroadcastCompatible(ArrayRef shapeA, + ArrayRef shapeB) { + unsigned rankA = shapeA.size(); + unsigned rankB = shapeB.size(); + unsigned minRank = std::min(rankA, rankB); + + for (unsigned i = 0; i < minRank; i++) { + int64_t dimA = shapeA[rankA - i - 1]; + int64_t dimB = shapeB[rankB - i - 1]; + // Here, we only check the static dimensions for compatibility. + if (dimA == Torch::kUnknownSize || dimB == Torch::kUnknownSize) + continue; + if (!(dimA == dimB || dimA == 1 || dimB == 1)) + return failure(); + } + + return success(); +} + // Broadcast the `values` tensor to the slice size created by the list of index // tensors. -static Value broadcastValuesToSliceSize(Location loc, Value input, Value values, - llvm::ArrayRef indices, - OpBuilder b) { +static LogicalResult broadcastValuesToSliceSize(Location loc, Value input, + Value values, + llvm::ArrayRef indices, + OpBuilder b, Value &result) { auto inputType = cast(input.getType()); ArrayRef inputStaticShape = inputType.getSizes(); auto valuesType = cast(values.getType()); + ArrayRef valuesStaticShape = valuesType.getSizes(); // In the case where the input rank is greater than the number of index // tensors, the remaining dimensions of the input are indexed in their @@ -823,12 +851,20 @@ static Value broadcastValuesToSliceSize(Location loc, Value input, Value values, resultStaticShape.push_back(inputStaticShape[i]); } + // Check if the values tensor is broadcast compatible with indexing result + // shape or not. Here, we only check the static dimensions the dynamic ones + // will be caught by the downstream lowering. + if (failed(areStaticallyBroadcastCompatible(valuesStaticShape, + resultStaticShape))) + return failure(); + auto resultType = b.getType( resultStaticShape, valuesType.getOptionalDtype()); Value broadcastShapeList = b.create( loc, Torch::ListType::get(b.getType()), resultShape); - return b.create(loc, resultType, values, - broadcastShapeList); + result = + b.create(loc, resultType, values, broadcastShapeList); + return success(); } class ConvertAtenIndexPutHackedTwinOp @@ -878,8 +914,12 @@ class ConvertAtenIndexPutHackedTwinOp if (optionalIndicesCount == 0) return rewriter.notifyMatchFailure(op, "Indices list must not be empty."); - values = broadcastValuesToSliceSize(loc, input, values, optionalIndicesList, - rewriter); + if (failed(broadcastValuesToSliceSize(loc, input, values, + optionalIndicesList, rewriter, + /*result=*/values))) + return rewriter.notifyMatchFailure( + op, "values tensor cannot be broadcast to indexing result shape."); + // Filter to available indices and get the indicesMap: SmallVector indicesList; SmallVector indicesMap; From aa2f25eefdeae04b3a67dfc50e9b064bc19f4233 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 19 May 2025 14:17:12 +0530 Subject: [PATCH 2/3] Remove check and add op verifier --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + .../torch-mlir/Dialect/Torch/Utils/Utils.h | 4 + .../TorchToTMTensor/TorchToTMTensor.cpp | 54 ++---------- lib/Dialect/Torch/IR/TorchOps.cpp | 83 +++++++++++++++++++ lib/Dialect/Torch/Utils/Utils.cpp | 26 ++++++ .../build_tools/torch_ods_gen.py | 3 +- test/Dialect/Torch/invalid.mlir | 10 +++ 7 files changed, 133 insertions(+), 48 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4c783fd3b495..a4757463d0a2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6351,6 +6351,7 @@ def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ printDefaultTorchOp(printer, *this, 4, 1); } }]; + let hasVerifier = 1; } def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [ diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index a000b7ab2f98..53598e451751 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -158,6 +158,10 @@ LogicalResult getPermutedType(BaseTensorType inType, SmallVector permuteDims, Type &permutedType); +// Check whether the given shapes of 2 tensors are broadcastable or not. +LogicalResult areStaticallyBroadcastCompatible(ArrayRef shapeA, + ArrayRef shapeB); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index a2db29203b26..01e1bbf0d26e 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -803,42 +803,14 @@ static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, return b.create(loc, valuesTy, values, outDimsList); } -// Check whether the shapes of the tensors are broadcastable or not. -// Two tensors are “broadcastable” if the following rules hold: -// 1.) Each tensor has at least one dimension. -// 2.) When iterating over the dimension sizes, starting at the trailing -// dimension, the dimension sizes must either be equal, one of them is 1, or -// one of them does not exist. -static LogicalResult -areStaticallyBroadcastCompatible(ArrayRef shapeA, - ArrayRef shapeB) { - unsigned rankA = shapeA.size(); - unsigned rankB = shapeB.size(); - unsigned minRank = std::min(rankA, rankB); - - for (unsigned i = 0; i < minRank; i++) { - int64_t dimA = shapeA[rankA - i - 1]; - int64_t dimB = shapeB[rankB - i - 1]; - // Here, we only check the static dimensions for compatibility. - if (dimA == Torch::kUnknownSize || dimB == Torch::kUnknownSize) - continue; - if (!(dimA == dimB || dimA == 1 || dimB == 1)) - return failure(); - } - - return success(); -} - // Broadcast the `values` tensor to the slice size created by the list of index // tensors. -static LogicalResult broadcastValuesToSliceSize(Location loc, Value input, - Value values, - llvm::ArrayRef indices, - OpBuilder b, Value &result) { +static Value broadcastValuesToSliceSize(Location loc, Value input, Value values, + llvm::ArrayRef indices, + OpBuilder b) { auto inputType = cast(input.getType()); ArrayRef inputStaticShape = inputType.getSizes(); auto valuesType = cast(values.getType()); - ArrayRef valuesStaticShape = valuesType.getSizes(); // In the case where the input rank is greater than the number of index // tensors, the remaining dimensions of the input are indexed in their @@ -851,20 +823,12 @@ static LogicalResult broadcastValuesToSliceSize(Location loc, Value input, resultStaticShape.push_back(inputStaticShape[i]); } - // Check if the values tensor is broadcast compatible with indexing result - // shape or not. Here, we only check the static dimensions the dynamic ones - // will be caught by the downstream lowering. - if (failed(areStaticallyBroadcastCompatible(valuesStaticShape, - resultStaticShape))) - return failure(); - auto resultType = b.getType( resultStaticShape, valuesType.getOptionalDtype()); Value broadcastShapeList = b.create( loc, Torch::ListType::get(b.getType()), resultShape); - result = - b.create(loc, resultType, values, broadcastShapeList); - return success(); + return b.create(loc, resultType, values, + broadcastShapeList); } class ConvertAtenIndexPutHackedTwinOp @@ -914,12 +878,8 @@ class ConvertAtenIndexPutHackedTwinOp if (optionalIndicesCount == 0) return rewriter.notifyMatchFailure(op, "Indices list must not be empty."); - if (failed(broadcastValuesToSliceSize(loc, input, values, - optionalIndicesList, rewriter, - /*result=*/values))) - return rewriter.notifyMatchFailure( - op, "values tensor cannot be broadcast to indexing result shape."); - + values = broadcastValuesToSliceSize(loc, input, values, optionalIndicesList, + rewriter); // Filter to available indices and get the indicesMap: SmallVector indicesList; SmallVector indicesMap; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index eb2f697c2596..f4a12dc5b709 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -6086,6 +6086,89 @@ LogicalResult AtenCountNonzeroDimIntListOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenIndexPutOp +//===----------------------------------------------------------------------===// + +// Determine the common broadcast shape of all the index tensors. +SmallVector +getIndexBroadcastShape(SmallVector indicesTypes) { + int64_t indicesBroadcastRank = 0; + SmallVector indicesRank; + SmallVector> indicesShape; + for (auto indexTy : indicesTypes) { + indicesShape.push_back(indexTy.getSizes()); + int64_t rank = indexTy.getSizes().size(); + indicesRank.push_back(rank); + indicesBroadcastRank = std::max(rank, indicesBroadcastRank); + } + + auto maxDim = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return std::max(dim0, dim1); + }; + + SmallVector broadcastShape(indicesBroadcastRank, 0); + for (unsigned i = 0; i < indicesTypes.size(); i++) { + for (int32_t j = 0; j < indicesRank[i]; ++j) { + auto size = indicesShape[i][j]; + int32_t idx = broadcastShape.size() - indicesRank[i] + j; + broadcastShape[idx] = maxDim(size, broadcastShape[idx]); + } + } + return broadcastShape; +} + +LogicalResult AtenIndexPutOp::verify() { + if (isa(getIndices().getType())) + return success(); + + SmallVector indices; + if (!getListConstructElements(getIndices(), indices)) + return success(); + + SmallVector indicesTypes; + for (auto index : indices) { + auto indexTy = cast(index.getType()); + if (!indexTy.hasSizes()) + return success(); + indicesTypes.push_back(indexTy); + } + + auto inputType = cast(getSelf().getType()); + if (!inputType.hasSizes()) + return success(); + SmallVector inputShape(inputType.getSizes()); + + auto valuesType = cast(getValues().getType()); + if (!valuesType.hasSizes()) + return success(); + SmallVector valuesShape(valuesType.getSizes()); + + SmallVector indicesBroadcastShape( + getIndexBroadcastShape(indicesTypes)); + // In the case where the input rank is greater than the number of index + // tensors, the remaining dimensions of the input are indexed in their + // entirety. Thus, we need to append the remaining dimensions to get the shape + // of the indexed slice. + for (size_t i = indices.size(); i < inputShape.size(); i++) { + indicesBroadcastShape.push_back(inputShape[i]); + } + + // Check if the values tensor is broadcast compatible with indexing result + // shape or not. Here, we only check the static dimensions the dynamic ones + // will be caught by the downstream lowering through runtime checks. + if (failed( + areStaticallyBroadcastCompatible(valuesShape, indicesBroadcastShape))) + return emitOpError("values tensor shape [") + << valuesShape + << "] cannot be broadcasted to indexing result shape [" + << indicesBroadcastShape << "]\n"; + + return success(); +} + //===----------------------------------------------------------------------===// // OnnxVariantRotaryEmbeddingOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 388e31353571..11e3884c1e4c 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -709,3 +709,29 @@ Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { return rewriter.getI64Type(); return inputType; } + +// Check whether the shapes of the tensors are broadcastable or not. +// Two tensors are “broadcastable” if the following rules hold: +// 1.) Each tensor has at least one dimension. +// 2.) When iterating over the dimension sizes, starting at the trailing +// dimension, the dimension sizes must either be equal, one of them is 1, or +// one of them does not exist. +LogicalResult +Torch::areStaticallyBroadcastCompatible(ArrayRef shapeA, + ArrayRef shapeB) { + unsigned rankA = shapeA.size(); + unsigned rankB = shapeB.size(); + unsigned minRank = std::min(rankA, rankB); + + for (unsigned i = 0; i < minRank; i++) { + int64_t dimA = shapeA[rankA - i - 1]; + int64_t dimB = shapeB[rankB - i - 1]; + // Here, we only check the static dimensions for compatibility. + if (dimA == Torch::kUnknownSize || dimB == Torch::kUnknownSize) + continue; + if (!(dimA == dimB || dimA == 1 || dimB == 1)) + return failure(); + } + + return success(); +} diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 0f6694132aec..234cf99a9807 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -557,7 +557,8 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") emit_with_mutating_variants( - "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)" + "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)", + has_verifier=True, ) emit_with_mutating_variants( "aten::index_put.hacked_twin : (Tensor, Tensor[], Tensor, bool) -> (Tensor)" diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index c863e93fa5fa..c383c75f256b 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -403,3 +403,13 @@ func.func @torch.symbolic_int$no_shape_symbols(%arg0: !torch.vtensor<[?],f32>) - torch.bind_symbolic_shape %arg0, [%int0], affine_map<()[s0] -> (s0)> : !torch.vtensor<[?],f32> return %arg0 : !torch.vtensor<[?],f32> } + +// ----- + +func.func @index_put_values_shape_broadcast_incompatible(%arg0: !torch.vtensor<[?,32,16,192],f16>, %arg1: !torch.vtensor<[?],si64>, %arg2: !torch.vtensor<[?,32,128,192],f16>) -> !torch.vtensor<[?,32,16,192],f16> attributes {torch.onnx_meta.opset_version = 10 : si64} { + %0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[?],si64>) -> !torch.list> + %false = torch.constant.bool false + // expected-error @+1 {{'torch.aten.index_put' op values tensor shape [-1, 32, 128, 192] cannot be broadcasted to indexing result shape [-1, 32, 16, 192]}} + %1 = torch.aten.index_put %arg0, %0, %arg2, %false : !torch.vtensor<[?,32,16,192],f16>, !torch.list>, !torch.vtensor<[?,32,128,192],f16>, !torch.bool -> !torch.vtensor<[?,32,16,192],f16> + return %1 : !torch.vtensor<[?,32,16,192],f16> +} From 6690099f58fe9b6f3c7cc25411e85df8d5fb3a4c Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 19 May 2025 18:04:42 +0530 Subject: [PATCH 3/3] Handle none values in indices list --- lib/Dialect/Torch/IR/TorchOps.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index f4a12dc5b709..64ca695b8178 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -6130,10 +6130,12 @@ LogicalResult AtenIndexPutOp::verify() { SmallVector indicesTypes; for (auto index : indices) { - auto indexTy = cast(index.getType()); - if (!indexTy.hasSizes()) - return success(); - indicesTypes.push_back(indexTy); + // Skipping the none value in the indices list. + if (auto indexTy = dyn_cast(index.getType())) { + if (!indexTy.hasSizes()) + return success(); + indicesTypes.push_back(indexTy); + } } auto inputType = cast(getSelf().getType());