Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ public boolean isArithmeticOperator()
return this.equals(ADD) || this.equals(SUBTRACT) || this.equals(MULTIPLY) || this.equals(DIVIDE) || this.equals(MODULUS);
}

public boolean isHashOperator()
{
return this.equals(HASH_CODE) ||
this.equals(XX_HASH_64);
}

public static Optional<OperatorType> tryGetOperatorType(QualifiedObjectName operatorName)
{
return Optional.ofNullable(OPERATOR_TYPES.get(operatorName));
Expand Down
11 changes: 11 additions & 0 deletions presto-docs/src/main/sphinx/admin/properties.rst
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,17 @@ This can also be specified on a per-query basis using the ``confidence_based_bro
Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. This can also be specified
on a per-query basis using the ``treat-low-confidence-zero-estimation-as-unknown`` session property.

``optimizer.scalar-function-stats-propagation-enabled``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


* **Type:** ``boolean``
* **Default value:** ``false``

Enable scalar functions stats propagation using annotations. Annotations define the behavior of the scalar
function's stats characteristics. When set to ``true``, this property enables the stats propagation through annotations.
This can also be specified on a per-query basis using the ``scalar_function_stats_propagation_enabled`` session property.

``optimizer.retry-query-with-history-based-optimization``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ public final class SystemSessionProperties
public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms";
public static final String WARN_ON_COMMON_NAN_PATTERNS = "warn_on_common_nan_patterns";
public static final String INLINE_PROJECTIONS_ON_VALUES = "inline_projections_on_values";
public static final String SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED = "scalar_function_stats_propagation_enabled";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -2077,6 +2078,10 @@ public SystemSessionProperties(
booleanProperty(INLINE_PROJECTIONS_ON_VALUES,
"Whether to evaluate project node on values node",
featuresConfig.getInlineProjectionsOnValues(),
false),
booleanProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED,
"whether or not to respect stats propagation annotation for scalar functions (or UDF)",
featuresConfig.isScalarFunctionStatsPropagationEnabled(),
false));
}

Expand Down Expand Up @@ -3414,4 +3419,9 @@ public static boolean isInlineProjectionsOnValues(Session session)
{
return session.getSystemProperty(INLINE_PROJECTIONS_ON_VALUES, Boolean.class);
}

public static boolean shouldEnableScalarFunctionStatsPropagation(Session session)
{
return session.getSystemProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, Boolean.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.cost;

import com.facebook.presto.spi.function.ScalarPropagateSourceStats;
import com.facebook.presto.spi.function.ScalarStatsHeader;
import com.facebook.presto.spi.function.StatsPropagationBehavior;
import com.facebook.presto.spi.relation.CallExpression;

import java.util.List;
import java.util.Map;

import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.firstFiniteValue;
import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.getReturnTypeWidth;
import static com.facebook.presto.cost.ScalarStatsCalculatorUtils.getTypeWidth;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.Constants.NON_NULL_ROW_COUNT_CONST;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.Constants.ROW_COUNT_CONST;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS;
import static com.facebook.presto.util.MoreMath.max;
import static com.facebook.presto.util.MoreMath.min;
import static com.facebook.presto.util.MoreMath.minExcludingNaNs;
import static com.facebook.presto.util.MoreMath.nearlyEqual;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Double.NaN;
import static java.lang.Double.isFinite;
import static java.lang.Double.isNaN;

public final class ScalarStatsAnnotationProcessor
{
private ScalarStatsAnnotationProcessor()
{
}

public static VariableStatsEstimate computeStatsFromAnnotations(
CallExpression callExpression,
List<VariableStatsEstimate> sourceStats,
ScalarStatsHeader scalarStatsHeader,
double outputRowCount)
{
double nullFraction = scalarStatsHeader.getNullFraction();
double distinctValuesCount = NaN;
double averageRowSize = NaN;
double maxValue = scalarStatsHeader.getMax();
double minValue = scalarStatsHeader.getMin();
for (Map.Entry<Integer, ScalarPropagateSourceStats> paramIndexToStatsMap : scalarStatsHeader.getArgumentStats().entrySet()) {
ScalarPropagateSourceStats scalarPropagateSourceStats = paramIndexToStatsMap.getValue();
boolean propagateAllStats = scalarPropagateSourceStats.propagateAllStats();
nullFraction = min(firstFiniteValue(nullFraction, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getNullsFraction).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.nullFraction()))), 1.0);
distinctValuesCount = firstFiniteValue(distinctValuesCount, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getDistinctValuesCount).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.distinctValuesCount())));
StatsPropagationBehavior averageRowSizeStatsBehaviour = applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.avgRowSize());
averageRowSize = minExcludingNaNs(firstFiniteValue(averageRowSize, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getAverageRowSize).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
averageRowSizeStatsBehaviour)), getReturnTypeWidth(callExpression, averageRowSizeStatsBehaviour));
maxValue = firstFiniteValue(maxValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getHighValue).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.maxValue())));
minValue = firstFiniteValue(minValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getLowValue).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.minValue())));
}
if (isNaN(maxValue) || isNaN(minValue)) {
minValue = NaN;
maxValue = NaN;
}
return VariableStatsEstimate.builder()
.setLowValue(minValue)
.setHighValue(maxValue)
.setNullsFraction(nullFraction)
.setAverageRowSize(firstFiniteValue(scalarStatsHeader.getAvgRowSize(), averageRowSize, getReturnTypeWidth(callExpression, UNKNOWN)))
.setDistinctValuesCount(processDistinctValuesCount(outputRowCount, nullFraction, scalarStatsHeader.getDistinctValuesCount(), distinctValuesCount)).build();
}

