diff --git a/presto-common/src/main/java/com/facebook/presto/common/function/OperatorType.java b/presto-common/src/main/java/com/facebook/presto/common/function/OperatorType.java index 63b4ba595194a..ed5c9dba9d506 100644 --- a/presto-common/src/main/java/com/facebook/presto/common/function/OperatorType.java +++ b/presto-common/src/main/java/com/facebook/presto/common/function/OperatorType.java @@ -91,6 +91,12 @@ public boolean isArithmeticOperator() return this.equals(ADD) || this.equals(SUBTRACT) || this.equals(MULTIPLY) || this.equals(DIVIDE) || this.equals(MODULUS); } + public boolean isHashOperator() + { + return this.equals(HASH_CODE) || + this.equals(XX_HASH_64); + } + public static Optional tryGetOperatorType(QualifiedObjectName operatorName) { return Optional.ofNullable(OPERATOR_TYPES.get(operatorName)); diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 0b4a7ccc8393a..dd6c2efab9c58 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -884,6 +884,17 @@ This can also be specified on a per-query basis using the ``confidence_based_bro Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. This can also be specified on a per-query basis using the ``treat-low-confidence-zero-estimation-as-unknown`` session property. +``optimizer.scalar-function-stats-propagation-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Enable scalar functions stats propagation using annotations. Annotations define the behavior of the scalar +function's stats characteristics. When set to ``true``, this property enables the stats propagation through annotations. +This can also be specified on a per-query basis using the ``scalar_function_stats_propagation_enabled`` session property. + ``optimizer.retry-query-with-history-based-optimization`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index d039bd5595744..61f7174a9aaab 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -371,6 +371,7 @@ public final class SystemSessionProperties public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms"; public static final String WARN_ON_COMMON_NAN_PATTERNS = "warn_on_common_nan_patterns"; public static final String INLINE_PROJECTIONS_ON_VALUES = "inline_projections_on_values"; + public static final String SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED = "scalar_function_stats_propagation_enabled"; private final List> sessionProperties; @@ -2077,6 +2078,10 @@ public SystemSessionProperties( booleanProperty(INLINE_PROJECTIONS_ON_VALUES, "Whether to evaluate project node on values node", featuresConfig.getInlineProjectionsOnValues(), + false), + booleanProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, + "whether or not to respect stats propagation annotation for scalar functions (or UDF)", + featuresConfig.isScalarFunctionStatsPropagationEnabled(), false)); } @@ -3414,4 +3419,9 @@ public static boolean isInlineProjectionsOnValues(Session session) { return session.getSystemProperty(INLINE_PROJECTIONS_ON_VALUES, Boolean.class); } + + public static boolean shouldEnableScalarFunctionStatsPropagation(Session session) + { + return session.getSystemProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsAnnotationProcessor.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsAnnotationProcessor.java new file mode 100644 index 0000000000000..9469c2f551f60 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsAnnotationProcessor.java @@ -0,0 +1,181 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.StatsPropagationBehavior; +import com.facebook.presto.spi.relation.CallExpression; + +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.firstFiniteValue; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.getReturnTypeWidth; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.getTypeWidth; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.Constants.NON_NULL_ROW_COUNT_CONST; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.Constants.ROW_COUNT_CONST; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; +import static com.facebook.presto.util.MoreMath.max; +import static com.facebook.presto.util.MoreMath.min; +import static com.facebook.presto.util.MoreMath.minExcludingNaNs; +import static com.facebook.presto.util.MoreMath.nearlyEqual; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.Double.NaN; +import static java.lang.Double.isFinite; +import static java.lang.Double.isNaN; + +public final class ScalarStatsAnnotationProcessor +{ + private ScalarStatsAnnotationProcessor() + { + } + + public static VariableStatsEstimate computeStatsFromAnnotations( + CallExpression callExpression, + List sourceStats, + ScalarStatsHeader scalarStatsHeader, + double outputRowCount) + { + double nullFraction = scalarStatsHeader.getNullFraction(); + double distinctValuesCount = NaN; + double averageRowSize = NaN; + double maxValue = scalarStatsHeader.getMax(); + double minValue = scalarStatsHeader.getMin(); + for (Map.Entry paramIndexToStatsMap : scalarStatsHeader.getArgumentStats().entrySet()) { + ScalarPropagateSourceStats scalarPropagateSourceStats = paramIndexToStatsMap.getValue(); + boolean propagateAllStats = scalarPropagateSourceStats.propagateAllStats(); + nullFraction = min(firstFiniteValue(nullFraction, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getNullsFraction).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.nullFraction()))), 1.0); + distinctValuesCount = firstFiniteValue(distinctValuesCount, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getDistinctValuesCount).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.distinctValuesCount()))); + StatsPropagationBehavior averageRowSizeStatsBehaviour = applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.avgRowSize()); + averageRowSize = minExcludingNaNs(firstFiniteValue(averageRowSize, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getAverageRowSize).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + averageRowSizeStatsBehaviour)), getReturnTypeWidth(callExpression, averageRowSizeStatsBehaviour)); + maxValue = firstFiniteValue(maxValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getHighValue).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.maxValue()))); + minValue = firstFiniteValue(minValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getLowValue).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.minValue()))); + } + if (isNaN(maxValue) || isNaN(minValue)) { + minValue = NaN; + maxValue = NaN; + } + return VariableStatsEstimate.builder() + .setLowValue(minValue) + .setHighValue(maxValue) + .setNullsFraction(nullFraction) + .setAverageRowSize(firstFiniteValue(scalarStatsHeader.getAvgRowSize(), averageRowSize, getReturnTypeWidth(callExpression, UNKNOWN))) + .setDistinctValuesCount(processDistinctValuesCount(outputRowCount, nullFraction, scalarStatsHeader.getDistinctValuesCount(), distinctValuesCount)).build(); + } + + private static double processDistinctValuesCount(double outputRowCount, double nullFraction, double distinctValuesCountFromConstant, double distinctValuesCount) + { + if (isFinite(distinctValuesCountFromConstant)) { + if (nearlyEqual(distinctValuesCountFromConstant, NON_NULL_ROW_COUNT_CONST, 0.1)) { + distinctValuesCountFromConstant = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0)); + } + else if (nearlyEqual(distinctValuesCountFromConstant, ROW_COUNT_CONST, 0.1)) { + distinctValuesCountFromConstant = outputRowCount; + } + } + double distinctValuesCountFinal = firstFiniteValue(distinctValuesCountFromConstant, distinctValuesCount); + if (distinctValuesCountFinal > outputRowCount) { + distinctValuesCountFinal = NaN; + } + return distinctValuesCountFinal; + } + + private static double processSingleArgumentStatistic( + double outputRowCount, + double nullFraction, + CallExpression callExpression, + List sourceStats, + int sourceStatsArgumentIndex, + StatsPropagationBehavior operation) + { + // sourceStatsArgumentIndex is index of the argument on which + // ScalarPropagateSourceStats annotation was applied. + double statValue = NaN; + if (operation.isMultiArgumentStat()) { + for (int i = 0; i < sourceStats.size(); i++) { + if (i == 0 && operation.isSourceStatsDependentStats() && isFinite(sourceStats.get(i))) { + statValue = sourceStats.get(i); + } + else { + switch (operation) { + case MAX_TYPE_WIDTH_VARCHAR: + statValue = getTypeWidth(callExpression.getArguments().get(i).getType()); + break; + case USE_MIN_ARGUMENT: + statValue = min(statValue, sourceStats.get(i)); + break; + case USE_MAX_ARGUMENT: + statValue = max(statValue, sourceStats.get(i)); + break; + case SUM_ARGUMENTS: + statValue = statValue + sourceStats.get(i); + break; + } + } + } + } + else { + switch (operation) { + case USE_SOURCE_STATS: + statValue = sourceStats.get(sourceStatsArgumentIndex); + break; + case ROW_COUNT: + statValue = outputRowCount; + break; + case NON_NULL_ROW_COUNT: + statValue = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0)); + break; + case USE_TYPE_WIDTH_VARCHAR: + statValue = getTypeWidth(callExpression.getArguments().get(sourceStatsArgumentIndex).getType()); + break; + case LOG10_SOURCE_STATS: + statValue = Math.log10(sourceStats.get(sourceStatsArgumentIndex)); + break; + case LOG2_SOURCE_STATS: + statValue = Math.log(sourceStats.get(sourceStatsArgumentIndex)) / Math.log(2); + break; + case LOG_NATURAL_SOURCE_STATS: + statValue = Math.log(sourceStats.get(sourceStatsArgumentIndex)); + } + } + return statValue; + } + + private static StatsPropagationBehavior applyPropagateAllStats( + boolean propagateAllStats, StatsPropagationBehavior operation) + { + if (operation == UNKNOWN && propagateAllStats) { + return USE_SOURCE_STATS; + } + return operation; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java index 044785ca0a440..a18e0a2396ce8 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java @@ -13,14 +13,20 @@ */ package com.facebook.presto.cost; +import com.facebook.presto.FullConnectorSession; import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.common.QualifiedObjectName; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.metadata.BuiltInFunctionHandle; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.InputReferenceExpression; @@ -53,11 +59,23 @@ import javax.inject.Inject; +import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.OptionalDouble; +import java.util.stream.IntStream; import static com.facebook.presto.common.function.OperatorType.DIVIDE; import static com.facebook.presto.common.function.OperatorType.MODULUS; +import static com.facebook.presto.cost.ScalarStatsAnnotationProcessor.computeStatsFromAnnotations; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.computeArithmeticBinaryStatistics; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.computeCastStatistics; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.computeComparisonOperatorStatistics; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.computeConcatStatistics; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.computeHashCodeOperatorStatistics; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.computeNegationStatistics; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.computeYearFunctionStatistics; +import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.getTypeWidth; import static com.facebook.presto.cost.StatsUtil.toStatsRepresentation; import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; @@ -66,7 +84,10 @@ import static com.facebook.presto.sql.relational.Expressions.isNull; import static com.facebook.presto.util.MoreMath.max; import static com.facebook.presto.util.MoreMath.min; +import static com.facebook.presto.util.MoreMath.nearlyEqual; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.Double.NaN; import static java.lang.Double.isFinite; import static java.lang.Double.isNaN; @@ -107,23 +128,30 @@ private class RowExpressionStatsVisitor private final PlanNodeStatsEstimate input; private final ConnectorSession session; private final FunctionResolution resolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + private final boolean isStatsPropagationEnabled; public RowExpressionStatsVisitor(PlanNodeStatsEstimate input, ConnectorSession session) { this.input = requireNonNull(input, "input is null"); this.session = requireNonNull(session, "session is null"); + // casting session to FullConnectorSession is not ideal. + this.isStatsPropagationEnabled = + SystemSessionProperties.shouldEnableScalarFunctionStatsPropagation(((FullConnectorSession) session).getSession()); } @Override public VariableStatsEstimate visitCall(CallExpression call, Void context) { + List sourceStatsList = + IntStream.range(0, call.getArguments().size()).mapToObj(argumentIndex -> getSourceStats(call, context, argumentIndex)) + .collect(toImmutableList()); if (resolution.isNegateFunction(call.getFunctionHandle())) { - return computeNegationStatistics(call, context); + return computeNegationStatistics(sourceStatsList); } FunctionMetadata functionMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(call.getFunctionHandle()); if (functionMetadata.getOperatorType().map(OperatorType::isArithmeticOperator).orElse(false)) { - return computeArithmeticBinaryStatistics(call, context); + return computeArithmeticBinaryStatistics(functionMetadata, sourceStatsList, input.getOutputRowCount()); } RowExpression value = new RowExpressionOptimizer(metadata).optimize(call, OPTIMIZED, session); @@ -136,11 +164,12 @@ public VariableStatsEstimate visitCall(CallExpression call, Void context) return value.accept(this, context); } - // value is not a constant but we can still propagate estimation through cast + // value is not a constant, but we can still propagate estimation through cast if (resolution.isCastFunction(call.getFunctionHandle())) { - return computeCastStatistics(call, context); + return computeCastStatistics(call, metadata, sourceStatsList); } - return VariableStatsEstimate.unknown(); + + return computeStatsViaAnnotations(call, sourceStatsList, functionMetadata); } @Override @@ -159,7 +188,8 @@ public VariableStatsEstimate visitConstant(ConstantExpression literal, Void cont OptionalDouble doubleValue = toStatsRepresentation(metadata.getFunctionAndTypeManager(), session, literal.getType(), literal.getValue()); VariableStatsEstimate.Builder estimate = VariableStatsEstimate.builder() .setNullsFraction(0) - .setDistinctValuesCount(1); + .setDistinctValuesCount(1) + .setAverageRowSize(getTypeWidth(literal.getType())); if (doubleValue.isPresent()) { estimate.setLowValue(doubleValue.getAsDouble()); @@ -199,124 +229,45 @@ public VariableStatsEstimate visitSpecialForm(SpecialFormExpression specialForm, return VariableStatsEstimate.unknown(); } - private VariableStatsEstimate computeCastStatistics(CallExpression call, Void context) + private VariableStatsEstimate computeStatsViaAnnotations( + CallExpression call, + List sourceStatsList, + FunctionMetadata functionMetadata) { - requireNonNull(call, "call is null"); - VariableStatsEstimate sourceStats = call.getArguments().get(0).accept(this, context); - - // todo - make this general postprocessing rule. - double distinctValuesCount = sourceStats.getDistinctValuesCount(); - double lowValue = sourceStats.getLowValue(); - double highValue = sourceStats.getHighValue(); + if (isStatsPropagationEnabled) { - if (TypeUtils.isIntegralType(call.getType().getTypeSignature(), metadata.getFunctionAndTypeManager())) { - // todo handle low/high value changes if range gets narrower due to cast (e.g. BIGINT -> SMALLINT) - if (isFinite(lowValue)) { - lowValue = Math.round(lowValue); + if (functionMetadata.getOperatorType().map(OperatorType::isHashOperator).orElse(false)) { + return computeHashCodeOperatorStatistics(call, sourceStatsList, input.getOutputRowCount()); } - if (isFinite(highValue)) { - highValue = Math.round(highValue); - } - if (isFinite(lowValue) && isFinite(highValue)) { - double integersInRange = highValue - lowValue + 1; - if (!isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) { - distinctValuesCount = integersInRange; - } - } - } - return VariableStatsEstimate.builder() - .setNullsFraction(sourceStats.getNullsFraction()) - .setLowValue(lowValue) - .setHighValue(highValue) - .setDistinctValuesCount(distinctValuesCount) - .build(); - } - - private VariableStatsEstimate computeNegationStatistics(CallExpression call, Void context) - { - requireNonNull(call, "call is null"); - VariableStatsEstimate stats = call.getArguments().get(0).accept(this, context); - if (resolution.isNegateFunction(call.getFunctionHandle())) { - return VariableStatsEstimate.buildFrom(stats) - .setLowValue(-stats.getHighValue()) - .setHighValue(-stats.getLowValue()) - .build(); - } - throw new IllegalStateException(format("Unexpected sign: %s(%s)", call.getDisplayName(), call.getFunctionHandle())); - } - - private VariableStatsEstimate computeArithmeticBinaryStatistics(CallExpression call, Void context) - { - requireNonNull(call, "call is null"); - VariableStatsEstimate left = call.getArguments().get(0).accept(this, context); - VariableStatsEstimate right = call.getArguments().get(1).accept(this, context); - - VariableStatsEstimate.Builder result = VariableStatsEstimate.builder() - .setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())) - .setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()) - .setDistinctValuesCount(min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), input.getOutputRowCount())); - - FunctionMetadata functionMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(call.getFunctionHandle()); - checkState(functionMetadata.getOperatorType().isPresent()); - OperatorType operatorType = functionMetadata.getOperatorType().get(); - double leftLow = left.getLowValue(); - double leftHigh = left.getHighValue(); - double rightLow = right.getLowValue(); - double rightHigh = right.getHighValue(); - if (isNaN(leftLow) || isNaN(leftHigh) || isNaN(rightLow) || isNaN(rightHigh)) { - result.setLowValue(NaN).setHighValue(NaN); - } - else if (operatorType.equals(DIVIDE) && rightLow < 0 && rightHigh > 0) { - result.setLowValue(Double.NEGATIVE_INFINITY) - .setHighValue(Double.POSITIVE_INFINITY); - } - else if (operatorType.equals(MODULUS)) { - double maxDivisor = max(abs(rightLow), abs(rightHigh)); - if (leftHigh <= 0) { - result.setLowValue(max(-maxDivisor, leftLow)) - .setHighValue(0); + if (functionMetadata.getOperatorType().map(OperatorType::isComparisonOperator).orElse(false)) { + return computeComparisonOperatorStatistics(call, sourceStatsList); } - else if (leftLow >= 0) { - result.setLowValue(0) - .setHighValue(min(maxDivisor, leftHigh)); + + if (functionMetadata.getName().equals(QualifiedObjectName.valueOf("presto.default.concat"))) { + return computeConcatStatistics(call, sourceStatsList, input.getOutputRowCount()); } - else { - result.setLowValue(max(-maxDivisor, leftLow)) - .setHighValue(min(maxDivisor, leftHigh)); + + if (functionMetadata.getName().equals(QualifiedObjectName.valueOf("presto.default.year"))) { + return computeYearFunctionStatistics(call, sourceStatsList); } - } - else { - double v1 = operate(operatorType, leftLow, rightLow); - double v2 = operate(operatorType, leftLow, rightHigh); - double v3 = operate(operatorType, leftHigh, rightLow); - double v4 = operate(operatorType, leftHigh, rightHigh); - double lowValue = min(v1, v2, v3, v4); - double highValue = max(v1, v2, v3, v4); - result.setLowValue(lowValue) - .setHighValue(highValue); + if (functionMetadata.hasStatsHeader() && call.getFunctionHandle() instanceof BuiltInFunctionHandle) { + Signature signature = ((BuiltInFunctionHandle) call.getFunctionHandle()).getSignature().canonicalization(); + Optional statsHeader = functionMetadata.getScalarStatsHeader(signature); + if (statsHeader.isPresent()) { + return computeStatsFromAnnotations(call, sourceStatsList, statsHeader.get(), input.getOutputRowCount()); + } + } } - - return result.build(); + return VariableStatsEstimate.unknown(); } - private double operate(OperatorType operator, double left, double right) + private VariableStatsEstimate getSourceStats(CallExpression call, Void context, int argumentIndex) { - switch (operator) { - case ADD: - return left + right; - case SUBTRACT: - return left - right; - case MULTIPLY: - return left * right; - case DIVIDE: - return left / right; - case MODULUS: - return left % right; - default: - throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator); - } + checkArgument(argumentIndex < call.getArguments().size(), + format("function argument index: %d >= %d (call argument size) for %s", argumentIndex, call.getArguments().size(), call)); + return call.getArguments().get(argumentIndex).accept(this, context); } } @@ -360,7 +311,8 @@ protected VariableStatsEstimate visitLiteral(Literal node, Void context) OptionalDouble doubleValue = toStatsRepresentation(metadata, session, type, value); VariableStatsEstimate.Builder estimate = VariableStatsEstimate.builder() .setNullsFraction(0) - .setDistinctValuesCount(1); + .setDistinctValuesCount(1) + .setAverageRowSize(getTypeWidth(type)); if (doubleValue.isPresent()) { estimate.setLowValue(doubleValue.getAsDouble()); @@ -551,10 +503,10 @@ protected VariableStatsEstimate visitCoalesceExpression(CoalesceExpression node, private static VariableStatsEstimate estimateCoalesce(PlanNodeStatsEstimate input, VariableStatsEstimate left, VariableStatsEstimate right) { // Question to reviewer: do you have a method to check if fraction is empty or saturated? - if (left.getNullsFraction() == 0) { + if (nearlyEqual(left.getNullsFraction(), 0, 0.00001)) { return left; } - else if (left.getNullsFraction() == 1.0) { + else if (nearlyEqual(left.getNullsFraction(), 1.0, 0.00001)) { return right; } else { diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculatorUtils.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculatorUtils.java new file mode 100644 index 0000000000000..9b3c288604063 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculatorUtils.java @@ -0,0 +1,302 @@ +package com.facebook.presto.cost; + +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.FixedWidthType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.StatsPropagationBehavior; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.type.TypeUtils; +import org.joda.time.DateTimeField; +import org.joda.time.chrono.ISOChronology; + +import java.util.List; + +import static com.facebook.presto.common.function.OperatorType.DIVIDE; +import static com.facebook.presto.common.function.OperatorType.MODULUS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN; +import static com.facebook.presto.util.MoreMath.max; +import static com.facebook.presto.util.MoreMath.min; +import static com.facebook.presto.util.MoreMath.minExcludingNaNs; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static java.lang.Double.NaN; +import static java.lang.Double.isFinite; +import static java.lang.Double.isNaN; +import static java.lang.Math.abs; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.DAYS; + +public final class ScalarStatsCalculatorUtils +{ + private ScalarStatsCalculatorUtils() + { + } + + public static double getTypeWidth(Type argumentType) + { + if (argumentType instanceof FixedWidthType) { + return ((FixedWidthType) argumentType).getFixedSize(); + } + if (argumentType instanceof VarcharType) { + if (!((VarcharType) argumentType).isUnbounded()) { + return ((VarcharType) argumentType).getLengthSafe(); + } + } + if (argumentType instanceof CharType) { + return ((CharType) argumentType).getLength(); + } + return NaN; + } + + public static double getReturnTypeWidth(CallExpression callExpression, StatsPropagationBehavior operation) + { + if (callExpression.getType() instanceof FixedWidthType) { + return ((FixedWidthType) callExpression.getType()).getFixedSize(); + } + if (callExpression.getType() instanceof CharType) { + return ((CharType) callExpression.getType()).getLength(); + } + if (callExpression.getType() instanceof VarcharType) { + VarcharType returnType = (VarcharType) callExpression.getType(); + if (!returnType.isUnbounded()) { + return returnType.getLengthSafe(); + } + if (operation == SUM_ARGUMENTS) { + // since return type is an unbounded varchar and operation is SUM_ARGUMENTS, + // calculating the type width by doing a SUM of each argument's varchar type bounds - if available. + double sum = 0; + for (RowExpression r : callExpression.getArguments()) { + double typeWidth; + if (r instanceof CallExpression) { // argument is another function call + typeWidth = getReturnTypeWidth((CallExpression) r, UNKNOWN); + } + else { + typeWidth = getTypeWidth(r.getType()); + } + if (typeWidth < 0) { + return NaN; + } + sum += typeWidth; + } + return sum; + } + } + return NaN; + } + + // Return first 'finite' value from values, else return values[0] + public static double firstFiniteValue(double... values) + { + checkArgument(values.length > 1); + for (double v : values) { + if (isFinite(v)) { + return v; + } + } + return values[0]; + } + + public static VariableStatsEstimate computeCastStatistics(CallExpression callExpression, Metadata metadata, List sourceStats) + { + requireNonNull(callExpression, "call is null"); + checkArgument(!sourceStats.isEmpty()); + // todo - make this general postprocessing rule. + double distinctValuesCount = sourceStats.get(0).getDistinctValuesCount(); + double lowValue = sourceStats.get(0).getLowValue(); + double highValue = sourceStats.get(0).getHighValue(); + + if (TypeUtils.isIntegralType(callExpression.getType().getTypeSignature(), metadata.getFunctionAndTypeManager())) { + // todo handle low/high value changes if range gets narrower due to cast (e.g. BIGINT -> SMALLINT) + if (isFinite(lowValue)) { + lowValue = Math.round(lowValue); + } + if (isFinite(highValue)) { + highValue = Math.round(highValue); + } + if (isFinite(lowValue) && isFinite(highValue)) { + double integersInRange = highValue - lowValue + 1; + if (!isNaN(distinctValuesCount) && distinctValuesCount > integersInRange) { + distinctValuesCount = integersInRange; + } + } + } + + return VariableStatsEstimate.builder() + .setNullsFraction(sourceStats.get(0).getNullsFraction()) + .setLowValue(lowValue) + .setHighValue(highValue) + .setDistinctValuesCount(distinctValuesCount) + .build(); + } + + public static VariableStatsEstimate computeArithmeticBinaryStatistics(FunctionMetadata functionMetadata, List sourceStats, double outputRowCount) + { + checkArgument(sourceStats.size() > 1); + VariableStatsEstimate left = sourceStats.get(0); + VariableStatsEstimate right = sourceStats.get(1); + + VariableStatsEstimate.Builder result = VariableStatsEstimate.builder() + .setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())) + .setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()) + .setDistinctValuesCount(min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), outputRowCount)); + checkState(functionMetadata.getOperatorType().isPresent()); + OperatorType operatorType = functionMetadata.getOperatorType().get(); + double leftLow = left.getLowValue(); + double leftHigh = left.getHighValue(); + double rightLow = right.getLowValue(); + double rightHigh = right.getHighValue(); + if (isNaN(leftLow) || isNaN(leftHigh) || isNaN(rightLow) || isNaN(rightHigh)) { + result.setLowValue(NaN).setHighValue(NaN); + } + else if (operatorType.equals(DIVIDE) && rightLow < 0 && rightHigh > 0) { + result.setLowValue(Double.NEGATIVE_INFINITY) + .setHighValue(Double.POSITIVE_INFINITY); + } + else if (operatorType.equals(MODULUS)) { + double maxDivisor = max(abs(rightLow), abs(rightHigh)); + if (leftHigh <= 0) { + result.setLowValue(max(-maxDivisor, leftLow)) + .setHighValue(0); + } + else if (leftLow >= 0) { + result.setLowValue(0) + .setHighValue(min(maxDivisor, leftHigh)); + } + else { + result.setLowValue(max(-maxDivisor, leftLow)) + .setHighValue(min(maxDivisor, leftHigh)); + } + } + else { + double v1 = operate(operatorType, leftLow, rightLow); + double v2 = operate(operatorType, leftLow, rightHigh); + double v3 = operate(operatorType, leftHigh, rightLow); + double v4 = operate(operatorType, leftHigh, rightHigh); + double lowValue = min(v1, v2, v3, v4); + double highValue = max(v1, v2, v3, v4); + + result.setLowValue(lowValue) + .setHighValue(highValue); + } + + return result.build(); + } + + private static double operate(OperatorType operator, double left, double right) + { + switch (operator) { + case ADD: + return left + right; + case SUBTRACT: + return left - right; + case MULTIPLY: + return left * right; + case DIVIDE: + return left / right; + case MODULUS: + return left % right; + default: + throw new IllegalStateException("Unsupported ArithmeticBinaryExpression.Operator: " + operator); + } + } + + public static VariableStatsEstimate computeNegationStatistics(List sourceStats) + { + VariableStatsEstimate stats = sourceStats.get(0); + return VariableStatsEstimate.buildFrom(stats) + .setLowValue(-stats.getHighValue()) + .setHighValue(-stats.getLowValue()) + .build(); + } + + public static VariableStatsEstimate computeConcatStatistics(CallExpression call, List sourceStats, double outputRowCount) + { // Concat function is specially handled since it is a generated function for all arity. + double nullFraction = NaN; + double ndv = NaN; + double avgRowSize = 0.0; + for (VariableStatsEstimate stat : sourceStats) { + if (isFinite(stat.getNullsFraction())) { + nullFraction = firstFiniteValue(nullFraction, 0.0); + nullFraction = max(nullFraction, stat.getNullsFraction()); + } + if (isFinite(stat.getDistinctValuesCount())) { + ndv = firstFiniteValue(ndv, 0.0); + ndv = max(ndv, stat.getDistinctValuesCount()); + } + if (isFinite(stat.getAverageRowSize())) { + avgRowSize += stat.getAverageRowSize(); + } + } + if (avgRowSize == 0.0) { + avgRowSize = NaN; + } + return VariableStatsEstimate.builder() + .setNullsFraction(nullFraction) + .setDistinctValuesCount(minExcludingNaNs(ndv, outputRowCount)) + .setAverageRowSize(minExcludingNaNs(getReturnTypeWidth(call, SUM_ARGUMENTS), avgRowSize)) + .build(); + } + + public static VariableStatsEstimate computeHashCodeOperatorStatistics(CallExpression call, List sourceStats, double outputRowCount) + { + requireNonNull(call, "call is null"); + checkArgument(sourceStats.size() == 1, + "exactly one argument expected for hash code operator scalar function"); + VariableStatsEstimate argStats = sourceStats.get(0); + if (argStats.isUnknown()) { + return VariableStatsEstimate.unknown(); + } + VariableStatsEstimate.Builder result = + VariableStatsEstimate.builder() + .setAverageRowSize(minExcludingNaNs(argStats.getAverageRowSize(), getReturnTypeWidth(call, UNKNOWN))) + .setNullsFraction(argStats.getNullsFraction()) + .setDistinctValuesCount(minExcludingNaNs(argStats.getDistinctValuesCount(), outputRowCount)); + return result.build(); + } + + public static VariableStatsEstimate computeComparisonOperatorStatistics(CallExpression call, List sourceStats) + { + requireNonNull(call, "call is null"); + if (sourceStats.size() != 2) { + return VariableStatsEstimate.unknown(); + } + VariableStatsEstimate left = sourceStats.get(0); + VariableStatsEstimate right = sourceStats.get(1); + VariableStatsEstimate.Builder result = + VariableStatsEstimate.builder() + .setAverageRowSize(getReturnTypeWidth(call, UNKNOWN)) + .setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()) + .setDistinctValuesCount(2.0); + return result.build(); + } + + public static VariableStatsEstimate computeYearFunctionStatistics(CallExpression call, List sourceStats) + { + ISOChronology utcChronology = ISOChronology.getInstanceUTC(); + DateTimeField year = utcChronology.year(); + + if (sourceStats.size() != 1 || call.getArguments().size() != 1) { + return VariableStatsEstimate.unknown(); + } + VariableStatsEstimate date = sourceStats.get(0); + VariableStatsEstimate.Builder result = VariableStatsEstimate.builder(); + if (isFinite(date.getLowValue()) && isFinite(date.getHighValue()) && call.getArguments().get(0).getType() instanceof DateType) { + int minYear = year.get(DAYS.toMillis(Double.valueOf(date.getLowValue()).longValue())); + int maxYear = year.get(DAYS.toMillis(Double.valueOf(date.getHighValue()).longValue())); + int ndv = maxYear - minYear + 1; + result.setDistinctValuesCount(minExcludingNaNs(ndv, date.getDistinctValuesCount())); + result.setLowValue(minYear); + result.setHighValue(maxYear); + } + result.setAverageRowSize(getReturnTypeWidth(call, UNKNOWN)) + .setNullsFraction(date.getNullsFraction()); + return result.build(); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index ae76fd532db64..8cbadf0284026 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -184,6 +184,7 @@ import com.facebook.presto.operator.scalar.MathFunctions; import com.facebook.presto.operator.scalar.MathFunctions.LegacyLogFunction; import com.facebook.presto.operator.scalar.MultimapFromEntriesFunction; +import com.facebook.presto.operator.scalar.ParametricScalar; import com.facebook.presto.operator.scalar.QuantileDigestFunctions; import com.facebook.presto.operator.scalar.Re2JRegexpFunctions; import com.facebook.presto.operator.scalar.Re2JRegexpReplaceLambdaFunction; @@ -1181,6 +1182,19 @@ else if (function instanceof SqlInvokedFunction) { sqlFunction.getVersion(), sqlFunction.getComplexTypeFunctionDescriptor()); } + else if (function instanceof ParametricScalar) { + ParametricScalar sqlFunction = (ParametricScalar) function; + return new FunctionMetadata( + signature.getName(), + signature.getArgumentTypes(), + signature.getReturnType(), + signature.getKind(), + JAVA, + function.isDeterministic(), + function.isCalledOnNullInput(), + sqlFunction.getComplexTypeFunctionDescriptor(), + sqlFunction.getScalarHeader().getSignatureToScalarStatsHeadersMap()); + } else { return new FunctionMetadata( signature.getName(), diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCardinalityFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCardinalityFunction.java index 58baf6c0d475a..9fc93f45856a8 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCardinalityFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ArrayCardinalityFunction.java @@ -17,6 +17,7 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.spi.function.TypeParameter; @@ -28,6 +29,7 @@ private ArrayCardinalityFunction() {} @TypeParameter("E") @SqlType(StandardTypes.BIGINT) + @ScalarFunctionConstantStats(minValue = 0) public static long arrayCardinality(@SqlType("array(E)") Block block) { return block.getPositionCount(); diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/CombineHashFunction.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/CombineHashFunction.java index 1c8d09b5648f1..33fb53c3639d9 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/CombineHashFunction.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/CombineHashFunction.java @@ -15,9 +15,11 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; import com.facebook.presto.spi.function.SqlType; import static com.facebook.presto.spi.function.SqlFunctionVisibility.HIDDEN; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; public final class CombineHashFunction { @@ -25,7 +27,11 @@ private CombineHashFunction() {} @ScalarFunction(value = "combine_hash", visibility = HIDDEN) @SqlType(StandardTypes.BIGINT) - public static long getHash(@SqlType(StandardTypes.BIGINT) long previousHashValue, @SqlType(StandardTypes.BIGINT) long value) + public static long getHash( + @SqlType(StandardTypes.BIGINT) long previousHashValue, + @ScalarPropagateSourceStats( + distinctValuesCount = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.BIGINT) long value) { return (31 * previousHashValue + value); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/HmacFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/HmacFunctions.java index dffefa72a00e5..a56b21fff9839 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/HmacFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/HmacFunctions.java @@ -16,12 +16,15 @@ import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; import com.facebook.presto.spi.function.SqlType; import com.google.common.hash.HashCode; import com.google.common.hash.HashFunction; import com.google.common.hash.Hashing; import io.airlift.slice.Slice; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; import static io.airlift.slice.Slices.wrappedBuffer; public final class HmacFunctions @@ -31,7 +34,11 @@ private HmacFunctions() {} @Description("Compute HMAC with MD5") @ScalarFunction @SqlType(StandardTypes.VARBINARY) - public static Slice hmacMd5(@SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.VARBINARY) Slice key) + @ScalarFunctionConstantStats(avgRowSize = 32) + public static Slice hmacMd5( + @ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.VARBINARY) Slice key) { return computeHash(Hashing.hmacMd5(key.getBytes()), slice); } @@ -39,7 +46,10 @@ public static Slice hmacMd5(@SqlType(StandardTypes.VARBINARY) Slice slice, @SqlT @Description("Compute HMAC with SHA1") @ScalarFunction @SqlType(StandardTypes.VARBINARY) - public static Slice hmacSha1(@SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.VARBINARY) Slice key) + @ScalarFunctionConstantStats(avgRowSize = 20) + public static Slice hmacSha1(@ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.VARBINARY) Slice key) { return computeHash(Hashing.hmacSha1(key.getBytes()), slice); } @@ -47,7 +57,10 @@ public static Slice hmacSha1(@SqlType(StandardTypes.VARBINARY) Slice slice, @Sql @Description("Compute HMAC with SHA256") @ScalarFunction @SqlType(StandardTypes.VARBINARY) - public static Slice hmacSha256(@SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.VARBINARY) Slice key) + @ScalarFunctionConstantStats(avgRowSize = 32) + public static Slice hmacSha256(@ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.VARBINARY) Slice key) { return computeHash(Hashing.hmacSha256(key.getBytes()), slice); } @@ -55,7 +68,10 @@ public static Slice hmacSha256(@SqlType(StandardTypes.VARBINARY) Slice slice, @S @Description("Compute HMAC with SHA512") @ScalarFunction @SqlType(StandardTypes.VARBINARY) - public static Slice hmacSha512(@SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.VARBINARY) Slice key) + @ScalarFunctionConstantStats(avgRowSize = 64) + public static Slice hmacSha512(@ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.VARBINARY) Slice key) { return computeHash(Hashing.hmacSha512(key.getBytes()), slice); } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java index 231e8db1b0aa8..6dfcf5812ad92 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/MathFunctions.java @@ -23,6 +23,8 @@ import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; @@ -63,6 +65,13 @@ import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE; import static com.facebook.presto.spi.function.FunctionKind.SCALAR; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.Constants.NON_NULL_ROW_COUNT_CONST; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.Constants.ROW_COUNT_CONST; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.LOG10_SOURCE_STATS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.LOG2_SOURCE_STATS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.LOG_NATURAL_SOURCE_STATS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_MAX_ARGUMENT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; import static com.facebook.presto.type.DecimalOperators.modulusScalarFunction; import static com.facebook.presto.type.DecimalOperators.modulusSignatureBuilder; import static com.facebook.presto.util.Failures.checkCondition; @@ -198,7 +207,8 @@ public static long absFloat(@SqlType(StandardTypes.REAL) long num) @Description("arc cosine") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double acos(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(minValue = 0, maxValue = Math.PI) + public static double acos(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.DOUBLE) double num) { return Math.acos(num); } @@ -206,7 +216,8 @@ public static double acos(@SqlType(StandardTypes.DOUBLE) double num) @Description("arc sine") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double asin(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(minValue = Math.PI / 2, maxValue = Math.PI / 2) + public static double asin(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.DOUBLE) double num) { return Math.asin(num); } @@ -214,7 +225,8 @@ public static double asin(@SqlType(StandardTypes.DOUBLE) double num) @Description("arc tangent") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double atan(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(minValue = Math.PI / 2, maxValue = Math.PI / 2) + public static double atan(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.DOUBLE) double num) { return Math.atan(num); } @@ -244,7 +256,7 @@ public static double binomialCdf( @Description("cube root") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double cbrt(@SqlType(StandardTypes.DOUBLE) double num) + public static double cbrt(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.DOUBLE) double num) { return Math.cbrt(num); } @@ -252,7 +264,7 @@ public static double cbrt(@SqlType(StandardTypes.DOUBLE) double num) @Description("round up to nearest integer") @ScalarFunction(value = "ceiling", alias = "ceil") @SqlType(StandardTypes.TINYINT) - public static long ceilingTinyint(@SqlType(StandardTypes.TINYINT) long num) + public static long ceilingTinyint(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.TINYINT) long num) { return num; } @@ -260,7 +272,7 @@ public static long ceilingTinyint(@SqlType(StandardTypes.TINYINT) long num) @Description("round up to nearest integer") @ScalarFunction(value = "ceiling", alias = "ceil") @SqlType(StandardTypes.SMALLINT) - public static long ceilingSmallint(@SqlType(StandardTypes.SMALLINT) long num) + public static long ceilingSmallint(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.SMALLINT) long num) { return num; } @@ -268,7 +280,7 @@ public static long ceilingSmallint(@SqlType(StandardTypes.SMALLINT) long num) @Description("round up to nearest integer") @ScalarFunction(value = "ceiling", alias = "ceil") @SqlType(StandardTypes.INTEGER) - public static long ceilingInteger(@SqlType(StandardTypes.INTEGER) long num) + public static long ceilingInteger(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.INTEGER) long num) { return num; } @@ -276,7 +288,7 @@ public static long ceilingInteger(@SqlType(StandardTypes.INTEGER) long num) @Description("round up to nearest integer") @ScalarFunction(alias = "ceil") @SqlType(StandardTypes.BIGINT) - public static long ceiling(@SqlType(StandardTypes.BIGINT) long num) + public static long ceiling(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.BIGINT) long num) { return num; } @@ -306,7 +318,8 @@ private Ceiling() {} @LiteralParameters({"p", "s", "rp"}) @SqlType("decimal(rp,0)") @Constraint(variable = "rp", expression = "p - s + min(s, 1)") - public static long ceilingShort(@LiteralParameter("s") long numScale, @SqlType("decimal(p, s)") long num) + public static long ceilingShort(@LiteralParameter("s") long numScale, + @ScalarPropagateSourceStats(nullFraction = USE_MAX_ARGUMENT) @SqlType("decimal(p, s)") long num) { long rescaleFactor = Decimals.longTenToNth((int) numScale); long increment = (num % rescaleFactor > 0) ? 1 : 0; @@ -400,7 +413,11 @@ public static long truncate(@SqlType(StandardTypes.REAL) long num, @SqlType(Stan @Description("cosine") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double cos(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(minValue = -1, maxValue = 1) + public static double cos( + @ScalarPropagateSourceStats( + distinctValuesCount = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.cos(num); } @@ -408,7 +425,12 @@ public static double cos(@SqlType(StandardTypes.DOUBLE) double num) @Description("hyperbolic cosine") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double cosh(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(minValue = 1) + public static double cosh( + @ScalarPropagateSourceStats( + distinctValuesCount = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) + @SqlType(StandardTypes.DOUBLE) double num) { return Math.cosh(num); } @@ -416,7 +438,8 @@ public static double cosh(@SqlType(StandardTypes.DOUBLE) double num) @Description("converts an angle in radians to degrees") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double degrees(@SqlType(StandardTypes.DOUBLE) double radians) + public static double degrees( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.DOUBLE) double radians) { return Math.toDegrees(radians); } @@ -424,6 +447,7 @@ public static double degrees(@SqlType(StandardTypes.DOUBLE) double radians) @Description("Euler's number") @ScalarFunction @SqlType(StandardTypes.DOUBLE) + @ScalarFunctionConstantStats(minValue = Math.E, maxValue = Math.E, nullFraction = 0, distinctValuesCount = 1) public static double e() { return Math.E; @@ -432,7 +456,10 @@ public static double e() @Description("Euler's number raised to the given power") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double exp(@SqlType(StandardTypes.DOUBLE) double num) + public static double exp( + @ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.exp(num); } @@ -546,7 +573,11 @@ public static long inverseBinomialCdf( @Description("natural logarithm") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double ln(@SqlType(StandardTypes.DOUBLE) double num) + public static double ln(@ScalarPropagateSourceStats( + minValue = LOG_NATURAL_SOURCE_STATS, + maxValue = LOG_NATURAL_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.log(num); } @@ -554,7 +585,11 @@ public static double ln(@SqlType(StandardTypes.DOUBLE) double num) @Description("logarithm to base 2") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double log2(@SqlType(StandardTypes.DOUBLE) double num) + public static double log2(@ScalarPropagateSourceStats( + minValue = LOG2_SOURCE_STATS, + maxValue = LOG2_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.log(num) / Math.log(2); } @@ -562,7 +597,11 @@ public static double log2(@SqlType(StandardTypes.DOUBLE) double num) @Description("logarithm to base 10") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double log10(@SqlType(StandardTypes.DOUBLE) double num) + public static double log10(@ScalarPropagateSourceStats( + minValue = LOG10_SOURCE_STATS, + maxValue = LOG10_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.log10(num); } @@ -627,6 +666,7 @@ public static long modFloat(@SqlType(StandardTypes.REAL) long num1, @SqlType(Sta @Description("the constant Pi") @ScalarFunction @SqlType(StandardTypes.DOUBLE) + @ScalarFunctionConstantStats(minValue = Math.PI, maxValue = Math.PI, distinctValuesCount = 1, nullFraction = 0) public static double pi() { return Math.PI; @@ -635,7 +675,9 @@ public static double pi() @Description("value raised to the power of exponent") @ScalarFunction(alias = "pow") @SqlType(StandardTypes.DOUBLE) - public static double power(@SqlType(StandardTypes.DOUBLE) double num, @SqlType(StandardTypes.DOUBLE) double exponent) + public static double power(@ScalarPropagateSourceStats( + nullFraction = USE_MAX_ARGUMENT, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num, @SqlType(StandardTypes.DOUBLE) double exponent) { return Math.pow(num, exponent); } @@ -643,7 +685,10 @@ public static double power(@SqlType(StandardTypes.DOUBLE) double num, @SqlType(S @Description("converts an angle in degrees to radians") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double radians(@SqlType(StandardTypes.DOUBLE) double degrees) + public static double radians( + @ScalarPropagateSourceStats( + distinctValuesCount = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double degrees) { return Math.toRadians(degrees); } @@ -651,6 +696,7 @@ public static double radians(@SqlType(StandardTypes.DOUBLE) double degrees) @Description("a pseudo-random value") @ScalarFunction(alias = "rand", deterministic = false) @SqlType(StandardTypes.DOUBLE) + @ScalarFunctionConstantStats(minValue = 0, maxValue = 1, distinctValuesCount = ROW_COUNT_CONST, nullFraction = 0) public static double random() { return ThreadLocalRandom.current().nextDouble(); @@ -659,7 +705,11 @@ public static double random() @Description("a pseudo-random number between 0 and value (exclusive)") @ScalarFunction(value = "random", alias = "rand", deterministic = false) @SqlType(StandardTypes.TINYINT) - public static long randomTinyint(@SqlType(StandardTypes.TINYINT) long value) + @ScalarFunctionConstantStats(minValue = 0) + public static long randomTinyint( + @ScalarPropagateSourceStats( + maxValue = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.TINYINT) long value) { checkCondition(value > 0, INVALID_FUNCTION_ARGUMENT, "bound must be positive"); return ThreadLocalRandom.current().nextInt((int) value); @@ -668,7 +718,10 @@ public static long randomTinyint(@SqlType(StandardTypes.TINYINT) long value) @Description("a pseudo-random number between 0 and value (exclusive)") @ScalarFunction(value = "random", alias = "rand", deterministic = false) @SqlType(StandardTypes.SMALLINT) - public static long randomSmallint(@SqlType(StandardTypes.SMALLINT) long value) + @ScalarFunctionConstantStats(minValue = 0) + public static long randomSmallint(@ScalarPropagateSourceStats( + maxValue = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.SMALLINT) long value) { checkCondition(value > 0, INVALID_FUNCTION_ARGUMENT, "bound must be positive"); return ThreadLocalRandom.current().nextInt((int) value); @@ -677,7 +730,10 @@ public static long randomSmallint(@SqlType(StandardTypes.SMALLINT) long value) @Description("a pseudo-random number between 0 and value (exclusive)") @ScalarFunction(value = "random", alias = "rand", deterministic = false) @SqlType(StandardTypes.INTEGER) - public static long randomInteger(@SqlType(StandardTypes.INTEGER) long value) + @ScalarFunctionConstantStats(minValue = 0) + public static long randomInteger(@ScalarPropagateSourceStats( + maxValue = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.INTEGER) long value) { checkCondition(value > 0, INVALID_FUNCTION_ARGUMENT, "bound must be positive"); return ThreadLocalRandom.current().nextInt((int) value); @@ -686,7 +742,10 @@ public static long randomInteger(@SqlType(StandardTypes.INTEGER) long value) @Description("a pseudo-random number between 0 and value (exclusive)") @ScalarFunction(alias = "rand", deterministic = false) @SqlType(StandardTypes.BIGINT) - public static long random(@SqlType(StandardTypes.BIGINT) long value) + @ScalarFunctionConstantStats(minValue = 0) + public static long random(@ScalarPropagateSourceStats( + maxValue = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.BIGINT) long value) { checkCondition(value > 0, INVALID_FUNCTION_ARGUMENT, "bound must be positive"); return ThreadLocalRandom.current().nextLong(value); @@ -695,6 +754,7 @@ public static long random(@SqlType(StandardTypes.BIGINT) long value) @Description("a cryptographically secure random number between 0 and 1 (exclusive)") @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) @SqlType(StandardTypes.DOUBLE) + @ScalarFunctionConstantStats(minValue = 0, maxValue = 1, distinctValuesCount = ROW_COUNT_CONST, nullFraction = 0) public static double secure_random() { SecureRandom random = SecureRandomGeneration.getNonBlocking(); @@ -704,7 +764,10 @@ public static double secure_random() @Description("a cryptographically secure random number between lower and upper (exclusive)") @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) @SqlType(StandardTypes.DOUBLE) - public static double secure_random(@SqlType(StandardTypes.DOUBLE) double lower, @SqlType(StandardTypes.DOUBLE) double upper) + @ScalarFunctionConstantStats(distinctValuesCount = NON_NULL_ROW_COUNT_CONST) + public static double secure_random( + @ScalarPropagateSourceStats(minValue = USE_SOURCE_STATS, nullFraction = USE_MAX_ARGUMENT) @SqlType(StandardTypes.DOUBLE) double lower, + @ScalarPropagateSourceStats(maxValue = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); SecureRandom random = SecureRandomGeneration.getNonBlocking(); @@ -716,7 +779,9 @@ public static double secure_random(@SqlType(StandardTypes.DOUBLE) double lower, @Description("a cryptographically secure random number between lower and upper (exclusive)") @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) @SqlType(StandardTypes.TINYINT) - public static long secureRandomTinyint(@SqlType(StandardTypes.TINYINT) long lower, @SqlType(StandardTypes.TINYINT) long upper) + public static long secureRandomTinyint( + @ScalarPropagateSourceStats(minValue = USE_SOURCE_STATS, nullFraction = USE_MAX_ARGUMENT) @SqlType(StandardTypes.TINYINT) long lower, + @ScalarPropagateSourceStats(maxValue = USE_SOURCE_STATS) @SqlType(StandardTypes.TINYINT) long upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); SecureRandom random = SecureRandomGeneration.getNonBlocking(); @@ -728,7 +793,9 @@ public static long secureRandomTinyint(@SqlType(StandardTypes.TINYINT) long lowe @Description("a cryptographically secure random number between lower and upper (exclusive)") @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) @SqlType(StandardTypes.SMALLINT) - public static long secureRandomSmallint(@SqlType(StandardTypes.SMALLINT) long lower, @SqlType(StandardTypes.SMALLINT) long upper) + public static long secureRandomSmallint( + @ScalarPropagateSourceStats(minValue = USE_SOURCE_STATS, nullFraction = USE_MAX_ARGUMENT) @SqlType(StandardTypes.SMALLINT) long lower, + @ScalarPropagateSourceStats(maxValue = USE_SOURCE_STATS) @SqlType(StandardTypes.SMALLINT) long upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); SecureRandom random = SecureRandomGeneration.getNonBlocking(); @@ -740,7 +807,9 @@ public static long secureRandomSmallint(@SqlType(StandardTypes.SMALLINT) long lo @Description("a cryptographically secure random number between lower and upper (exclusive)") @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) @SqlType(StandardTypes.INTEGER) - public static long secureRandomInteger(@SqlType(StandardTypes.INTEGER) long lower, @SqlType(StandardTypes.INTEGER) long upper) + public static long secureRandomInteger( + @ScalarPropagateSourceStats(minValue = USE_SOURCE_STATS, nullFraction = USE_MAX_ARGUMENT) @SqlType(StandardTypes.INTEGER) long lower, + @ScalarPropagateSourceStats(maxValue = USE_SOURCE_STATS) @SqlType(StandardTypes.INTEGER) long upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); SecureRandom random = SecureRandomGeneration.getNonBlocking(); @@ -752,7 +821,9 @@ public static long secureRandomInteger(@SqlType(StandardTypes.INTEGER) long lowe @Description("a cryptographically secure random number between lower and upper (exclusive)") @ScalarFunction(value = "secure_random", alias = "secure_rand", deterministic = false) @SqlType(StandardTypes.BIGINT) - public static long secureRandomBigint(@SqlType(StandardTypes.BIGINT) long lower, @SqlType(StandardTypes.BIGINT) long upper) + public static long secureRandomBigint( + @ScalarPropagateSourceStats(minValue = USE_SOURCE_STATS, nullFraction = USE_MAX_ARGUMENT) @SqlType(StandardTypes.BIGINT) long lower, + @ScalarPropagateSourceStats(maxValue = USE_SOURCE_STATS) @SqlType(StandardTypes.BIGINT) long upper) { checkCondition(lower < upper, INVALID_FUNCTION_ARGUMENT, "upper bound must be greater than lower bound"); SecureRandom random = SecureRandomGeneration.getNonBlocking(); @@ -1372,7 +1443,8 @@ else if (isNegative(num)) { @ScalarFunction @SqlType(StandardTypes.BIGINT) - public static long sign(@SqlType(StandardTypes.BIGINT) long num) + @ScalarFunctionConstantStats(minValue = -1, maxValue = 1, distinctValuesCount = 3) + public static long sign(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.BIGINT) long num) { return (long) Math.signum(num); } @@ -1380,7 +1452,8 @@ public static long sign(@SqlType(StandardTypes.BIGINT) long num) @Description("signum") @ScalarFunction("sign") @SqlType(StandardTypes.INTEGER) - public static long signInteger(@SqlType(StandardTypes.INTEGER) long num) + @ScalarFunctionConstantStats(minValue = -1, maxValue = 1, distinctValuesCount = 3) + public static long signInteger(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.INTEGER) long num) { return (long) Math.signum(num); } @@ -1388,7 +1461,8 @@ public static long signInteger(@SqlType(StandardTypes.INTEGER) long num) @Description("signum") @ScalarFunction("sign") @SqlType(StandardTypes.SMALLINT) - public static long signSmallint(@SqlType(StandardTypes.SMALLINT) long num) + @ScalarFunctionConstantStats(minValue = -1, maxValue = 1, distinctValuesCount = 3) + public static long signSmallint(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.SMALLINT) long num) { return (long) Math.signum(num); } @@ -1396,7 +1470,8 @@ public static long signSmallint(@SqlType(StandardTypes.SMALLINT) long num) @Description("signum") @ScalarFunction("sign") @SqlType(StandardTypes.TINYINT) - public static long signTinyint(@SqlType(StandardTypes.TINYINT) long num) + @ScalarFunctionConstantStats(minValue = -1, maxValue = 1, distinctValuesCount = 3) + public static long signTinyint(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.TINYINT) long num) { return (long) Math.signum(num); } @@ -1404,7 +1479,8 @@ public static long signTinyint(@SqlType(StandardTypes.TINYINT) long num) @Description("signum") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double sign(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(minValue = -1, maxValue = 1, distinctValuesCount = 3) + public static double sign(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.signum(num); } @@ -1412,7 +1488,8 @@ public static double sign(@SqlType(StandardTypes.DOUBLE) double num) @Description("signum") @ScalarFunction("sign") @SqlType(StandardTypes.REAL) - public static long signFloat(@SqlType(StandardTypes.REAL) long num) + @ScalarFunctionConstantStats(minValue = -1, maxValue = 1, distinctValuesCount = 3) + public static long signFloat(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.REAL) long num) { return floatToRawIntBits((Math.signum(intBitsToFloat((int) num)))); } @@ -1420,7 +1497,11 @@ public static long signFloat(@SqlType(StandardTypes.REAL) long num) @Description("sine") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double sin(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(minValue = -1, maxValue = 1) + public static double sin( + @ScalarPropagateSourceStats( + distinctValuesCount = USE_SOURCE_STATS, + nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.sin(num); } @@ -1428,7 +1509,10 @@ public static double sin(@SqlType(StandardTypes.DOUBLE) double num) @Description("square root") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double sqrt(@SqlType(StandardTypes.DOUBLE) double num) + public static double sqrt( + @ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.sqrt(num); } @@ -1436,7 +1520,9 @@ public static double sqrt(@SqlType(StandardTypes.DOUBLE) double num) @Description("tangent") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double tan(@SqlType(StandardTypes.DOUBLE) double num) + public static double tan(@ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.tan(num); } @@ -1444,7 +1530,9 @@ public static double tan(@SqlType(StandardTypes.DOUBLE) double num) @Description("hyperbolic tangent") @ScalarFunction @SqlType(StandardTypes.DOUBLE) - public static double tanh(@SqlType(StandardTypes.DOUBLE) double num) + public static double tanh(@ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Math.tanh(num); } @@ -1452,7 +1540,8 @@ public static double tanh(@SqlType(StandardTypes.DOUBLE) double num) @Description("test if value is not-a-number") @ScalarFunction("is_nan") @SqlType(StandardTypes.BOOLEAN) - public static boolean isNaN(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(distinctValuesCount = 1) + public static boolean isNaN(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Double.isNaN(num); } @@ -1460,7 +1549,8 @@ public static boolean isNaN(@SqlType(StandardTypes.DOUBLE) double num) @Description("test if value is finite") @ScalarFunction @SqlType(StandardTypes.BOOLEAN) - public static boolean isFinite(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(distinctValuesCount = 1) + public static boolean isFinite(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Doubles.isFinite(num); } @@ -1468,7 +1558,8 @@ public static boolean isFinite(@SqlType(StandardTypes.DOUBLE) double num) @Description("test if value is infinite") @ScalarFunction @SqlType(StandardTypes.BOOLEAN) - public static boolean isInfinite(@SqlType(StandardTypes.DOUBLE) double num) + @ScalarFunctionConstantStats(distinctValuesCount = 1) + public static boolean isInfinite(@ScalarPropagateSourceStats(nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num) { return Double.isInfinite(num); } @@ -1476,6 +1567,8 @@ public static boolean isInfinite(@SqlType(StandardTypes.DOUBLE) double num) @Description("constant representing not-a-number") @ScalarFunction("nan") @SqlType(StandardTypes.DOUBLE) + // Note: min and max cannot be set to NaN, as that implies nullfraction = 1.0 + @ScalarFunctionConstantStats(distinctValuesCount = 1, nullFraction = 0) public static double NaN() { return Double.NaN; @@ -1484,6 +1577,7 @@ public static double NaN() @Description("Infinity") @ScalarFunction @SqlType(StandardTypes.DOUBLE) + @ScalarFunctionConstantStats(distinctValuesCount = 1, nullFraction = 0) public static double infinity() { return Double.POSITIVE_INFINITY; @@ -1492,7 +1586,10 @@ public static double infinity() @Description("convert a number to a string in the given base") @ScalarFunction @SqlType("varchar(64)") - public static Slice toBase(@SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.BIGINT) long radix) + public static Slice toBase( + @ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType(StandardTypes.BIGINT) long value, @SqlType(StandardTypes.BIGINT) long radix) { checkRadix(radix); return utf8Slice(Long.toString(value, (int) radix)); @@ -1502,7 +1599,9 @@ public static Slice toBase(@SqlType(StandardTypes.BIGINT) long value, @SqlType(S @ScalarFunction @LiteralParameters("x") @SqlType(StandardTypes.BIGINT) - public static long fromBase(@SqlType("varchar(x)") Slice value, @SqlType(StandardTypes.BIGINT) long radix) + public static long fromBase(@ScalarPropagateSourceStats( + nullFraction = USE_SOURCE_STATS, + distinctValuesCount = USE_SOURCE_STATS) @SqlType("varchar(x)") Slice value, @SqlType(StandardTypes.BIGINT) long radix) { checkRadix(radix); try { diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index 3950c075e9fe4..7a09287c341ae 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -30,47 +30,62 @@ import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.util.Failures.checkCondition; +import static com.google.common.base.MoreObjects.toStringHelper; import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class ParametricScalar extends SqlScalarFunction { - private final ScalarHeader details; + private final ScalarHeader scalarHeader; private final ParametricImplementationsGroup implementations; public ParametricScalar( Signature signature, - ScalarHeader details, + ScalarHeader scalarHeader, ParametricImplementationsGroup implementations) { super(signature); - this.details = requireNonNull(details); + this.scalarHeader = requireNonNull(scalarHeader); this.implementations = requireNonNull(implementations); } @Override public SqlFunctionVisibility getVisibility() { - return details.getVisibility(); + return scalarHeader.getVisibility(); + } + + public ScalarHeader getScalarHeader() + { + return scalarHeader; } @Override public boolean isDeterministic() { - return details.isDeterministic(); + return scalarHeader.isDeterministic(); } @Override public boolean isCalledOnNullInput() { - return details.isCalledOnNullInput(); + return scalarHeader.isCalledOnNullInput(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("signature", getSignature()) + .add("implementation", implementations) + .add("scalarHeader", scalarHeader).toString(); } @Override public String getDescription() { - return details.getDescription().isPresent() ? details.getDescription().get() : ""; + return scalarHeader.getDescription().isPresent() ? scalarHeader.getDescription().get() : ""; } @VisibleForTesting diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java index b6b5e33fa6c00..0959a6684a019 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java @@ -13,16 +13,24 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableMap; +import java.util.Map; import java.util.Optional; +import static com.google.common.base.MoreObjects.toStringHelper; + public class ScalarHeader { private final Optional description; private final SqlFunctionVisibility visibility; private final boolean deterministic; private final boolean calledOnNullInput; + private final Map signatureToScalarStatsHeaders; public ScalarHeader(Optional description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput) { @@ -30,6 +38,29 @@ public ScalarHeader(Optional description, SqlFunctionVisibility visibili this.visibility = visibility; this.deterministic = deterministic; this.calledOnNullInput = calledOnNullInput; + this.signatureToScalarStatsHeaders = ImmutableMap.of(); + } + + public ScalarHeader(Optional description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput, + Map signatureToScalarStatsHeaders) + { + this.description = description; + this.visibility = visibility; + this.deterministic = deterministic; + this.calledOnNullInput = calledOnNullInput; + this.signatureToScalarStatsHeaders = signatureToScalarStatsHeaders; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("description:", this.description) + .add("visibility", this.visibility) + .add("deterministic", this.deterministic) + .add("calledOnNullInput", this.calledOnNullInput) + .add("signatureToScalarStatsHeadersMap", Joiner.on(" , ").withKeyValueSeparator(" -> ").join(this.signatureToScalarStatsHeaders)) + .toString(); } public Optional getDescription() @@ -51,4 +82,9 @@ public boolean isCalledOnNullInput() { return calledOnNullInput; } + + public Map getSignatureToScalarStatsHeadersMap() + { + return signatureToScalarStatsHeaders; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java index 96aa7214aaf8e..a48f59403c5cf 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/StringFunctions.java @@ -21,7 +21,9 @@ import com.facebook.presto.spi.function.Description; import com.facebook.presto.spi.function.LiteralParameters; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; import com.facebook.presto.spi.function.SqlNullable; import com.facebook.presto.spi.function.SqlType; import com.facebook.presto.type.CodePointsType; @@ -41,6 +43,12 @@ import static com.facebook.presto.common.type.Chars.trimTrailingSpaces; import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.MAX_TYPE_WIDTH_VARCHAR; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.NON_NULL_ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_MAX_ARGUMENT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_TYPE_WIDTH_VARCHAR; import static com.facebook.presto.util.Failures.checkCondition; import static io.airlift.slice.SliceUtf8.countCodePoints; import static io.airlift.slice.SliceUtf8.getCodePointAt; @@ -57,7 +65,7 @@ /** * Current implementation is based on code points from Unicode and does ignore grapheme cluster boundaries. - * Therefore only some methods work correctly with grapheme cluster boundaries. + * Therefore, only some methods work correctly with grapheme cluster boundaries. */ public final class StringFunctions { @@ -90,7 +98,13 @@ public static long codepoint(@SqlType("varchar(1)") Slice slice) @ScalarFunction @LiteralParameters("x") @SqlType(StandardTypes.BIGINT) - public static long length(@SqlType("varchar(x)") Slice slice) + @ScalarFunctionConstantStats(minValue = 0) + public static long length( + @ScalarPropagateSourceStats( + maxValue = USE_TYPE_WIDTH_VARCHAR, + distinctValuesCount = USE_TYPE_WIDTH_VARCHAR, + nullFraction = USE_SOURCE_STATS + ) @SqlType("varchar(x)") Slice slice) { return countCodePoints(slice); } @@ -99,7 +113,13 @@ public static long length(@SqlType("varchar(x)") Slice slice) @ScalarFunction("length") @LiteralParameters("x") @SqlType(StandardTypes.BIGINT) - public static long charLength(@LiteralParameter("x") long x, @SqlType("char(x)") Slice slice) + @ScalarFunctionConstantStats(minValue = 0) + public static long charLength(@LiteralParameter("x") long x, + @ScalarPropagateSourceStats( + maxValue = USE_TYPE_WIDTH_VARCHAR, + distinctValuesCount = USE_TYPE_WIDTH_VARCHAR, + nullFraction = USE_SOURCE_STATS + ) @SqlType("char(x)") Slice slice) { return x; } @@ -108,7 +128,9 @@ public static long charLength(@LiteralParameter("x") long x, @SqlType("char(x)") @ScalarFunction @LiteralParameters({"x", "y"}) @SqlType("varchar(x)") - public static Slice replace(@SqlType("varchar(x)") Slice str, @SqlType("varchar(y)") Slice search) + public static Slice replace( + @ScalarPropagateSourceStats(nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice str, + @SqlType("varchar(y)") Slice search) { return replace(str, search, Slices.EMPTY_SLICE); } @@ -118,7 +140,10 @@ public static Slice replace(@SqlType("varchar(x)") Slice str, @SqlType("varchar( @LiteralParameters({"x", "y", "z", "u"}) @Constraint(variable = "u", expression = "min(2147483647, x + z * (x + 1))") @SqlType("varchar(u)") - public static Slice replace(@SqlType("varchar(x)") Slice str, @SqlType("varchar(y)") Slice search, @SqlType("varchar(z)") Slice replace) + public static Slice replace( + @ScalarPropagateSourceStats(nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice str, + @SqlType("varchar(y)") Slice search, + @SqlType("varchar(z)") Slice replace) { // Empty search? if (search.length() == 0) { @@ -191,7 +216,7 @@ public static Slice replace(@SqlType("varchar(x)") Slice str, @SqlType("varchar( @ScalarFunction @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice reverse(@SqlType("varchar(x)") Slice slice) + public static Slice reverse(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice) { return SliceUtf8.reverse(slice); } @@ -200,7 +225,13 @@ public static Slice reverse(@SqlType("varchar(x)") Slice slice) @ScalarFunction("strpos") @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.BIGINT) - public static long stringPosition(@SqlType("varchar(x)") Slice string, @SqlType("varchar(y)") Slice substring) + @ScalarFunctionConstantStats(minValue = 0) + public static long stringPosition( + @ScalarPropagateSourceStats( + maxValue = USE_TYPE_WIDTH_VARCHAR, + distinctValuesCount = USE_TYPE_WIDTH_VARCHAR, + nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice string, + @SqlType("varchar(y)") Slice substring) { return stringPositionFromStart(string, substring, 1); } @@ -209,7 +240,14 @@ public static long stringPosition(@SqlType("varchar(x)") Slice string, @SqlType( @ScalarFunction("strpos") @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.BIGINT) - public static long stringPosition(@SqlType("varchar(x)") Slice string, @SqlType("varchar(y)") Slice substring, @SqlType(StandardTypes.BIGINT) long instance) + @ScalarFunctionConstantStats(minValue = 0) + public static long stringPosition( + @ScalarPropagateSourceStats( + maxValue = USE_TYPE_WIDTH_VARCHAR, + distinctValuesCount = USE_TYPE_WIDTH_VARCHAR, + nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice string, + @SqlType("varchar(y)") Slice substring, + @SqlType(StandardTypes.BIGINT) long instance) { return stringPositionFromStart(string, substring, instance); } @@ -218,7 +256,13 @@ public static long stringPosition(@SqlType("varchar(x)") Slice string, @SqlType( @ScalarFunction("strrpos") @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.BIGINT) - public static long stringReversePosition(@SqlType("varchar(x)") Slice string, @SqlType("varchar(y)") Slice substring) + @ScalarFunctionConstantStats(minValue = 0) + public static long stringReversePosition( + @ScalarPropagateSourceStats( + maxValue = USE_TYPE_WIDTH_VARCHAR, + distinctValuesCount = USE_TYPE_WIDTH_VARCHAR, + nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice string, + @SqlType("varchar(y)") Slice substring) { return stringPositionFromEnd(string, substring, 1); } @@ -227,7 +271,14 @@ public static long stringReversePosition(@SqlType("varchar(x)") Slice string, @S @ScalarFunction("strrpos") @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.BIGINT) - public static long stringReversePosition(@SqlType("varchar(x)") Slice string, @SqlType("varchar(y)") Slice substring, @SqlType(StandardTypes.BIGINT) long instance) + @ScalarFunctionConstantStats(minValue = 0) + public static long stringReversePosition( + @ScalarPropagateSourceStats( + maxValue = USE_TYPE_WIDTH_VARCHAR, + distinctValuesCount = USE_TYPE_WIDTH_VARCHAR, + nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice string, + @SqlType("varchar(y)") Slice substring, + @SqlType(StandardTypes.BIGINT) long instance) { return stringPositionFromEnd(string, substring, instance); } @@ -288,7 +339,8 @@ private static long stringPositionFromEnd(Slice string, Slice substring, long in @ScalarFunction @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice substr(@SqlType("varchar(x)") Slice utf8, @SqlType(StandardTypes.BIGINT) long start) + public static Slice substr(@SqlType("varchar(x)") Slice utf8, + @SqlType(StandardTypes.BIGINT) long start) { if ((start == 0) || utf8.length() == 0) { return Slices.EMPTY_SLICE; @@ -326,7 +378,8 @@ public static Slice substr(@SqlType("varchar(x)") Slice utf8, @SqlType(StandardT @ScalarFunction("substr") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charSubstr(@SqlType("char(x)") Slice utf8, @SqlType(StandardTypes.BIGINT) long start) + public static Slice charSubstr(@SqlType("char(x)") Slice utf8, + @SqlType(StandardTypes.BIGINT) long start) { return substr(utf8, start); } @@ -335,7 +388,9 @@ public static Slice charSubstr(@SqlType("char(x)") Slice utf8, @SqlType(Standard @ScalarFunction @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice substr(@SqlType("varchar(x)") Slice utf8, @SqlType(StandardTypes.BIGINT) long start, @SqlType(StandardTypes.BIGINT) long length) + public static Slice substr(@SqlType("varchar(x)") Slice utf8, + @SqlType(StandardTypes.BIGINT) long start, + @SqlType(StandardTypes.BIGINT) long length) { if (start == 0 || (length <= 0) || (utf8.length() == 0)) { return Slices.EMPTY_SLICE; @@ -384,7 +439,9 @@ public static Slice substr(@SqlType("varchar(x)") Slice utf8, @SqlType(StandardT @ScalarFunction("substr") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charSubstr(@SqlType("char(x)") Slice utf8, @SqlType(StandardTypes.BIGINT) long start, @SqlType(StandardTypes.BIGINT) long length) + public static Slice charSubstr(@SqlType("char(x)") Slice utf8, + @SqlType(StandardTypes.BIGINT) long start, + @SqlType(StandardTypes.BIGINT) long length) { return trimTrailingSpaces(substr(utf8, start, length)); } @@ -392,7 +449,9 @@ public static Slice charSubstr(@SqlType("char(x)") Slice utf8, @SqlType(Standard @ScalarFunction @LiteralParameters({"x", "y"}) @SqlType("array(varchar(x))") - public static Block split(@SqlType("varchar(x)") Slice string, @SqlType("varchar(y)") Slice delimiter) + public static Block split( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice string, + @SqlType("varchar(y)") Slice delimiter) { return split(string, delimiter, string.length() + 1); } @@ -400,7 +459,9 @@ public static Block split(@SqlType("varchar(x)") Slice string, @SqlType("varchar @ScalarFunction @LiteralParameters({"x", "y"}) @SqlType("array(varchar(x))") - public static Block split(@SqlType("varchar(x)") Slice string, @SqlType("varchar(y)") Slice delimiter, @SqlType(StandardTypes.BIGINT) long limit) + public static Block split( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice string, + @SqlType("varchar(y)") Slice delimiter, @SqlType(StandardTypes.BIGINT) long limit) { checkCondition(limit > 0, INVALID_FUNCTION_ARGUMENT, "Limit must be positive"); checkCondition(limit <= Integer.MAX_VALUE, INVALID_FUNCTION_ARGUMENT, "Limit is too large"); @@ -491,7 +552,7 @@ public static Slice splitPart(@SqlType("varchar(x)") Slice string, @SqlType("var @ScalarFunction("ltrim") @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice leftTrim(@SqlType("varchar(x)") Slice slice) + public static Slice leftTrim(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice) { return SliceUtf8.leftTrim(slice); } @@ -500,7 +561,7 @@ public static Slice leftTrim(@SqlType("varchar(x)") Slice slice) @ScalarFunction("ltrim") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charLeftTrim(@SqlType("char(x)") Slice slice) + public static Slice charLeftTrim(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("char(x)") Slice slice) { return SliceUtf8.leftTrim(slice); } @@ -509,7 +570,7 @@ public static Slice charLeftTrim(@SqlType("char(x)") Slice slice) @ScalarFunction("rtrim") @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice rightTrim(@SqlType("varchar(x)") Slice slice) + public static Slice rightTrim(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice) { return SliceUtf8.rightTrim(slice); } @@ -518,7 +579,7 @@ public static Slice rightTrim(@SqlType("varchar(x)") Slice slice) @ScalarFunction("rtrim") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charRightTrim(@SqlType("char(x)") Slice slice) + public static Slice charRightTrim(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("char(x)") Slice slice) { return rightTrim(slice); } @@ -527,7 +588,7 @@ public static Slice charRightTrim(@SqlType("char(x)") Slice slice) @ScalarFunction @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice trim(@SqlType("varchar(x)") Slice slice) + public static Slice trim(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice) { return SliceUtf8.trim(slice); } @@ -536,7 +597,7 @@ public static Slice trim(@SqlType("varchar(x)") Slice slice) @ScalarFunction("trim") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charTrim(@SqlType("char(x)") Slice slice) + public static Slice charTrim(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("char(x)") Slice slice) { return trim(slice); } @@ -545,7 +606,9 @@ public static Slice charTrim(@SqlType("char(x)") Slice slice) @ScalarFunction("ltrim") @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice leftTrim(@SqlType("varchar(x)") Slice slice, @SqlType(CodePointsType.NAME) int[] codePointsToTrim) + public static Slice leftTrim( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice, + @SqlType(CodePointsType.NAME) int[] codePointsToTrim) { return SliceUtf8.leftTrim(slice, codePointsToTrim); } @@ -554,7 +617,9 @@ public static Slice leftTrim(@SqlType("varchar(x)") Slice slice, @SqlType(CodePo @ScalarFunction("ltrim") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charLeftTrim(@SqlType("char(x)") Slice slice, @SqlType(CodePointsType.NAME) int[] codePointsToTrim) + public static Slice charLeftTrim( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("char(x)") Slice slice, + @SqlType(CodePointsType.NAME) int[] codePointsToTrim) { return leftTrim(slice, codePointsToTrim); } @@ -563,7 +628,9 @@ public static Slice charLeftTrim(@SqlType("char(x)") Slice slice, @SqlType(CodeP @ScalarFunction("rtrim") @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice rightTrim(@SqlType("varchar(x)") Slice slice, @SqlType(CodePointsType.NAME) int[] codePointsToTrim) + public static Slice rightTrim( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice, + @SqlType(CodePointsType.NAME) int[] codePointsToTrim) { return SliceUtf8.rightTrim(slice, codePointsToTrim); } @@ -572,7 +639,7 @@ public static Slice rightTrim(@SqlType("varchar(x)") Slice slice, @SqlType(CodeP @ScalarFunction("rtrim") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charRightTrim(@SqlType("char(x)") Slice slice, @SqlType(CodePointsType.NAME) int[] codePointsToTrim) + public static Slice charRightTrim(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("char(x)") Slice slice, @SqlType(CodePointsType.NAME) int[] codePointsToTrim) { return trimTrailingSpaces(rightTrim(slice, codePointsToTrim)); } @@ -581,7 +648,9 @@ public static Slice charRightTrim(@SqlType("char(x)") Slice slice, @SqlType(Code @ScalarFunction("trim") @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice trim(@SqlType("varchar(x)") Slice slice, @SqlType(CodePointsType.NAME) int[] codePointsToTrim) + public static Slice trim( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice, + @SqlType(CodePointsType.NAME) int[] codePointsToTrim) { return SliceUtf8.trim(slice, codePointsToTrim); } @@ -590,7 +659,9 @@ public static Slice trim(@SqlType("varchar(x)") Slice slice, @SqlType(CodePoints @ScalarFunction("trim") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charTrim(@SqlType("char(x)") Slice slice, @SqlType(CodePointsType.NAME) int[] codePointsToTrim) + public static Slice charTrim( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("char(x)") Slice slice, + @SqlType(CodePointsType.NAME) int[] codePointsToTrim) { return trimTrailingSpaces(trim(slice, codePointsToTrim)); } @@ -640,7 +711,7 @@ private static int safeCountCodePoints(Slice slice) @ScalarFunction @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice lower(@SqlType("varchar(x)") Slice slice) + public static Slice lower(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice) { return toLowerCase(slice); } @@ -649,7 +720,7 @@ public static Slice lower(@SqlType("varchar(x)") Slice slice) @ScalarFunction("lower") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charLower(@SqlType("char(x)") Slice slice) + public static Slice charLower(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("char(x)") Slice slice) { return lower(slice); } @@ -658,7 +729,7 @@ public static Slice charLower(@SqlType("char(x)") Slice slice) @ScalarFunction @LiteralParameters("x") @SqlType("varchar(x)") - public static Slice upper(@SqlType("varchar(x)") Slice slice) + public static Slice upper(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice) { return toUpperCase(slice); } @@ -667,7 +738,7 @@ public static Slice upper(@SqlType("varchar(x)") Slice slice) @ScalarFunction("upper") @LiteralParameters("x") @SqlType("char(x)") - public static Slice charUpper(@SqlType("char(x)") Slice slice) + public static Slice charUpper(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("char(x)") Slice slice) { return upper(slice); } @@ -729,7 +800,13 @@ private static Slice pad(Slice text, long targetLength, Slice padString, int pad @ScalarFunction("lpad") @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.VARCHAR) - public static Slice leftPad(@SqlType("varchar(x)") Slice text, @SqlType(StandardTypes.BIGINT) long targetLength, @SqlType("varchar(y)") Slice padString) + public static Slice leftPad( + @ScalarPropagateSourceStats( + distinctValuesCount = USE_SOURCE_STATS, + avgRowSize = SUM_ARGUMENTS, + nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice text, + @SqlType(StandardTypes.BIGINT) long targetLength, + @SqlType("varchar(y)") Slice padString) { return pad(text, targetLength, padString, 0); } @@ -738,7 +815,13 @@ public static Slice leftPad(@SqlType("varchar(x)") Slice text, @SqlType(Standard @ScalarFunction("rpad") @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.VARCHAR) - public static Slice rightPad(@SqlType("varchar(x)") Slice text, @SqlType(StandardTypes.BIGINT) long targetLength, @SqlType("varchar(y)") Slice padString) + public static Slice rightPad( + @ScalarPropagateSourceStats( + distinctValuesCount = USE_SOURCE_STATS, + avgRowSize = SUM_ARGUMENTS, + nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice text, + @SqlType(StandardTypes.BIGINT) long targetLength, + @SqlType("varchar(y)") Slice padString) { return pad(text, targetLength, padString, text.length()); } @@ -747,7 +830,13 @@ public static Slice rightPad(@SqlType("varchar(x)") Slice text, @SqlType(Standar @ScalarFunction @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.BIGINT) - public static long levenshteinDistance(@SqlType("varchar(x)") Slice left, @SqlType("varchar(y)") Slice right) + @ScalarFunctionConstantStats(minValue = 0, avgRowSize = 8) + public static long levenshteinDistance( + @ScalarPropagateSourceStats( + maxValue = MAX_TYPE_WIDTH_VARCHAR, + distinctValuesCount = MAX_TYPE_WIDTH_VARCHAR, + nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice left, + @SqlType("varchar(y)") Slice right) { int[] leftCodePoints = castToCodePoints(left); int[] rightCodePoints = castToCodePoints(right); @@ -799,7 +888,13 @@ public static long levenshteinDistance(@SqlType("varchar(x)") Slice left, @SqlTy @ScalarFunction @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.BIGINT) - public static long hammingDistance(@SqlType("varchar(x)") Slice left, @SqlType("varchar(y)") Slice right) + @ScalarFunctionConstantStats(minValue = 0, avgRowSize = 8) + public static long hammingDistance( + @ScalarPropagateSourceStats( + maxValue = MAX_TYPE_WIDTH_VARCHAR, + distinctValuesCount = MAX_TYPE_WIDTH_VARCHAR, + nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice left, + @SqlType("varchar(y)") Slice right) { int distance = 0; int leftPosition = 0; @@ -830,7 +925,9 @@ public static long hammingDistance(@SqlType("varchar(x)") Slice left, @SqlType(" @ScalarFunction @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.VARCHAR) - public static Slice normalize(@SqlType("varchar(x)") Slice slice, @SqlType("varchar(y)") Slice form) + public static Slice normalize( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice, + @SqlType("varchar(y)") Slice form) { Normalizer.Form targetForm; try { @@ -845,7 +942,8 @@ public static Slice normalize(@SqlType("varchar(x)") Slice slice, @SqlType("varc @Description("decodes the UTF-8 encoded string") @ScalarFunction @SqlType(StandardTypes.VARCHAR) - public static Slice fromUtf8(@SqlType(StandardTypes.VARBINARY) Slice slice) + public static Slice fromUtf8( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.VARBINARY) Slice slice) { return SliceUtf8.fixInvalidUtf8(slice); } @@ -854,7 +952,9 @@ public static Slice fromUtf8(@SqlType(StandardTypes.VARBINARY) Slice slice) @ScalarFunction @LiteralParameters("x") @SqlType(StandardTypes.VARCHAR) - public static Slice fromUtf8(@SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType("varchar(x)") Slice replacementCharacter) + public static Slice fromUtf8( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.VARBINARY) Slice slice, + @SqlType("varchar(x)") Slice replacementCharacter) { int count = countCodePoints(replacementCharacter); if (count > 1) { @@ -879,7 +979,9 @@ public static Slice fromUtf8(@SqlType(StandardTypes.VARBINARY) Slice slice, @Sql @Description("decodes the UTF-8 encoded string") @ScalarFunction @SqlType(StandardTypes.VARCHAR) - public static Slice fromUtf8(@SqlType(StandardTypes.VARBINARY) Slice slice, @SqlType(StandardTypes.BIGINT) long replacementCodePoint) + public static Slice fromUtf8( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.VARBINARY) Slice slice, + @SqlType(StandardTypes.BIGINT) long replacementCodePoint) { if (replacementCodePoint > MAX_CODE_POINT || Character.getType((int) replacementCodePoint) == SURROGATE) { throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Invalid replacement character"); @@ -891,7 +993,7 @@ public static Slice fromUtf8(@SqlType(StandardTypes.VARBINARY) Slice slice, @Sql @ScalarFunction @LiteralParameters("x") @SqlType(StandardTypes.VARBINARY) - public static Slice toUtf8(@SqlType("varchar(x)") Slice slice) + public static Slice toUtf8(@ScalarPropagateSourceStats(propagateAllStats = true) @SqlType("varchar(x)") Slice slice) { return slice; } @@ -902,7 +1004,12 @@ public static Slice toUtf8(@SqlType("varchar(x)") Slice slice) @LiteralParameters({"x", "y", "u"}) @Constraint(variable = "u", expression = "x + y") @SqlType("char(u)") - public static Slice concat(@LiteralParameter("x") Long x, @SqlType("char(x)") Slice left, @SqlType("char(y)") Slice right) + public static Slice concat(@LiteralParameter("x") Long x, + @ScalarPropagateSourceStats( + nullFraction = USE_MAX_ARGUMENT, + avgRowSize = SUM_ARGUMENTS, + distinctValuesCount = NON_NULL_ROW_COUNT) @SqlType("char(x)") Slice left, + @SqlType("char(y)") Slice right) { int rightLength = right.length(); if (rightLength == 0) { @@ -924,7 +1031,8 @@ public static Slice concat(@LiteralParameter("x") Long x, @SqlType("char(x)") Sl @ScalarFunction("starts_with") @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.BOOLEAN) - public static Boolean startsWith(@SqlType("varchar(x)") Slice x, @SqlType("varchar(y)") Slice y) + @ScalarFunctionConstantStats(distinctValuesCount = 2) + public static Boolean startsWith(@ScalarPropagateSourceStats(nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice x, @SqlType("varchar(y)") Slice y) { if (x.length() < y.length()) { return false; @@ -938,7 +1046,8 @@ public static Boolean startsWith(@SqlType("varchar(x)") Slice x, @SqlType("varch @ScalarFunction("ends_with") @LiteralParameters({"x", "y"}) @SqlType(StandardTypes.BOOLEAN) - public static Boolean endsWith(@SqlType("varchar(x)") Slice x, @SqlType("varchar(y)") Slice y) + @ScalarFunctionConstantStats(distinctValuesCount = 2) + public static Boolean endsWith(@ScalarPropagateSourceStats(nullFraction = USE_MAX_ARGUMENT) @SqlType("varchar(x)") Slice x, @SqlType("varchar(y)") Slice y) { if (x.length() < y.length()) { return false; diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java index eb2f5a3ae346f..1b8ef7892c992 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java @@ -17,23 +17,32 @@ import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.annotations.FunctionsParserHelper; import com.facebook.presto.operator.scalar.ParametricScalar; +import com.facebook.presto.operator.scalar.ScalarHeader; import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature; import com.facebook.presto.spi.function.CodegenScalarFunction; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.ScalarStatsHeader; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlInvokedScalarFunction; import com.facebook.presto.spi.function.SqlType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static com.facebook.presto.operator.scalar.annotations.OperatorValidator.validateOperator; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; @@ -99,11 +108,38 @@ private static List findScalarsInFunctionSetClass(Class< return builder.build(); } + private static Optional getScalarStatsHeader(Method annotated) + { + Optional scalarStatsHeader; + ScalarFunctionConstantStats constantStatsAnnotation = + annotated.getAnnotation(ScalarFunctionConstantStats.class); + List params = + Arrays.stream(annotated.getParameters()) + .filter(param -> param.getAnnotation(SqlType.class) != null) + .collect(Collectors.toList()); + // Map of (function argument position index) -> (ScalarPropagateSourceStats annotation) + ImmutableMap.Builder argumentIndexToStatsAnnotationMapBuilder = new ImmutableMap.Builder<>(); + + IntStream.range(0, params.size()) + .filter(paramIndex -> params.get(paramIndex).getAnnotation(ScalarPropagateSourceStats.class) != null) + .forEachOrdered(paramIndex -> argumentIndexToStatsAnnotationMapBuilder.put(paramIndex, + params.get(paramIndex).getAnnotation(ScalarPropagateSourceStats.class))); + + Map argumentIndexToStatsAnnotation = argumentIndexToStatsAnnotationMapBuilder.build(); + scalarStatsHeader = Optional.ofNullable(constantStatsAnnotation) + .map(statsAnnotation -> new ScalarStatsHeader(statsAnnotation, argumentIndexToStatsAnnotation)); + if (!argumentIndexToStatsAnnotation.isEmpty() && !scalarStatsHeader.isPresent()) { + scalarStatsHeader = Optional.of(new ScalarStatsHeader(argumentIndexToStatsAnnotation)); + } + return scalarStatsHeader; + } + private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods scalar, Optional> constructor) { ScalarImplementationHeader header = scalar.getHeader(); Map signatures = new HashMap<>(); + ImmutableMap.Builder signatureToStatsHeaderMapBuilder = new ImmutableMap.Builder<>(); for (Method method : scalar.getMethods()) { ParametricScalarImplementation implementation = ParametricScalarImplementation.Parser.parseImplementation(header, method, constructor); if (!signatures.containsKey(implementation.getSpecializedSignature())) { @@ -119,6 +155,8 @@ private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods sc ParametricScalarImplementation.Builder builder = signatures.get(implementation.getSpecializedSignature()); builder.addChoices(implementation); } + Optional scalarStatsHeader = getScalarStatsHeader(method); + scalarStatsHeader.ifPresent(statsHeader -> signatureToStatsHeaderMapBuilder.put(implementation.getSignature().canonicalization(), statsHeader)); } ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); @@ -131,7 +169,11 @@ private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods sc header.getOperatorType().ifPresent(operatorType -> validateOperator(operatorType, scalarSignature.getReturnType(), scalarSignature.getArgumentTypes())); - return new ParametricScalar(scalarSignature, header.getHeader(), implementations); + ScalarHeader scalarHeader = header.getHeader(); + ScalarHeader headerWithStats = + new ScalarHeader(scalarHeader.getDescription(), scalarHeader.getVisibility(), scalarHeader.isDeterministic(), + scalarHeader.isCalledOnNullInput(), signatureToStatsHeaderMapBuilder.build()); + return new ParametricScalar(scalarSignature, headerWithStats, implementations); } private static class ScalarHeaderAndMethods diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 88d1b3248cd76..8bd78582c184d 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -293,7 +293,7 @@ public class FeaturesConfig private boolean removeCrossJoinWithSingleConstantRow = true; private CreateView.Security defaultViewSecurityMode = DEFINER; private boolean useHistograms; - + private boolean isScalarFunctionStatsPropagationEnabled; private boolean isInlineProjectionsOnValuesEnabled; private boolean eagerPlanValidationEnabled; @@ -2947,6 +2947,19 @@ public FeaturesConfig setRemoveCrossJoinWithSingleConstantRow(boolean removeCros return this; } + public boolean isScalarFunctionStatsPropagationEnabled() + { + return isScalarFunctionStatsPropagationEnabled; + } + + @Config("optimizer.scalar-function-stats-propagation-enabled") + @ConfigDescription("Respect scalar function statistics annotation for cost-based calculations in the optimizer") + public FeaturesConfig setScalarFunctionStatsPropagationEnabled(boolean scalarFunctionStatsPropagation) + { + this.isScalarFunctionStatsPropagationEnabled = scalarFunctionStatsPropagation; + return this; + } + public boolean isUseHistograms() { return useHistograms; diff --git a/presto-main/src/main/java/com/facebook/presto/testing/TestngUtils.java b/presto-main/src/main/java/com/facebook/presto/testing/TestngUtils.java index 1c89e72664956..969bfd21749c0 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/TestngUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/TestngUtils.java @@ -31,4 +31,16 @@ private TestngUtils() {} }, builder -> builder.toArray(new Object[][] {})); } + + public static Collector toDataProviderFromArray() + { + return Collector.of( + ArrayList::new, + ArrayList::add, + (left, right) -> { + left.addAll(right); + return left; + }, + builder -> builder.toArray(new Object[][] {})); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java b/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java index fd5beebb602b7..ccdefea2c1bc6 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java +++ b/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java @@ -15,6 +15,8 @@ import java.util.stream.DoubleStream; +import static java.lang.Double.NaN; +import static java.lang.Double.isFinite; import static java.lang.Double.isNaN; public final class MoreMath @@ -79,6 +81,23 @@ public static double max(double... values) .getAsDouble(); } + /** + * Returns the minimum value of the arguments. Returns NaN if there are no arguments or all arguments are NaN. + */ + public static double minExcludingNaNs(double... values) + { + double min = NaN; + for (double v : values) { + if (isFinite(v)) { + if (isNaN(min)) { + min = v; + } + min = Math.min(min, v); + } + } + return min; + } + public static double rangeMin(double left, double right) { if (isNaN(left)) { diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsAnnotationProcessor.java b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsAnnotationProcessor.java new file mode 100644 index 0000000000000..a93952b778ae3 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsAnnotationProcessor.java @@ -0,0 +1,312 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.metadata.BuiltInFunctionHandle; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.StatsPropagationBehavior; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.lang.annotation.Annotation; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.cost.ScalarStatsAnnotationProcessor.computeStatsFromAnnotations; +import static com.facebook.presto.spi.function.FunctionKind.SCALAR; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.Constants.NON_NULL_ROW_COUNT_CONST; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.MAX_TYPE_WIDTH_VARCHAR; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.NON_NULL_ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_MAX_ARGUMENT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_MIN_ARGUMENT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.NaN; +import static java.lang.Double.POSITIVE_INFINITY; +import static org.testng.Assert.assertEquals; + +public class TestScalarStatsAnnotationProcessor +{ + private static final VariableStatsEstimate STATS_ESTIMATE_FINITE = VariableStatsEstimate.builder() + .setLowValue(1.0) + .setHighValue(120.0) + .setNullsFraction(0.1) + .setAverageRowSize(15.0) + .setDistinctValuesCount(23.0) + .build(); + private static final ScalarFunctionConstantStats CONSTANT_STATS_UNKNOWN = createScalarFunctionConstantStatsInstance(NEGATIVE_INFINITY, POSITIVE_INFINITY, NaN, NaN, NaN); + private static final VariableStatsEstimate STATS_ESTIMATE_UNKNOWN = VariableStatsEstimate.unknown(); + private static final List STATS_ESTIMATE_LIST = ImmutableList.of(STATS_ESTIMATE_FINITE, STATS_ESTIMATE_FINITE); + private static final List STATS_ESTIMATE_LIST_WITH_UNKNOWN = ImmutableList.of(STATS_ESTIMATE_FINITE, STATS_ESTIMATE_UNKNOWN); + private static final TypeSignature VARCHAR_TYPE_10 = createVarcharType(10).getTypeSignature(); + private static final List TWO_ARGUMENTS = ImmutableList.of( + new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(10)), + new VariableReferenceExpression(Optional.empty(), "y", createVarcharType(10))); + + @Test + public void testComputeStatsFromAnnotationsConstantStatsTakePrecedence() + { + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10); + CallExpression callExpression = + new CallExpression("test", new BuiltInFunctionHandle(signature), createVarcharType(10), + ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "y", VARCHAR))); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + createScalarFunctionConstantStatsInstance(1, 10, 0.1, 2.3, 25), + ImmutableMap.of(1, createScalarPropagateSourceStatsInstance(true, UNKNOWN, ROW_COUNT, UNKNOWN, UNKNOWN, UNKNOWN))); + VariableStatsEstimate actualStats = computeStatsFromAnnotations(callExpression, STATS_ESTIMATE_LIST, scalarStatsHeader, 1000); + VariableStatsEstimate expectedStats = VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setNullsFraction(0.1) + .setAverageRowSize(2.3) + .setDistinctValuesCount(25) + .build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testComputeStatsFromAnnotationsNaNSourceStats() + { + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10, + VARCHAR_TYPE_10, VARCHAR_TYPE_10); + CallExpression callExpression = + new CallExpression("test", new BuiltInFunctionHandle(signature), createVarcharType(10), TWO_ARGUMENTS); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + CONSTANT_STATS_UNKNOWN, + ImmutableMap.of(1, createScalarPropagateSourceStatsInstance(true, USE_SOURCE_STATS, USE_MAX_ARGUMENT, SUM_ARGUMENTS, SUM_ARGUMENTS, NON_NULL_ROW_COUNT))); + VariableStatsEstimate actualStats = computeStatsFromAnnotations(callExpression, STATS_ESTIMATE_LIST_WITH_UNKNOWN, scalarStatsHeader, 1000); + VariableStatsEstimate expectedStats = VariableStatsEstimate + .buildFrom(VariableStatsEstimate.unknown()) + .setDistinctValuesCount(1000) + .setAverageRowSize(10.0).build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testComputeStatsFromAnnotationsTypeWidthBoundaryConditions() + { + VariableStatsEstimate statsEstimateLarge = + VariableStatsEstimate.builder() + .setNullsFraction(0.0) + .setAverageRowSize(8.0) + .setDistinctValuesCount(Double.MAX_VALUE - 1) + .build(); + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10, + createVarcharType(VarcharType.MAX_LENGTH).getTypeSignature(), createVarcharType(VarcharType.MAX_LENGTH).getTypeSignature()); + + List largeVarcharArguments = ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(VarcharType.MAX_LENGTH)), + new VariableReferenceExpression(Optional.empty(), "y", createVarcharType(VarcharType.MAX_LENGTH))); + CallExpression callExpression = new CallExpression("test", new BuiltInFunctionHandle(signature), createUnboundedVarcharType(), largeVarcharArguments); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + CONSTANT_STATS_UNKNOWN, + ImmutableMap.of(1, createScalarPropagateSourceStatsInstance(false, USE_SOURCE_STATS, SUM_ARGUMENTS, SUM_ARGUMENTS, SUM_ARGUMENTS, MAX_TYPE_WIDTH_VARCHAR))); + VariableStatsEstimate actualStats = + computeStatsFromAnnotations(callExpression, ImmutableList.of(statsEstimateLarge, statsEstimateLarge), scalarStatsHeader, Double.MAX_VALUE - 1); + VariableStatsEstimate expectedStats = VariableStatsEstimate + .builder() + .setNullsFraction(0.0) + .setDistinctValuesCount(VarcharType.MAX_LENGTH) + .setAverageRowSize(16.0).build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testComputeStatsFromAnnotationsTypeWidthBoundaryConditions2() + { + VariableStatsEstimate statsEstimateLarge = + VariableStatsEstimate.builder() + .setLowValue(Double.MIN_VALUE) + .setHighValue(Double.MAX_VALUE) + .setNullsFraction(0.0) + .setAverageRowSize(8.0) + .setDistinctValuesCount(Double.MAX_VALUE - 1) + .build(); + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10, + DoubleType.DOUBLE.getTypeSignature(), DoubleType.DOUBLE.getTypeSignature()); + + List doubleArguments = ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "x", DoubleType.DOUBLE), + new VariableReferenceExpression(Optional.empty(), "y", DoubleType.DOUBLE)); + CallExpression callExpression = new CallExpression("test", new BuiltInFunctionHandle(signature), createUnboundedVarcharType(), doubleArguments); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + CONSTANT_STATS_UNKNOWN, + ImmutableMap.of(1, createScalarPropagateSourceStatsInstance(false, + USE_MIN_ARGUMENT, SUM_ARGUMENTS, SUM_ARGUMENTS, SUM_ARGUMENTS, + USE_MAX_ARGUMENT))); + VariableStatsEstimate actualStats = + computeStatsFromAnnotations(callExpression, ImmutableList.of(statsEstimateLarge, statsEstimateLarge), scalarStatsHeader, Double.MAX_VALUE - 1); + VariableStatsEstimate expectedStats = VariableStatsEstimate + .builder() + .setLowValue(Double.MIN_VALUE) + .setHighValue(POSITIVE_INFINITY) + .setNullsFraction(0.0) + .setAverageRowSize(16.0) + .setDistinctValuesCount(Double.MAX_VALUE - 1).build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testComputeStatsFromAnnotationsConstantStats() + { + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10); + CallExpression callExpression = + new CallExpression("test", new BuiltInFunctionHandle(signature), createVarcharType(10), ImmutableList.of()); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + createScalarFunctionConstantStatsInstance(0, 1, 0.1, 8, NON_NULL_ROW_COUNT_CONST), + ImmutableMap.of()); + VariableStatsEstimate actualStats = computeStatsFromAnnotations(callExpression, STATS_ESTIMATE_LIST, scalarStatsHeader, 1000); + VariableStatsEstimate expectedStats = VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(1) + .setNullsFraction(0.1) + .setAverageRowSize(8.0) + .setDistinctValuesCount(900) + .build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testComputeStatsFromAnnotationsConstantNDVWithNullFractionFromArgumentStats() + { + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10, VARCHAR_TYPE_10, VARCHAR_TYPE_10); + CallExpression callExpression = new CallExpression("test", new BuiltInFunctionHandle(signature), createVarcharType(10), TWO_ARGUMENTS); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + createScalarFunctionConstantStatsInstance(0, 1, NaN, NaN, NON_NULL_ROW_COUNT_CONST), + ImmutableMap.of(0, createScalarPropagateSourceStatsInstance(false, UNKNOWN, UNKNOWN, SUM_ARGUMENTS, USE_SOURCE_STATS, UNKNOWN))); + VariableStatsEstimate actualStats = computeStatsFromAnnotations(callExpression, STATS_ESTIMATE_LIST, scalarStatsHeader, 1000); + VariableStatsEstimate expectedStats = VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(1) + .setNullsFraction(0.1) + .setAverageRowSize(10) + .setDistinctValuesCount(900) + .build(); + assertEquals(actualStats, expectedStats); + } + + private static ScalarFunctionConstantStats createScalarFunctionConstantStatsInstance( + double min, double max, double nullFraction, double avgRowSize, + double distinctValuesCount) + { + return new ScalarFunctionConstantStats() + { + @Override + public Class annotationType() + { + return ScalarFunctionConstantStats.class; + } + + @Override + public double minValue() + { + return min; + } + + @Override + public double maxValue() + { + return max; + } + + @Override + public double distinctValuesCount() + { + return distinctValuesCount; + } + + @Override + public double nullFraction() + { + return nullFraction; + } + + @Override + public double avgRowSize() + { + return avgRowSize; + } + }; + } + + private ScalarPropagateSourceStats createScalarPropagateSourceStatsInstance( + Boolean propagateAllStats, + StatsPropagationBehavior minValue, + StatsPropagationBehavior maxValue, + StatsPropagationBehavior avgRowSize, + StatsPropagationBehavior nullFraction, + StatsPropagationBehavior distinctValuesCount) + { + return new ScalarPropagateSourceStats() + { + @Override + public Class annotationType() + { + return ScalarPropagateSourceStats.class; + } + + @Override + public boolean propagateAllStats() + { + return propagateAllStats; + } + + @Override + public StatsPropagationBehavior minValue() + { + return minValue; + } + + @Override + public StatsPropagationBehavior maxValue() + { + return maxValue; + } + + @Override + public StatsPropagationBehavior distinctValuesCount() + { + return distinctValuesCount; + } + + @Override + public StatsPropagationBehavior avgRowSize() + { + return avgRowSize; + } + + @Override + public StatsPropagationBehavior nullFraction() + { + return nullFraction; + } + }; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java index c7951ffbeebb1..60ff50d69c034 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java @@ -14,8 +14,18 @@ package com.facebook.presto.cost; import com.facebook.presto.Session; +import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.FunctionListBuilder; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.function.LiteralParameters; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.SqlFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.StatsPropagationBehavior; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; @@ -34,16 +44,23 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slice; import io.airlift.slice.Slices; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.List; import java.util.Map; import java.util.Optional; +import static com.facebook.presto.SystemSessionProperties.SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.SmallintType.SMALLINT; +import static com.facebook.presto.common.type.TinyintType.TINYINT; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; @@ -51,18 +68,57 @@ import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.POSITIVE_INFINITY; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; public class TestScalarStatsCalculator { + public static final Map SESSION_CONFIG = ImmutableMap.of(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, "true"); private static final Map DEFAULT_SYMBOL_TYPES = ImmutableMap.of( "a", BIGINT, "x", BIGINT, "y", BIGINT, "all_null", BIGINT); - + private static final PlanNodeStatsEstimate TWO_ARGUMENTS_BIGINT_SOURCE_STATS = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), VariableStatsEstimate.builder() + .setLowValue(-1) + .setHighValue(10) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), VariableStatsEstimate.builder() + .setLowValue(-2) + .setHighValue(5) + .setDistinctValuesCount(3) + .setNullsFraction(0.2) + .build()) + .setOutputRowCount(10) + .build(); + + private static final PlanNodeStatsEstimate BIGINT_SOURCE_STATS = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), + VariableStatsEstimate.builder() + .setLowValue(-2) + .setHighValue(5) + .setDistinctValuesCount(3) + .setNullsFraction(0.2) + .build()) + .setOutputRowCount(10) + .build(); + private static final PlanNodeStatsEstimate VARCHAR_SOURCE_STATS_10_ROWS = PlanNodeStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(20)), + VariableStatsEstimate.builder() + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .setAverageRowSize(14) + .build()) + .setOutputRowCount(10) + .build(); + + private static final Map TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP = ImmutableMap.of("x", BIGINT, "y", BIGINT); + private final SqlParser sqlParser = new SqlParser(); private ScalarStatsCalculator calculator; private Session session; - private final SqlParser sqlParser = new SqlParser(); private TestingRowExpressionTranslator translator; @BeforeClass @@ -73,52 +129,289 @@ public void setUp() translator = new TestingRowExpressionTranslator(MetadataManager.createTestMetadataManager()); } + @Test + public void testStatsPropagationForCustomAdd() + { + assertCalculate(SESSION_CONFIG, + expression("custom_add(x, y)"), + TWO_ARGUMENTS_BIGINT_SOURCE_STATS, + TypeProvider.viewOf(TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP)) + .distinctValuesCount(4) + .lowValue(-3) + .highValue(15) + .nullsFraction(0.3) + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForUnknownSourceStats() + { + PlanNodeStatsEstimate statsWithUnknowns = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), VariableStatsEstimate.builder() + .setLowValue(-1) + .setHighValue(10) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), + VariableStatsEstimate.unknown()) + .setOutputRowCount(10) + .build(); + assertCalculate(SESSION_CONFIG, + expression("custom_add(x, y)"), + statsWithUnknowns, + TypeProvider.viewOf(TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP)) + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown() + .averageRowSize(8.0); + + PlanNodeStatsEstimate varcharStatsUnknown = PlanNodeStatsEstimate.buildFrom(PlanNodeStatsEstimate.unknown()) + .setOutputRowCount(10) + .build(); + assertCalculate(SESSION_CONFIG, + expression("custom_str_len(x)"), + varcharStatsUnknown, + TypeProvider.viewOf(ImmutableMap.of("x", createVarcharType(20)))) + .lowValue(0.0) + .highValue(20.0) + .distinctValuesCountUnknown() + .nullsFractionUnknown() + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForCustomStrLen() + { + PlanNodeStatsEstimate varcharStats100Rows = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(20)), VariableStatsEstimate.builder() + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .setAverageRowSize(14) + .build()) + .setOutputRowCount(100) + .build(); + + assertCalculate(SESSION_CONFIG, + expression("custom_str_len(x)"), + varcharStats100Rows, + TypeProvider.viewOf(ImmutableMap.of("x", createVarcharType(20)))) + .distinctValuesCount(20.0) + .lowValue(0.0) + .highValue(20.0) + .nullsFraction(0.1) + .averageRowSize(8.0); + assertCalculate(SESSION_CONFIG, + expression("custom_str_len(x)"), + VARCHAR_SOURCE_STATS_10_ROWS, + TypeProvider.viewOf(ImmutableMap.of("x", createVarcharType(20)))) + .lowValue(0.0) + .highValue(20.0) + .distinctValuesCountUnknown() // When computed NDV is > output row count, it is set to unknown. + .nullsFraction(0.1) + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForCustomPrng() + { + assertCalculate(SESSION_CONFIG, + expression("custom_prng(x, y)"), + TWO_ARGUMENTS_BIGINT_SOURCE_STATS, + TypeProvider.viewOf(TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP)) + .lowValue(-1) + .highValue(5) + .distinctValuesCount(10) + .nullsFraction(0.0) + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForCustomStringEditDistance() + { + PlanNodeStatsEstimate.Builder sourceStatsBuilder = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(10)), + VariableStatsEstimate.builder() + .setDistinctValuesCount(4) + .setNullsFraction(0.213) + .setAverageRowSize(9.44) + .build()) + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", createVarcharType(20)), + VariableStatsEstimate.builder() + .setDistinctValuesCount(6) + .setNullsFraction(0.4) + .setAverageRowSize(19.333) + .build()); + PlanNodeStatsEstimate sourceStats10Rows = sourceStatsBuilder.setOutputRowCount(10).build(); + PlanNodeStatsEstimate sourceStats100Rows = sourceStatsBuilder.setOutputRowCount(100).build(); + Map referenceNameToVarcharType = ImmutableMap.of("x", createVarcharType(10), "y", createVarcharType(20)); + Map referenceNameToUnboundedVarcharType = ImmutableMap.of("x", createVarcharType(10), "y", VARCHAR); + assertCalculate(SESSION_CONFIG, + expression("custom_str_edit_distance(x, y)"), + sourceStats10Rows, + TypeProvider.viewOf(referenceNameToVarcharType)) + .lowValue(0) + .highValue(20) + .distinctValuesCountUnknown() + .nullsFraction(0.4) + .averageRowSize(8.0); + assertCalculate(SESSION_CONFIG, + expression("custom_str_edit_distance(x, y)"), + sourceStats100Rows, + TypeProvider.viewOf(referenceNameToVarcharType)) + .distinctValuesCount(20) + .lowValue(0) + .highValue(20) + .nullsFraction(0.4) + .averageRowSize(8.0); + assertCalculate(SESSION_CONFIG, + expression("custom_str_edit_distance(x, y)"), + sourceStats100Rows, + TypeProvider.viewOf(referenceNameToUnboundedVarcharType)) + .lowValue(0) + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown() + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForCustomIsNull() + { + assertCalculate(SESSION_CONFIG, + expression("custom_is_null(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT))) + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCount(3.19) + .nullsFraction(0.0) + .averageRowSize(1.0); + assertCalculate(SESSION_CONFIG, + expression("custom_is_null(x)"), + VARCHAR_SOURCE_STATS_10_ROWS, + TypeProvider.viewOf(ImmutableMap.of("x", createVarcharType(10)))) + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCount(2.0) + .nullsFraction(0.0) + .averageRowSize(1.0); + } + + @Test + public void testConstantStatsBoundaryConditions() + { + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null2(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null3(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null4(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null5(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null6(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertCalculate(SESSION_CONFIG, + expression("custom_is_null7(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT))) + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForSourceStatsBoundaryConditions() + { + PlanNodeStatsEstimate sourceStats = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), VariableStatsEstimate.builder() + .setLowValue(NEGATIVE_INFINITY) + .setHighValue(-7) + .setDistinctValuesCount(10) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), VariableStatsEstimate.builder() + .setLowValue(-2) + .setHighValue(-1) + .setDistinctValuesCount(10) + .setNullsFraction(1.0) + .build()) + .setOutputRowCount(10) + .build(); + assertCalculate(SESSION_CONFIG, + expression("custom_add(x, y)"), + sourceStats, + TypeProvider.viewOf(TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP)) + .distinctValuesCount(10) + .lowValueUnknown() + .highValue(-8) + .nullsFraction(1.0) + .averageRowSize(8.0); + } + @Test public void testLiteral() { assertCalculate(new GenericLiteral("TINYINT", "7")) + .averageRowSize(TINYINT.getFixedSize()) .distinctValuesCount(1.0) .lowValue(7) .highValue(7) .nullsFraction(0.0); assertCalculate(new GenericLiteral("SMALLINT", "8")) + .averageRowSize(SMALLINT.getFixedSize()) .distinctValuesCount(1.0) .lowValue(8) .highValue(8) .nullsFraction(0.0); assertCalculate(new GenericLiteral("INTEGER", "9")) + .averageRowSize(INTEGER.getFixedSize()) .distinctValuesCount(1.0) .lowValue(9) .highValue(9) .nullsFraction(0.0); assertCalculate(new GenericLiteral("BIGINT", Long.toString(Long.MAX_VALUE))) + .averageRowSize(BIGINT.getFixedSize()) .distinctValuesCount(1.0) .lowValue(Long.MAX_VALUE) .highValue(Long.MAX_VALUE) .nullsFraction(0.0); assertCalculate(new DoubleLiteral("7.5")) + .averageRowSize(DOUBLE.getFixedSize()) .distinctValuesCount(1.0) .lowValue(7.5) .highValue(7.5) .nullsFraction(0.0); assertCalculate(new DecimalLiteral("75.5")) + .averageRowSize(8) .distinctValuesCount(1.0) .lowValue(75.5) .highValue(75.5) .nullsFraction(0.0); assertCalculate(new StringLiteral("blah")) + .averageRowSize(4) .distinctValuesCount(1.0) .lowValueUnknown() .highValueUnknown() .nullsFraction(0.0); assertCalculate(new NullLiteral()) + .dataSizeUnknown() .distinctValuesCount(0.0) .lowValueUnknown() .highValueUnknown() @@ -524,8 +817,163 @@ public void testCoalesceExpression() .averageRowSize(2.0); } + private VariableStatsAssertion assertCalculate( + Map sessionConfigs, + Expression scalarExpression, + PlanNodeStatsEstimate inputStatistics, + TypeProvider types) + { + MetadataManager metadata = createTestMetadataManager(); + List functions = new FunctionListBuilder() + .scalars(CustomFunctions.class) + .getFunctions(); + Session.SessionBuilder sessionBuilder = testSessionBuilder(); + for (Map.Entry entry : sessionConfigs.entrySet()) { + sessionBuilder.setSystemProperty(entry.getKey(), entry.getValue()); + } + Session session1 = sessionBuilder.build(); + metadata.getFunctionAndTypeManager().registerBuiltInFunctions(functions); + ScalarStatsCalculator statsCalculator = new ScalarStatsCalculator(metadata); + TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(metadata); + RowExpression scalarRowExpression = translator.translate(scalarExpression, types); + VariableStatsEstimate rowExpressionVariableStatsEstimate = statsCalculator.calculate(scalarRowExpression, inputStatistics, session1); + return VariableStatsAssertion.assertThat(rowExpressionVariableStatsEstimate); + } + private Expression expression(String sqlExpression) { return rewriteIdentifiersToSymbolReferences(sqlParser.createExpression(sqlExpression)); } + + public static final class CustomFunctions + { + private CustomFunctions() {} + + @ScalarFunction(value = "custom_add", calledOnNullInput = false) + @ScalarFunctionConstantStats(avgRowSize = 8.0) + @SqlType(StandardTypes.BIGINT) + public static long customAdd( + @ScalarPropagateSourceStats( + propagateAllStats = false, + nullFraction = StatsPropagationBehavior.SUM_ARGUMENTS, + distinctValuesCount = StatsPropagationBehavior.USE_MAX_ARGUMENT, + minValue = StatsPropagationBehavior.SUM_ARGUMENTS, + maxValue = StatsPropagationBehavior.SUM_ARGUMENTS) @SqlType(StandardTypes.BIGINT) long x, + @SqlType(StandardTypes.BIGINT) long y) + { + return x + y; + } + + @ScalarFunction(value = "custom_is_null", calledOnNullInput = true) + @LiteralParameters("x") + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(distinctValuesCount = 2.0, nullFraction = 0.0) + public static boolean customIsNullVarchar(@SqlNullable @SqlType("varchar(x)") Slice slice) + { + return slice == null; + } + + @ScalarFunction(value = "custom_is_null", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(distinctValuesCount = 3.19, nullFraction = 0.0) + public static boolean customIsNullBigint(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_str_len") + @SqlType(StandardTypes.BIGINT) + @LiteralParameters("x") + @ScalarFunctionConstantStats(minValue = 0) + public static long customStrLength( + @ScalarPropagateSourceStats( + propagateAllStats = false, + nullFraction = StatsPropagationBehavior.USE_SOURCE_STATS, + distinctValuesCount = StatsPropagationBehavior.USE_TYPE_WIDTH_VARCHAR, + maxValue = StatsPropagationBehavior.USE_TYPE_WIDTH_VARCHAR) @SqlType("varchar(x)") Slice value) + { + return value.length(); + } + + @ScalarFunction(value = "custom_str_edit_distance") + @SqlType(StandardTypes.BIGINT) + @LiteralParameters({"x", "y"}) + @ScalarFunctionConstantStats(minValue = 0) + public static long customStrEditDistance( + @ScalarPropagateSourceStats( + propagateAllStats = false, + nullFraction = StatsPropagationBehavior.USE_MAX_ARGUMENT, + distinctValuesCount = StatsPropagationBehavior.MAX_TYPE_WIDTH_VARCHAR, + maxValue = StatsPropagationBehavior.MAX_TYPE_WIDTH_VARCHAR) @SqlType("varchar(x)") Slice str1, + @SqlType("varchar(y)") Slice str2) + { + return 100; + } + + @ScalarFunction(value = "custom_prng", calledOnNullInput = true) + @SqlType(StandardTypes.BIGINT) + @LiteralParameters("x") + @ScalarFunctionConstantStats(nullFraction = 0) + public static long customPrng( + @SqlNullable + @ScalarPropagateSourceStats( + propagateAllStats = false, + distinctValuesCount = StatsPropagationBehavior.ROW_COUNT, + minValue = StatsPropagationBehavior.USE_SOURCE_STATS) @SqlType(StandardTypes.BIGINT) Long min, + @ScalarPropagateSourceStats( + propagateAllStats = false, + maxValue = StatsPropagationBehavior.USE_SOURCE_STATS + ) @SqlNullable @SqlType(StandardTypes.BIGINT) Long max) + { + return (long) ((Math.random() * (max - min)) + min); + } + + @ScalarFunction(value = "custom_is_null2", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(distinctValuesCount = -3.19, nullFraction = 0.0) + public static boolean customIsNullBigint2(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null3", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(minValue = -3.19, maxValue = -6.19, nullFraction = 0.0) + public static boolean customIsNullBigint3(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null4", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(nullFraction = 1.1) + public static boolean customIsNullBigint4(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null5", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(nullFraction = -1) + public static boolean customIsNullBigint5(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null6", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(avgRowSize = -1) + public static boolean customIsNullBigint6(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null7", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(avgRowSize = 8) + public static boolean customIsNullBigint7(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStatsAnnotationScalarFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStatsAnnotationScalarFunctions.java new file mode 100644 index 0000000000000..a440377722828 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStatsAnnotationScalarFunctions.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.operator.scalar.annotations.ScalarFromAnnotationsParser; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.LiteralParameters; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import io.airlift.slice.Slice; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestStatsAnnotationScalarFunctions + extends AbstractTestFunctions +{ + public TestStatsAnnotationScalarFunctions() + { + } + + protected TestStatsAnnotationScalarFunctions(FeaturesConfig config) + { + super(config); + } + + @Description("Functions with stats annotation") + public static class TestScalarFunction + { + @SqlType(StandardTypes.BOOLEAN) + @LiteralParameters("x") + @ScalarFunction + @ScalarFunctionConstantStats(avgRowSize = 2) + public static boolean fun1(@SqlType("varchar(x)") Slice slice) + { + return true; + } + + @SqlType(StandardTypes.BOOLEAN) + @LiteralParameters("x") + @ScalarFunction + @ScalarFunctionConstantStats(avgRowSize = 2, distinctValuesCount = 2.0) + public static boolean fun2( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.BIGINT) Slice slice) + { + return true; + } + } + + @Test + public void testAnnotations() + { + List sqlScalarFunctions = ScalarFromAnnotationsParser.parseFunctionDefinitions(TestScalarFunction.class); + assertEquals(sqlScalarFunctions.size(), 2); + for (SqlScalarFunction function : sqlScalarFunctions) { + assertTrue(function instanceof ParametricScalar); + ParametricScalar parametricScalar = (ParametricScalar) function; + Signature signature = parametricScalar.getSignature().canonicalization(); + Map scalarStatsHeaderMap = parametricScalar.getScalarHeader().getSignatureToScalarStatsHeadersMap(); + ScalarStatsHeader scalarStatsHeader = scalarStatsHeaderMap.get(signature); + assertEquals(scalarStatsHeader.getAvgRowSize(), 2); + if (function.getSignature().getName().toString().equals("fun2")) { + assertEquals(scalarStatsHeader.getDistinctValuesCount(), 2); + Map argumentStatsActual = scalarStatsHeader.getArgumentStats(); + assertEquals(argumentStatsActual.size(), 1); + assertTrue(argumentStatsActual.get(0).propagateAllStats()); + } + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index ebe971197a369..c0bf37ab12ee8 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -257,7 +257,8 @@ public void testDefaults() .setUseHistograms(false) .setInlineProjectionsOnValues(false) .setEagerPlanValidationEnabled(false) - .setEagerPlanValidationThreadPoolSize(20)); + .setEagerPlanValidationThreadPoolSize(20) + .setScalarFunctionStatsPropagationEnabled(false)); } @Test @@ -464,6 +465,7 @@ public void testExplicitPropertyMappings() .put("optimizer.inline-projections-on-values", "true") .put("eager-plan-validation-enabled", "true") .put("eager-plan-validation-thread-pool-size", "2") + .put("optimizer.scalar-function-stats-propagation-enabled", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -667,7 +669,8 @@ public void testExplicitPropertyMappings() .setUseHistograms(true) .setInlineProjectionsOnValues(true) .setEagerPlanValidationEnabled(true) - .setEagerPlanValidationThreadPoolSize(2); + .setEagerPlanValidationThreadPoolSize(2) + .setScalarFunctionStatsPropagationEnabled(true); assertFullMapping(properties, expected); } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestStatsPropagation.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestStatsPropagation.java new file mode 100644 index 0000000000000..85d569e55888e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/optimizations/TestStatsPropagation.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.cost.PlanNodeStatsEstimate; +import com.facebook.presto.cost.VariableStatsEstimate; +import com.facebook.presto.spi.WarningCollector; +import com.facebook.presto.sql.Optimizer; +import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.testing.TestngUtils; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.function.Predicate; + +import static com.facebook.presto.SystemSessionProperties.SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED; +import static java.lang.Double.isFinite; +import static org.testng.Assert.assertTrue; + +public class TestStatsPropagation + extends BasePlanTest +{ + private LocalQueryRunner queryRunner; + + private void assertPlanHasExpectedStats(Predicate statsChecker, @Language("SQL") String sql) + { + List optimizers = queryRunner.getPlanOptimizers(true); + queryRunner.inTransaction(queryRunner.getDefaultSession(), transactionSession -> { + Plan actualPlanResult = queryRunner.createPlan( + transactionSession, + sql, + optimizers, + Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED, + WarningCollector.NOOP); + + assertTrue(actualPlanResult.getStatsAndCosts().getStats().values().stream().allMatch(statsChecker), sql); + return null; + }); + } + + private void assertPlanHasExpectedVariableStats(Predicate statsChecker, String sql) + { + assertPlanHasExpectedStats(planNodeStatsEstimate -> planNodeStatsEstimate.getVariableStatistics().values().stream().allMatch(statsChecker), sql); + } + + @BeforeClass + public final void init() + throws Exception + { + queryRunner = createQueryRunner(ImmutableMap.of(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, "true")); + } + + @DataProvider(name = "queriesWithStringFunctionsInJoinClause") + public Object[][] queriesWithStringFunctionsInJoinClause() + { + return ImmutableList.of( + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and reverse(trim(l.comment)) = reverse(rtrim(ltrim(l.comment)))", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and lower(l.comment) = upper(l.comment)", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and ltrim(lpad(l.comment, 10, ' ')) = rtrim(rpad(l.comment, 10, ' '))", + //"SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and substr(lower(l.comment), 2) = 'us'", + // "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and l.comment LIKE '%u'", + // "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and l.comment LIKE '%u%'", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and levenshtein_distance(l.comment, 'no') = 2", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and hamming_distance(l.comment, 'no') = 2", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and normalize(l.comment, NFC) = 'us'", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and from_utf8(to_utf8(l.comment)) = 'us'", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and starts_with(o.orderstatus, l.comment)", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and ends_with(o.orderstatus, l.comment)", + // "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and concat(o.orderstatus, l.comment) LIKE '%new us%'", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and levenshtein_distance(l.comment, 'no') > 2", + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and levenshtein_distance(l.comment, 'no') < 20") + .stream().collect(TestngUtils.toDataProvider()); + } + + @DataProvider(name = "queriesWithMathFunctionsInJoinClause") + public Object[][] queriesWithMathFunctionsInJoinClause() + { + return ImmutableList.of( + "SELECT 1 FROM lineitem l, orders o WHERE l.orderkey=o.orderkey and l.discount = (SELECT random() FROM nation n where n.nationkey=1)", + "SELECT 1 FROM lineitem l, orders o WHERE l.orderkey=o.orderkey and log10(o.totalprice) > 1", + // "SELECT 1 FROM lineitem l, orders o WHERE l.orderkey=o.orderkey and is_nan(o.totalprice)", // failing due to source stats missing for orderkey. + "SELECT 1 FROM orders o, lineitem as l WHERE o.orderkey = l.orderkey and year(o.orderdate) <> year(l.shipdate) ") + .stream().collect(TestngUtils.toDataProvider()); + } + + @Test(dataProvider = "queriesWithStringFunctionsInJoinClause") + public void testStatsPropagationScalarStringFunction(@Language("SQL") String query) + { + ensurePlanNodesHaveStats(query); + } + + @Test(dataProvider = "queriesWithMathFunctionsInJoinClause") + public void testStatsPropagationScalarMathFunction(@Language("SQL") String query) + { + ensurePlanNodesHaveStats(query); + } + + private void ensurePlanNodesHaveStats(@Language("SQL") String query) + { + assertPlanHasExpectedStats(planNodeStatsEstimate -> !planNodeStatsEstimate.isOutputRowCountUnknown(), query); + assertPlanHasExpectedVariableStats(stats -> isFinite(stats.getDistinctValuesCount()), query); + assertPlanHasExpectedVariableStats(stats -> isFinite(stats.getNullsFraction()), query); + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java index a82a744b32fbe..bc769a45833d0 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java @@ -20,11 +20,13 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import static com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor.defaultFunctionDescriptor; import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static java.util.Collections.emptyMap; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; @@ -42,6 +44,7 @@ public class FunctionMetadata private final boolean calledOnNullInput; private final FunctionVersion version; private final ComplexTypeFunctionDescriptor descriptor; + private final Map signatureToScalarStatsHeaders; public FunctionMetadata( QualifiedObjectName name, @@ -52,7 +55,22 @@ public FunctionMetadata( boolean deterministic, boolean calledOnNullInput) { - this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned()); + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, + deterministic, calledOnNullInput, notVersioned(), emptyMap()); + } + + public FunctionMetadata( + QualifiedObjectName name, + List argumentTypes, + TypeSignature returnType, + FunctionKind functionKind, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + Map signatureToScalarStatsHeaders) + { + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, + deterministic, calledOnNullInput, notVersioned(), signatureToScalarStatsHeaders); } public FunctionMetadata( @@ -65,7 +83,23 @@ public FunctionMetadata( boolean calledOnNullInput, ComplexTypeFunctionDescriptor functionDescriptor) { - this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned(), functionDescriptor); + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, + deterministic, calledOnNullInput, notVersioned(), functionDescriptor, emptyMap()); + } + + public FunctionMetadata( + QualifiedObjectName name, + List argumentTypes, + TypeSignature returnType, + FunctionKind functionKind, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + ComplexTypeFunctionDescriptor functionDescriptor, + Map signatureToScalarStatsHeaders) + { + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, + deterministic, calledOnNullInput, notVersioned(), functionDescriptor, signatureToScalarStatsHeaders); } public FunctionMetadata( @@ -80,7 +114,8 @@ public FunctionMetadata( boolean calledOnNullInput, FunctionVersion version) { - this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, calledOnNullInput, version); + this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, + calledOnNullInput, version, emptyMap()); } public FunctionMetadata( @@ -97,7 +132,25 @@ public FunctionMetadata( ComplexTypeFunctionDescriptor functionDescriptor) { this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, - calledOnNullInput, version, functionDescriptor); + calledOnNullInput, version, functionDescriptor, emptyMap()); + } + + public FunctionMetadata( + QualifiedObjectName name, + List argumentTypes, + List argumentNames, + TypeSignature returnType, + FunctionKind functionKind, + Language language, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + FunctionVersion version, + ComplexTypeFunctionDescriptor functionDescriptor, + Map signatureToScalarStatsHeaders) + { + this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, + calledOnNullInput, version, functionDescriptor, signatureToScalarStatsHeaders); } public FunctionMetadata( @@ -109,7 +162,8 @@ public FunctionMetadata( boolean deterministic, boolean calledOnNullInput) { - this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned()); + this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), + implementationType, deterministic, calledOnNullInput, notVersioned(), emptyMap()); } public FunctionMetadata( @@ -122,7 +176,8 @@ public FunctionMetadata( boolean calledOnNullInput, ComplexTypeFunctionDescriptor functionDescriptor) { - this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned(), functionDescriptor); + this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), + implementationType, deterministic, calledOnNullInput, notVersioned(), functionDescriptor, emptyMap()); } private FunctionMetadata( @@ -136,7 +191,8 @@ private FunctionMetadata( FunctionImplementationType implementationType, boolean deterministic, boolean calledOnNullInput, - FunctionVersion version) + FunctionVersion version, + Map signatureToScalarStatsHeaders) { this( name, @@ -150,7 +206,8 @@ private FunctionMetadata( deterministic, calledOnNullInput, version, - defaultFunctionDescriptor()); + defaultFunctionDescriptor(), + signatureToScalarStatsHeaders); } private FunctionMetadata( @@ -165,7 +222,8 @@ private FunctionMetadata( boolean deterministic, boolean calledOnNullInput, FunctionVersion version, - ComplexTypeFunctionDescriptor functionDescriptor) + ComplexTypeFunctionDescriptor functionDescriptor, + Map signatureToScalarStatsHeaders) { this.name = requireNonNull(name, "name is null"); this.operatorType = requireNonNull(operatorType, "operatorType is null"); @@ -185,7 +243,9 @@ private FunctionMetadata( functionDescriptor.getArgumentIndicesContainingMapOrArray(), functionDescriptor.getOutputToInputTransformationFunction(), argumentTypes); + this.signatureToScalarStatsHeaders = signatureToScalarStatsHeaders; } + public FunctionKind getFunctionKind() { return functionKind; @@ -246,6 +306,16 @@ public ComplexTypeFunctionDescriptor getDescriptor() return descriptor; } + public boolean hasStatsHeader() + { + return !signatureToScalarStatsHeaders.isEmpty(); + } + + public Optional getScalarStatsHeader(Signature signature) + { + return Optional.ofNullable(signatureToScalarStatsHeaders.get(signature)); + } + @Override public boolean equals(Object obj) { @@ -267,12 +337,14 @@ public boolean equals(Object obj) Objects.equals(this.deterministic, other.deterministic) && Objects.equals(this.calledOnNullInput, other.calledOnNullInput) && Objects.equals(this.version, other.version) && - Objects.equals(this.descriptor, other.descriptor); + Objects.equals(this.descriptor, other.descriptor) && + Objects.equals(this.signatureToScalarStatsHeaders, other.signatureToScalarStatsHeaders); } @Override public int hashCode() { - return Objects.hash(name, operatorType, argumentTypes, argumentNames, returnType, functionKind, language, implementationType, deterministic, calledOnNullInput, version, descriptor); + return Objects.hash(name, operatorType, argumentTypes, argumentNames, returnType, functionKind, language, implementationType, deterministic, calledOnNullInput, version, + descriptor, signatureToScalarStatsHeaders); } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionConstantStats.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionConstantStats.java new file mode 100644 index 0000000000000..19dc8dc1828a8 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionConstantStats.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * By default, a function is just a “black box” that the database system knows very little about the behavior of. + * However, that means that queries using the function may be executed much less efficiently than they could be. + * It is possible to supply additional knowledge that helps the planner optimize function calls. + * Scalar functions are straight forward to optimize and can have impact on the overall query performance. + * Use this annotation to provide information regarding how this function impacts following query statistics. + *

+ * A function may take one or more input column or a constant as parameters. Precise stats may depend on the input + * parameters. This annotation does not cover all the possible cases and allows constant values for the following fields. + * Value Double.NaN implies unknown. + *

+ */ +@Retention(RUNTIME) +@Target(METHOD) +public @interface ScalarFunctionConstantStats +{ + // Min max value is Infinity if unknown. + double minValue() default Double.NEGATIVE_INFINITY; + double maxValue() default Double.POSITIVE_INFINITY; + + /** + * A constant value for Distinct values count regardless of `input column`'s source stats. + * e.g. a perfectly random generator may result in distinctValuesCount of `ScalarFunctionStatsUtils.ROW_COUNT`. + */ + double distinctValuesCount() default Double.NaN; + + /** + * A constant value for nullFraction, e.g. is_null(Slice) will alter column's null fraction + * value to 0.0, regardless of input column's source stats. + */ + double nullFraction() default Double.NaN; + + /** + * A constant value for `avgRowSize` e.g. a function like md5 may produce a + * constant row size. + */ + double avgRowSize() default Double.NaN; +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarPropagateSourceStats.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarPropagateSourceStats.java new file mode 100644 index 0000000000000..ddfe2571a9fcb --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarPropagateSourceStats.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target(PARAMETER) +public @interface ScalarPropagateSourceStats +{ + boolean propagateAllStats() default false; + + StatsPropagationBehavior minValue() default StatsPropagationBehavior.UNKNOWN; + StatsPropagationBehavior maxValue() default StatsPropagationBehavior.UNKNOWN; + StatsPropagationBehavior distinctValuesCount() default StatsPropagationBehavior.UNKNOWN; + StatsPropagationBehavior avgRowSize() default StatsPropagationBehavior.UNKNOWN; + StatsPropagationBehavior nullFraction() default StatsPropagationBehavior.UNKNOWN; +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarStatsHeader.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarStatsHeader.java new file mode 100644 index 0000000000000..9c65b4cbb0fb5 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarStatsHeader.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.util.Map; + +public class ScalarStatsHeader +{ + private final Map argumentStatsResolver; + private final double min; + private final double max; + private final double distinctValuesCount; + private final double nullFraction; + private final double avgRowSize; + + private ScalarStatsHeader(Map argumentStatsResolver, + double min, + double max, + double distinctValuesCount, + double nullFraction, + double avgRowSize) + { + this.min = min; + this.max = max; + this.argumentStatsResolver = argumentStatsResolver; + this.distinctValuesCount = distinctValuesCount; + this.nullFraction = nullFraction; + this.avgRowSize = avgRowSize; + } + + public ScalarStatsHeader(ScalarFunctionConstantStats methodConstantStats, Map argumentStatsResolver) + { + this(argumentStatsResolver, + methodConstantStats.minValue(), + methodConstantStats.maxValue(), + methodConstantStats.distinctValuesCount(), + methodConstantStats.nullFraction(), + methodConstantStats.avgRowSize()); + } + + public ScalarStatsHeader(Map argumentStatsResolver) + { + this(argumentStatsResolver, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.NaN, Double.NaN, Double.NaN); + } + + @Override + public String toString() + { + return String.format("distinctValuesCount: %g , nullFraction: %g, avgRowSize: %g, min: %g, max: %g", + distinctValuesCount, nullFraction, avgRowSize, min, max); + } + + /* + * Get stats annotation for each of the scalar function argument, where key is the index of the position + * of functions' argument and value is the ScalarPropagateSourceStats annotation. + */ + public Map getArgumentStats() + { + return argumentStatsResolver; + } + + public double getMin() + { + return min; + } + + public double getMax() + { + return max; + } + + public double getAvgRowSize() + { + return avgRowSize; + } + + public double getNullFraction() + { + return nullFraction; + } + + public double getDistinctValuesCount() + { + return distinctValuesCount; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java index 024ba43035bd0..d39d08033ccd3 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java @@ -168,6 +168,17 @@ public String toString() "(" + String.join(",", argumentTypes.stream().map(TypeSignature::toString).collect(toList())) + "):" + returnType; } + /* + * Canonical (normalized i.e. erased type size bounds) form of signature instance. + */ + public Signature canonicalization() + { + return new Signature(this.name, this.kind, new TypeSignature(this.returnType.getBase(), emptyList()), + argumentTypes + .stream() + .map(argumentTypeSignature -> new TypeSignature(argumentTypeSignature.getBase(), emptyList())).collect(toList())); + } + /* * similar to T extends MyClass, if Java supported varargs wildcards */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/StatsPropagationBehavior.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/StatsPropagationBehavior.java new file mode 100644 index 0000000000000..e6a1fb0dd9b18 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/StatsPropagationBehavior.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import static java.util.Collections.unmodifiableSet; + +public enum StatsPropagationBehavior +{ + /** Use the max value across all arguments to derive the new stats value */ + USE_MAX_ARGUMENT, + /** Use the min value across all arguments to derive the new stats value */ + USE_MIN_ARGUMENT, + /** Sum the stats value of all arguments to derive the new stats value */ + SUM_ARGUMENTS, + /** Propagate the source stats as-is */ + USE_SOURCE_STATS, + /** calculate logarithm with base 10 of the arguments source stats */ + LOG10_SOURCE_STATS, + /** calculate logarithm with base 2 of the arguments source stats */ + LOG2_SOURCE_STATS, + /** calculate natural logarithm of the arguments source stats */ + LOG_NATURAL_SOURCE_STATS, + // Following stats are independent of source stats. + /** Use the value of output row count. */ + ROW_COUNT, + /** Use the value of row_count * (1 - null_fraction). */ + NON_NULL_ROW_COUNT, + /** use the value of TYPE_WIDTH in varchar(TYPE_WIDTH) */ + USE_TYPE_WIDTH_VARCHAR, + /** Take max of type width of arguments with varchar type. */ + MAX_TYPE_WIDTH_VARCHAR, + /** Stats are unknown and thus no action is performed. */ + UNKNOWN; + /* + * Stats are multi argument when their value is calculated by operating on stats from source stats or other properties of the all the arguments. + */ + private static final Set MULTI_ARGUMENT_STATS = + unmodifiableSet( + new HashSet<>(Arrays.asList(MAX_TYPE_WIDTH_VARCHAR, USE_MAX_ARGUMENT, USE_MIN_ARGUMENT, SUM_ARGUMENTS))); + private static final Set SOURCE_STATS_DEPENDENT_STATS = + unmodifiableSet( + new HashSet<>(Arrays.asList(USE_MAX_ARGUMENT, USE_MIN_ARGUMENT, SUM_ARGUMENTS, USE_SOURCE_STATS))); + + public static final class Constants + { + public static final int ROW_COUNT_CONST = -1; + public static final int NON_NULL_ROW_COUNT_CONST = -10; + } + + public boolean isMultiArgumentStat() + { + return MULTI_ARGUMENT_STATS.contains(this); + } + + public boolean isSourceStatsDependentStats() + { + return SOURCE_STATS_DEPENDENT_STATS.contains(this); + } +}