diff --git a/mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake b/mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake index 883f8e94a..e8c05f3a2 100644 --- a/mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake +++ b/mlir-tensorrt/build_tools/cmake/MTRTDependencies.cmake @@ -57,8 +57,8 @@ macro(configure_tensorrt_python_plugin_header) find_file( trt_python_plugin_header NAMES NvInferPythonPlugin.h plugin.h - HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl - PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl + HINTS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl ${ARG_INSTALL_DIR}/include/impl + PATHS ${ARG_INSTALL_DIR} ${ARG_INSTALL_DIR}/python/include/impl ${ARG_INSTALL_DIR}/include/impl REQUIRED NO_CMAKE_PATH NO_DEFAULT_PATH NO_CACHE diff --git a/mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake b/mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake index 8e0ec4942..88fa2026d 100644 --- a/mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake +++ b/mlir-tensorrt/build_tools/cmake/TensorRTDownloadURL.cmake @@ -138,6 +138,10 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_ set(ARG_VERSION "10.13.0.35") endif() + if(ARG_VERSION VERSION_EQUAL "10.14") + set(ARG_VERSION "10.14.1.48") + endif() + set(downloadable_versions "8.6.1.6" "9.0.1.4" "9.1.0.4" "9.2.0.5" @@ -156,6 +160,7 @@ function(mtrt_get_tensorrt_download_url ARG_VERSION OS_NAME TARGET_ARCH ARG_OUT_ "10.9.0.34" "10.12.0.36" "10.13.0.35" + "10.14.1.48" ) if(NOT ARG_VERSION IN_LIST downloadable_versions) diff --git a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py index d57237885..b55f5d334 100644 --- a/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py +++ b/mlir-tensorrt/compiler/test/python/mlir_tensorrt_compiler/dialects/test_tensorrt.py @@ -49,6 +49,8 @@ def test_attributes(): tensorrt.TripLimitAttr.get("kWHILE"), tensorrt.FillOperationAttr.get("kRANDOM_UNIFORM"), tensorrt.ScatterModeAttr.get("kELEMENT"), + tensorrt.AttentionNormalizationOpAttr.get("kSOFTMAX"), + tensorrt.DataTypeAttr.get("kFLOAT"), ]: print(attr) @@ -74,3 +76,5 @@ def test_attributes(): # CHECK-NEXT: #tensorrt.trip_limit # CHECK-NEXT: #tensorrt.fill_operation # CHECK-NEXT: #tensorrt.scatter_mode +# CHECK-NEXT: #tensorrt.attention_normalization_op +# CHECK-NEXT: #tensorrt.data_type diff --git a/mlir-tensorrt/integrations/python/bindings/Compiler/DialectTensorRT.cpp b/mlir-tensorrt/integrations/python/bindings/Compiler/DialectTensorRT.cpp index 0e134a405..2dc6d3167 100644 --- a/mlir-tensorrt/integrations/python/bindings/Compiler/DialectTensorRT.cpp +++ b/mlir-tensorrt/integrations/python/bindings/Compiler/DialectTensorRT.cpp @@ -77,4 +77,6 @@ PYBIND11_MODULE(_tensorrt, m) { ADD_PYTHON_ATTRIBUTE_ADAPTOR(TripLimit) ADD_PYTHON_ATTRIBUTE_ADAPTOR(FillOperation) ADD_PYTHON_ATTRIBUTE_ADAPTOR(ScatterMode) + ADD_PYTHON_ATTRIBUTE_ADAPTOR(AttentionNormalizationOp) + ADD_PYTHON_ATTRIBUTE_ADAPTOR(DataType) } diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h index ed5e9d336..a5fefab71 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect-c/TensorRTAttributes.h @@ -188,6 +188,22 @@ DECLARE_ATTR_GETTER_FROM_STRING(ScatterMode) DECLARE_IS_ATTR(ScatterMode) DECLARE_STRING_GETTER_FROM_ATTR(ScatterMode) +//===----------------------------------------------------------------------===// +// AttentionNormalizationOp +//===----------------------------------------------------------------------===// + +DECLARE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp) +DECLARE_IS_ATTR(AttentionNormalizationOp) +DECLARE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp) + +//===----------------------------------------------------------------------===// +// DataType +//===----------------------------------------------------------------------===// + +DECLARE_ATTR_GETTER_FROM_STRING(DataType) +DECLARE_IS_ATTR(DataType) +DECLARE_STRING_GETTER_FROM_ATTR(DataType) + #ifdef __cplusplus } #endif diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td index 0bb4e91fd..4d4fd144e 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTEnums.td @@ -378,4 +378,42 @@ def TensorRT_ScatterMode : TensorRT_I32EnumAttr< def TensorRT_ScatterModeAttr : TensorRT_EnumAttr{ } +def TensorRT_AttentionNormalizationOp : TensorRT_I32EnumAttr< + "AttentionNormalizationOp", "", + [ + I32EnumAttrCase<"kNONE", 0>, + I32EnumAttrCase<"kSOFTMAX", 1> + ]> +{ + let cppNamespace = "::mlir::tensorrt"; + let genSpecializedAttr = 0; +} + +def TensorRT_AttentionNormalizationOpAttr : TensorRT_EnumAttr{ +} + +def TensorRT_DataType : TensorRT_I32EnumAttr< + "DataType", "", + [ + I32EnumAttrCase<"kFLOAT", 0>, + I32EnumAttrCase<"kHALF", 1>, + I32EnumAttrCase<"kINT8", 2>, + I32EnumAttrCase<"kINT32", 3>, + I32EnumAttrCase<"kBOOL", 4>, + I32EnumAttrCase<"kUINT8", 5>, + I32EnumAttrCase<"kFP8", 6>, + I32EnumAttrCase<"kBF16", 7>, + I32EnumAttrCase<"kINT64", 8>, + I32EnumAttrCase<"kINT4", 9>, + I32EnumAttrCase<"kFP4", 10>, + I32EnumAttrCase<"kE8M0", 11> + ]> +{ + let cppNamespace = "::mlir::tensorrt"; + let genSpecializedAttr = 0; +} + +def TensorRT_DataTypeAttr : TensorRT_EnumAttr{ +} + #endif // MLIR_TENSORRT_DIALECT_TENSORRT_IR_TENSORRTENUMS diff --git a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td index 0da85abad..7c71106ee 100644 --- a/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td +++ b/mlir-tensorrt/tensorrt/include/mlir-tensorrt-dialect/TensorRT/IR/TensorRTOps.td @@ -3507,6 +3507,171 @@ def TensorRT_DequantizeOp : TensorRT_Op<"dequantize", }]; } +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +def TensorRT_AttentionOp : TensorRT_Op<"attention", + [Pure, AttrSizedOperandSegments, TensorRTPartiallyInferTensorResultTypes, + AllElementTypesMatch<["query", "key", "value"]>, + AllRanksMatch<["query", "key", "value"]>]>{ + let summary = "TensorRT attention (IAttention) operation"; + let description = [{ + The `tensorrt.attention` operation implements a fused attention mechanism + that consumes query, key, and value tensors. The operation implicitly includes + two matrix multiplication layers (BMM1 and BMM2) and a normalization operation + (typically softmax). + + By default, TensorRT will try to use a single fused kernel for better efficiency. + The operation can optionally be decomposed into multiple kernels if no fused + kernel is available by setting `decomposable` to true. + + #### Architecture: + + ``` + Query Key Value Mask (optional) NormalizationQuantizeScale (optional) + | | | | | + | Transpose | | | + | | | | | + ----BMM1---- | | | + | | | | + *--------------------------- | + | | | + Normalization | | + | | | + *------------------------------------------------ + | | + -------BMM2------ + | + Output + ``` + + #### Inputs: + + - Query: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsQuery, sequenceLengthQuery, dimHead] + - Key: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead] + - Value: tensor of type f32, f16, or bf16 with shape + [batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead] + - Mask (optional): tensor of type i1 or same type as BMM1 output with shape + [batchSize, numHeadsQuery, sequenceLengthQuery, sequenceLengthKeyValue] + where batchSize and numHeadsQuery are broadcastable. For i1 mask, true + indicates the position is allowed to attend. For other types, mask values + are added to BMM1 output. + - NormalizationQuantizeScale (optional): tensor of type f32, f16, or bf16 + with rank 0 (scalar) or 1 (1D tensor), used for quantizing the normalization output. + Required when normalization_quantize_to_type is specified. + + #### Attributes: + + - normalization_operation: The normalization operation to use (default: kSOFTMAX) + - causal: Whether to use causal masking (default: false). Cannot be used with mask input. + - decomposable: Whether the operation can be decomposed (default: false) + - normalization_quantize_to_type: Optional output type for quantized normalization. + When specified, must be one of kFP8 or kINT8. Requires normalization_quantize_scale input to be provided. + + #### Constraints: + + - All query, key, and value tensors must be rank 4 with shape [batchSize, numHeads, sequenceLength, dimHead] + - Query, key, and value must have the same element type (f32, f16, or bf16) + - If normalization_quantize_to_type is specified: + * It must be kFP8 or kINT8 + * normalization_quantize_scale input must be provided + - If normalization_quantize_scale is provided: + * normalization_quantize_to_type must be specified + * Element type must be f32, f16, or bf16 + * Rank must be 0 (scalar) or 1 (1D tensor) + - Cannot use both mask input and causal=true simultaneously + + #### Examples: + + Basic attention: + ```mlir + %output = tensorrt.attention ins(%query, %key, %value : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + ``` + + Causal attention: + ```mlir + %output_causal = tensorrt.attention {causal = true} ins(%query, %key, %value : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + ``` + + Attention with quantization: + ```mlir + %scale = tensorrt.constant dense<1.0> : tensor + %output_quant = tensorrt.attention { + normalization_quantize_to_type = #tensorrt.data_type + } ins(%query, %key, %value, + normalization_quantize_scale = %scale : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, + tensor<2x8x128x64xf16>, tensor) + -> tensor<2x8x128x64xf16> + ``` + }]; + + let arguments = (ins + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$query, + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$key, + TensorRT_RankedTensorOf<[F16, BF16, F32]>:$value, + Optional:$mask, + Optional>:$normalization_quantize_scale, + DefaultValuedAttr:$normalization_operation, + DefaultValuedAttr:$causal, + DefaultValuedAttr:$decomposable, + OptionalAttr:$normalization_quantize_to_type + ); + + let results = (outs TensorRT_RankedTensorOf<[F16, BF16, F32]>:$result); + + let assemblyFormat = [{ + attr-dict `ins` `(` $query `,` $key `,` $value + (`,` `mask` `=` $mask^)? + (`,` `normalization_quantize_scale` `=` $normalization_quantize_scale^)? + `:` type($query) `,` type($key) `,` type($value) + (`,` type($mask)^)? + (`,` type($normalization_quantize_scale)^)? + `)` `->` type($result) + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Returns true if created op is valid for TensorRT major version. + bool isValidForTensorRTVersion(int64_t trtMajorVersion); + }] # baseClassDeclaration; + + let trtLayerAdd = [{ + nvinfer1::IAttention *layer = $net->addAttention(*$query, *$key, *$value, *$normalization_operation, $causal); + if (!layer) + return failure(); + + if ($mask) + layer->setMask(*$mask); + + layer->setDecomposable($decomposable); + + if ($normalization_quantize_scale) { + layer->setNormalizationQuantizeScale(*$normalization_quantize_scale); + } + + if ($normalization_quantize_to_type) { + auto convertedDataType = ::mlir::tensorrt::convertDataTypeToNvInferEnum(*$normalization_quantize_to_type); + if (!convertedDataType) + return emitError($op->getLoc()) << "failed to convert DataType to nvinfer enum"; + layer->setNormalizationQuantizeToType(*convertedDataType); + } + + $results.push_back(layer->getOutput(0)); + #if MLIR_TRT_COMPILE_TIME_TENSORRT_VERSION_GTE(10, 15, 0) + layer->setMetadata($op); + #endif + }]; +} + //===----------------------------------------------------------------------===// // TensorRT Dialect Extension Operations // diff --git a/mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp b/mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp index 50e87551c..456f3b11c 100644 --- a/mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp +++ b/mlir-tensorrt/tensorrt/lib/CAPI/TensorRTAttributes.cpp @@ -121,3 +121,11 @@ DEFINE_STRING_GETTER_FROM_ATTR(FillOperation) DEFINE_ATTR_GETTER_FROM_STRING(ScatterMode) DEFINE_IS_ATTR(ScatterMode) DEFINE_STRING_GETTER_FROM_ATTR(ScatterMode) + +DEFINE_ATTR_GETTER_FROM_STRING(AttentionNormalizationOp) +DEFINE_IS_ATTR(AttentionNormalizationOp) +DEFINE_STRING_GETTER_FROM_ATTR(AttentionNormalizationOp) + +DEFINE_ATTR_GETTER_FROM_STRING(DataType) +DEFINE_IS_ATTR(DataType) +DEFINE_STRING_GETTER_FROM_ATTR(DataType) diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp index 19d4b3c79..06522df3c 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TensorRTVersionCompatibility.cpp @@ -915,3 +915,16 @@ bool tensorrt::ScatterElementsOp::isValidForTensorRTVersion( return isValidForTensorRTVersionScatterOpImpl( trtMajorVersion, dataElementType, indicesElementType); } + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +bool tensorrt::AttentionOp::isValidForTensorRTVersion( + int64_t trtMajorVersion) { + // IAttention layer is only supported in TensorRT >= 10.14.0 + if (trtMajorVersion < 10) + return false; + + return true; +} diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp index e0ad4a1fc..96d107305 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/TypeInferenceInterfaceImpls.cpp @@ -1633,3 +1633,21 @@ LogicalResult tensorrt::DequantizeOp::inferReturnTypeComponents( /*elementType=*/nullptr); return success(); } + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +LogicalResult tensorrt::AttentionOp::inferReturnTypeComponents( + MLIRContext *ctx, std::optional loc, ValueShapeRange operands, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, + SmallVectorImpl &inferredReturnShapes) { + AttentionOp::Adaptor adaptor(operands, attributes, properties, regions); + auto queryType = dyn_cast(adaptor.getQuery().getType()); + if (!queryType) + return emitOptionalError(loc, "expected query to be a ranked tensor"); + inferredReturnShapes.emplace_back( + /*vec=*/queryType.getShape(), + /*elementType=*/queryType.getElementType()); + return success(); +} diff --git a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp index f9ea90fad..7761a84ff 100644 --- a/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp +++ b/mlir-tensorrt/tensorrt/lib/TensorRT/IR/Verification.cpp @@ -1466,3 +1466,59 @@ static LogicalResult verifyAllowedDataTypes(UnaryOp op) { LogicalResult tensorrt::UnaryOp::verify() { return verifyAllowedDataTypes(*this); } + +//===----------------------------------------------------------------------===// +// AttentionOp +//===----------------------------------------------------------------------===// + +LogicalResult tensorrt::AttentionOp::verify() { + // Check 1: Cannot use both mask input and causal=true simultaneously + if (getMask() && getCausal()) + return emitOpError( + "cannot use both mask input and causal=true simultaneously"); + + // Check 2: If normalization_quantize_to_type is specified, it must be kFP8 + // or kINT8 and normalization_quantize_scale must be provided + std::optional quantizeType = getNormalizationQuantizeToType(); + if (quantizeType.has_value()) { + if (*quantizeType != DataType::kFP8 && *quantizeType != DataType::kINT8) + return emitOpError("normalization_quantize_to_type must be kFP8 or " + "kINT8, but got ") + << stringifyDataType(*quantizeType); + + if (!getNormalizationQuantizeScale()) + return emitOpError( + "normalization_quantize_scale input must be provided when " + "normalization_quantize_to_type is specified"); + } + + // Check 3: If normalization_quantize_scale is provided, + // normalization_quantize_to_type must be specified + if (getNormalizationQuantizeScale() && !quantizeType.has_value()) + return emitOpError( + "normalization_quantize_to_type must be specified when " + "normalization_quantize_scale input is provided"); + + // Check 4: If normalization_quantize_scale is provided, validate its type + if (getNormalizationQuantizeScale()) { + RankedTensorType scaleType = getNormalizationQuantizeScale().getType(); + Type scaleElemType = scaleType.getElementType(); + + // Check that element type is f32, f16, or bf16 + if (!scaleElemType.isF32() && !scaleElemType.isF16() && + !scaleElemType.isBF16()) + return emitOpError( + "normalization_quantize_scale element type must be f32, f16, " + "or bf16, but got ") + << scaleElemType; + + // Check that scale is rank 0 or 1 + if (scaleType.getRank() != 0 && scaleType.getRank() != 1) + return emitOpError( + "normalization_quantize_scale must be rank 0 or 1, but got " + "rank ") + << scaleType.getRank(); + } + + return success(); +} diff --git a/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/attention.mlir b/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/attention.mlir new file mode 100644 index 000000000..c4efe1bf5 --- /dev/null +++ b/mlir-tensorrt/tensorrt/test/Target/TensorRT/TRT10/attention.mlir @@ -0,0 +1,72 @@ +// REQUIRES: tensorrt-version-ge-10.14 +// RUN: %pick-one-gpu tensorrt-opt -split-input-file -pass-pipeline="builtin.module(translate-tensorrt-to-engine)" \ +// RUN: -mlir-elide-elementsattrs-if-larger=32 -tensorrt-builder-opt-level=0 -tensorrt-strongly-typed %s | FileCheck %s +// RUN: %pick-one-gpu tensorrt-opt -split-input-file -pass-pipeline="builtin.module(translate-tensorrt-to-engine)" \ +// RUN: -mlir-elide-elementsattrs-if-larger=32 -tensorrt-builder-opt-level=0 %s | FileCheck %s + +// CHECK-LABEL: @trt_attention_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> { + %0 = tensorrt.attention ins(%arg0, %arg1, %arg2 : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} + +// CHECK-LABEL: @trt_attention_causal_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_causal_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> { + %0 = tensorrt.attention {causal = true} ins(%arg0, %arg1, %arg2 : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} + +// CHECK-LABEL: @trt_attention_with_mask_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_with_mask_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>, + %mask: tensor<2x8x128x128xf16>) + -> tensor<2x8x128x64xf16> { + %0 = tensorrt.attention ins(%arg0, %arg1, %arg2, mask = %mask : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x128xf16>) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} + +// CHECK-LABEL: @trt_attention_with_quantization_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_with_quantization_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> { + %scale = tensorrt.constant dense<1.0> : tensor + %0 = tensorrt.attention { + normalization_quantize_to_type = #tensorrt.data_type + } ins(%arg0, %arg1, %arg2, + normalization_quantize_scale = %scale : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, + tensor<2x8x128x64xf16>, tensor) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} + +// CHECK-LABEL: @trt_attention_decomposable_f16 +// CHECK-SAME: tensorrt.engine +func.func @trt_attention_decomposable_f16(%arg0: tensor<2x8x128x64xf16>, + %arg1: tensor<2x8x128x64xf16>, + %arg2: tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> { + %0 = tensorrt.attention {decomposable = true} ins(%arg0, %arg1, %arg2 : + tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>, tensor<2x8x128x64xf16>) + -> tensor<2x8x128x64xf16> + return %0 : tensor<2x8x128x64xf16> +} +