private static double processDistinctValuesCount(double outputRowCount, double nullFraction, double distinctValuesCountFromConstant, double distinctValuesCount)
{
if (isFinite(distinctValuesCountFromConstant)) {
if (nearlyEqual(distinctValuesCountFromConstant, NON_NULL_ROW_COUNT_CONST, 0.1)) {
distinctValuesCountFromConstant = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0));
}
else if (nearlyEqual(distinctValuesCountFromConstant, ROW_COUNT_CONST, 0.1)) {
distinctValuesCountFromConstant = outputRowCount;
}
}
double distinctValuesCountFinal = firstFiniteValue(distinctValuesCountFromConstant, distinctValuesCount);
if (distinctValuesCountFinal > outputRowCount) {
distinctValuesCountFinal = NaN;
}
return distinctValuesCountFinal;
}

private static double processSingleArgumentStatistic(
double outputRowCount,
double nullFraction,
CallExpression callExpression,
List<Double> sourceStats,
int sourceStatsArgumentIndex,
StatsPropagationBehavior operation)
{
// sourceStatsArgumentIndex is index of the argument on which
// ScalarPropagateSourceStats annotation was applied.
double statValue = NaN;
if (operation.isMultiArgumentStat()) {
for (int i = 0; i < sourceStats.size(); i++) {
if (i == 0 && operation.isSourceStatsDependentStats() && isFinite(sourceStats.get(i))) {
statValue = sourceStats.get(i);
}
else {
switch (operation) {
case MAX_TYPE_WIDTH_VARCHAR:
statValue = getTypeWidth(callExpression.getArguments().get(i).getType());
break;
case USE_MIN_ARGUMENT:
statValue = min(statValue, sourceStats.get(i));
break;
case USE_MAX_ARGUMENT:
statValue = max(statValue, sourceStats.get(i));
break;
case SUM_ARGUMENTS:
statValue = statValue + sourceStats.get(i);
break;
}
}
}
}
else {
switch (operation) {
case USE_SOURCE_STATS:
statValue = sourceStats.get(sourceStatsArgumentIndex);
break;
case ROW_COUNT:
statValue = outputRowCount;
break;
case NON_NULL_ROW_COUNT:
statValue = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0));
break;
case USE_TYPE_WIDTH_VARCHAR:
statValue = getTypeWidth(callExpression.getArguments().get(sourceStatsArgumentIndex).getType());
break;
case LOG10_SOURCE_STATS:
statValue = Math.log10(sourceStats.get(sourceStatsArgumentIndex));
break;
case LOG2_SOURCE_STATS:
statValue = Math.log(sourceStats.get(sourceStatsArgumentIndex)) / Math.log(2);
break;
case LOG_NATURAL_SOURCE_STATS:
statValue = Math.log(sourceStats.get(sourceStatsArgumentIndex));
}
}
return statValue;
}

private static StatsPropagationBehavior applyPropagateAllStats(
boolean propagateAllStats, StatsPropagationBehavior operation)
{
if (operation == UNKNOWN && propagateAllStats) {
return USE_SOURCE_STATS;
}
return operation;
}
}
Loading