From 19eccaa2fe0a7c4f8e4da5bfa36dcb867cee5fa1 Mon Sep 17 00:00:00 2001 From: James Xu Date: Wed, 10 Dec 2025 20:16:46 +0800 Subject: [PATCH 1/3] [AURON #1745] Introduce auron-spark-tests submodule for correctness testing --- .github/workflows/tpcds-reusable.yml | 10 + .github/workflows/tpcds.yml | 1 + auron-build.sh | 15 + auron-spark-tests/common/pom.xml | 133 +++++++ .../src/test/resources/log4j.properties | 46 +++ .../src/test/resources/log4j2.properties | 68 ++++ .../auron/utils/SQLQueryTestSettings.scala | 27 ++ .../auron/utils/SparkTestSettings.scala | 208 ++++++++++ .../spark/sql/SparkExpressionTestsBase.scala | 357 ++++++++++++++++++ .../spark/sql/SparkQueryTestsBase.scala | 273 ++++++++++++++ .../org/apache/spark/sql/SparkTestsBase.scala | 106 ++++++ .../sql/SparkTestsSharedSessionBase.scala | 82 ++++ .../spark/utils/DebuggableThreadUtils.scala | 51 +++ auron-spark-tests/pom.xml | 92 +++++ auron-spark-tests/spark33/pom.xml | 146 +++++++ .../auron/utils/AuronSparkTestSettings.scala | 39 ++ .../spark/sql/AuronStringFunctionsSuite.scala | 19 + pom.xml | 37 ++ 18 files changed, 1710 insertions(+) create mode 100644 auron-spark-tests/common/pom.xml create mode 100644 auron-spark-tests/common/src/test/resources/log4j.properties create mode 100644 auron-spark-tests/common/src/test/resources/log4j2.properties create mode 100644 auron-spark-tests/common/src/test/scala/org/apache/auron/utils/SQLQueryTestSettings.scala create mode 100644 auron-spark-tests/common/src/test/scala/org/apache/auron/utils/SparkTestSettings.scala create mode 100644 auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkExpressionTestsBase.scala create mode 100644 auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkQueryTestsBase.scala create mode 100644 auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkTestsBase.scala create mode 100644 auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkTestsSharedSessionBase.scala create mode 100644 auron-spark-tests/common/src/test/scala/org/apache/spark/utils/DebuggableThreadUtils.scala create mode 100644 auron-spark-tests/pom.xml create mode 100644 auron-spark-tests/spark33/pom.xml create mode 100644 auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala create mode 100644 auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronStringFunctionsSuite.scala diff --git a/.github/workflows/tpcds-reusable.yml b/.github/workflows/tpcds-reusable.yml index 7617b9d29..eab67d144 100644 --- a/.github/workflows/tpcds-reusable.yml +++ b/.github/workflows/tpcds-reusable.yml @@ -24,6 +24,11 @@ on: description: 'Maven profile id to resolve sparkVersion (e.g., spark-3.5)' required: true type: string + sparktests: + description: 'Whether to enable spark correctness tests' + required: false + type: string + default: 'false' hadoop-profile: description: 'Hadoop profile (e.g., hadoop2.7, hadoop3)' required: true @@ -183,6 +188,11 @@ jobs: CMD="$CMD --uniffle $UNIFFLE_NUMBER" fi + SPARK_TESTS="${{ inputs.sparktests }}" + if [[ "$SPARK_TESTS" == "true" ]]; then + CMD="$CMD --sparktests true" + fi + echo "Running: $CMD" exec $CMD diff --git a/.github/workflows/tpcds.yml b/.github/workflows/tpcds.yml index 588744ee3..87ac607f0 100644 --- a/.github/workflows/tpcds.yml +++ b/.github/workflows/tpcds.yml @@ -60,6 +60,7 @@ jobs: with: sparkver: spark-3.3 hadoop-profile: 'hadoop3' + sparktests: true test-spark-34-jdk11: name: Test spark-3.4 diff --git a/auron-build.sh b/auron-build.sh index 191c9ec8a..0fabf830b 100755 --- a/auron-build.sh +++ b/auron-build.sh @@ -53,6 +53,7 @@ print_help() { echo " --release Activate release profile" echo " --clean Clean before build (default: true)" echo " --skiptests Skip unit tests (default: true)" + echo " --sparktests Run spark tests (default: false)" echo " --docker Build in Docker environment (default: false)" IFS=','; echo " --image Docker image to use (e.g. ${SUPPORTED_OS_IMAGES[*]}, default: ${SUPPORTED_OS_IMAGES[*]:0:1})"; unset IFS IFS=','; echo " --sparkver Specify Spark version (e.g. ${SUPPORTED_SPARK_VERSIONS[*]})"; unset IFS @@ -124,6 +125,7 @@ PRE_PROFILE=false RELEASE_PROFILE=false CLEAN=true SKIP_TESTS=true +SPARK_TESTS=false SPARK_VER="" FLINK_VER="" SCALA_VER="" @@ -187,6 +189,15 @@ while [[ $# -gt 0 ]]; do exit 1 fi ;; + --sparktests) + if [[ -n "$2" && "$2" =~ ^(true|false)$ ]]; then + SPARK_TESTS="$2" + shift 2 + else + echo "ERROR: --sparktests requires true/false" >&2 + exit 1 + fi + ;; --sparkver) if [[ -n "$2" && "$2" != -* ]]; then SPARK_VER="$2" @@ -353,6 +364,10 @@ else BUILD_ARGS+=("package") fi +if [[ "$SPARK_TESTS" == true ]]; then + BUILD_ARGS+=("-Pspark-tests") +fi + if [[ "$PRE_PROFILE" == true ]]; then BUILD_ARGS+=("-Ppre") fi diff --git a/auron-spark-tests/common/pom.xml b/auron-spark-tests/common/pom.xml new file mode 100644 index 000000000..02eefdcb7 --- /dev/null +++ b/auron-spark-tests/common/pom.xml @@ -0,0 +1,133 @@ + + + + 4.0.0 + + + org.apache.auron + auron-spark-tests + ${project.version} + ../pom.xml + + + auron-spark-tests-common + jar + Auron Spark Test Common + + + + org.apache.auron + spark-extension_${scalaVersion} + ${project.version} + + + org.apache.spark + spark-core_${scalaVersion} + test + + + org.apache.spark + spark-catalyst_${scalaVersion} + test + + + org.apache.spark + spark-core_${scalaVersion} + test-jar + + + org.apache.spark + spark-sql_${scalaVersion} + test-jar + test + + + org.apache.spark + spark-catalyst_${scalaVersion} + test-jar + test + + + org.apache.spark + spark-hive_${scalaVersion} + test-jar + test + + + org.scalatestplus + scalatestplus-scalacheck_${scalaVersion} + test + + + + + + + org.apache.maven.plugins + maven-resources-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + + + + compile + + compile + + + + + org.scalastyle + scalastyle-maven-plugin + + + org.apache.maven.plugins + maven-checkstyle-plugin + + + org.scalatest + scalatest-maven-plugin + + + test + + test + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + prepare-test-jar + + test-jar + + test-compile + + + + + target/scala-${scalaVersion}/classes + target/scala-${scalaVersion}/test-classes + + diff --git a/auron-spark-tests/common/src/test/resources/log4j.properties b/auron-spark-tests/common/src/test/resources/log4j.properties new file mode 100644 index 000000000..4c4853a57 --- /dev/null +++ b/auron-spark-tests/common/src/test/resources/log4j.properties @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +log4j.rootLogger=WARN, CA + +#Console Appender +log4j.appender.CA=org.apache.log4j.ConsoleAppender +log4j.appender.CA.layout=org.apache.log4j.PatternLayout +log4j.appender.CA.layout.ConversionPattern=%d{HH:mm:ss.SSS} %p %c: %m%n +log4j.appender.CA.Threshold=DEBUG +log4j.appender.CA.follow=true + +# Some packages are noisy for no good reason. +log4j.additivity.org.apache.parquet.hadoop.ParquetRecordReader=false +log4j.logger.org.apache.parquet.hadoop.ParquetRecordReader=OFF + +log4j.additivity.org.apache.parquet.hadoop.ParquetOutputCommitter=false +log4j.logger.org.apache.parquet.hadoop.ParquetOutputCommitter=OFF + +log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false +log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF + +log4j.additivity.org.apache.hadoop.hive.metastore.RetryingHMSHandler=false +log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF + +log4j.additivity.hive.ql.metadata.Hive=false +log4j.logger.hive.ql.metadata.Hive=OFF + +# Parquet related logging +log4j.logger.org.apache.parquet.CorruptStatistics=ERROR +log4j.logger.parquet.CorruptStatistics=ERROR diff --git a/auron-spark-tests/common/src/test/resources/log4j2.properties b/auron-spark-tests/common/src/test/resources/log4j2.properties new file mode 100644 index 000000000..1723186a2 --- /dev/null +++ b/auron-spark-tests/common/src/test/resources/log4j2.properties @@ -0,0 +1,68 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file core/target/unit-tests.log +rootLogger.level = warn +rootLogger.appenderRef.stdout.ref = STDOUT + +#Console Appender +appender.console.type = Console +appender.console.name = STDOUT +appender.console.target = SYSTEM_OUT +appender.console.layout.type = PatternLayout +appender.console.layout.pattern = %d{HH:mm:ss.SSS} %p %c: %m%n%ex +appender.console.filter.threshold.type = ThresholdFilter +appender.console.filter.threshold.level = debug + +#File Appender +#appender.file.type = File +#appender.file.name = File +#appender.file.fileName = target/unit-tests.log +#appender.file.layout.type = PatternLayout +#appender.file.layout.pattern = %d{HH:mm:ss.SSS} %t %p %c{1}: %m%n%ex + +# Set the logger level of File Appender to WARN +# appender.file.filter.threshold.type = ThresholdFilter +# appender.file.filter.threshold.level = info + +# Some packages are noisy for no good reason. +logger.parquet_recordreader.name = org.apache.parquet.hadoop.ParquetRecordReader +logger.parquet_recordreader.additivity = false +logger.parquet_recordreader.level = off + +logger.parquet_outputcommitter.name = org.apache.parquet.hadoop.ParquetOutputCommitter +logger.parquet_outputcommitter.additivity = false +logger.parquet_outputcommitter.level = off + +logger.hadoop_lazystruct.name = org.apache.hadoop.hive.serde2.lazy.LazyStruct +logger.hadoop_lazystruct.additivity = false +logger.hadoop_lazystruct.level = off + +logger.hadoop_retryinghmshandler.name = org.apache.hadoop.hive.metastore.RetryingHMSHandler +logger.hadoop_retryinghmshandler.additivity = false +logger.hadoop_retryinghmshandler.level = off + +logger.hive_metadata.name = hive.ql.metadata.Hive +logger.hive_metadata.additivity = false +logger.hive_metadata.level = off + +# Parquet related logging +logger.parquet1.name = org.apache.parquet.CorruptStatistics +logger.parquet1.level = error + +logger.parquet2.name = parquet.CorruptStatistics +logger.parquet2.level = error diff --git a/auron-spark-tests/common/src/test/scala/org/apache/auron/utils/SQLQueryTestSettings.scala b/auron-spark-tests/common/src/test/scala/org/apache/auron/utils/SQLQueryTestSettings.scala new file mode 100644 index 000000000..979f26579 --- /dev/null +++ b/auron-spark-tests/common/src/test/scala/org/apache/auron/utils/SQLQueryTestSettings.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron.utils + +trait SQLQueryTestSettings { + def getResourceFilePath: String + + def getSupportedSQLQueryTests: Set[String] + + def getOverwriteSQLQueryTests: Set[String] + + def getIgnoredSQLQueryTests: List[String] = List.empty +} diff --git a/auron-spark-tests/common/src/test/scala/org/apache/auron/utils/SparkTestSettings.scala b/auron-spark-tests/common/src/test/scala/org/apache/auron/utils/SparkTestSettings.scala new file mode 100644 index 000000000..5ca56c129 --- /dev/null +++ b/auron-spark-tests/common/src/test/scala/org/apache/auron/utils/SparkTestSettings.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron.utils + +import java.util + +import scala.collection.JavaConverters._ +import scala.reflect.ClassTag + +/** + * Test settings which can enable/disable some tests on demand(e.g. Auron has some correctness bug + * that not fixed yet). + */ +abstract class SparkTestSettings { + private val AURON_TEST: String = "Auron - " + private val enabledSuites: java.util.Map[String, SuiteSettings] = new util.HashMap() + + protected def enableSuite[T: ClassTag]: SuiteSettings = { + enableSuite(implicitly[ClassTag[T]].runtimeClass.getCanonicalName) + } + + protected def enableSuite(suiteName: String): SuiteSettings = { + if (enabledSuites.containsKey(suiteName)) { + throw new IllegalArgumentException("Duplicated suite name: " + suiteName) + } + val suiteSettings = new SuiteSettings + enabledSuites.put(suiteName, suiteSettings) + suiteSettings + } + + private[utils] def shouldRun(suiteName: String, testName: String): Boolean = { + if (!enabledSuites.containsKey(suiteName)) { + return false + } + + val suiteSettings = enabledSuites.get(suiteName) + suiteSettings.disableReason match { + case Some(_) => return false + case _ => // continue + } + + val inclusion = suiteSettings.inclusion.asScala + val exclusion = suiteSettings.exclusion.asScala + + if (inclusion.isEmpty && exclusion.isEmpty) { + // default to run all cases under this suite + return true + } + + if (inclusion.nonEmpty && exclusion.nonEmpty) { + // error + throw new IllegalStateException( + s"Do not use include and exclude conditions on the same test case: $suiteName:$testName") + } + + if (inclusion.nonEmpty) { + // include mode + val isIncluded = inclusion.exists(_.isIncluded(testName)) + return isIncluded + } + + if (exclusion.nonEmpty) { + // exclude mode + val isExcluded = exclusion.exists(_.isExcluded(testName)) + return !isExcluded + } + + throw new IllegalStateException("Unreachable code from shouldRun") + } + + /** + * Settings for each test suite. + * + * Each test suite is consists of many tests. With SuiteSettings we can control which test to + * include/exclude. + */ + final protected class SuiteSettings { + private[utils] val inclusion: util.List[IncludeBase] = new util.ArrayList() + private[utils] val exclusion: util.List[ExcludeBase] = new util.ArrayList() + + private[utils] var disableReason: Option[String] = None + + def include(testNames: String*): SuiteSettings = { + inclusion.add(Include(testNames: _*)) + this + } + + def exclude(testNames: String*): SuiteSettings = { + exclusion.add(Exclude(testNames: _*)) + this + } + + def includeByPrefix(prefixes: String*): SuiteSettings = { + inclusion.add(IncludeByPrefix(prefixes: _*)) + this + } + def excludeByPrefix(prefixes: String*): SuiteSettings = { + exclusion.add(ExcludeByPrefix(prefixes: _*)) + this + } + + def disable(reason: String): SuiteSettings = { + disableReason = disableReason match { + case Some(r) => throw new IllegalArgumentException("Disable reason already set: " + r) + case None => Some(reason) + } + this + } + } + + object SuiteSettings { + implicit class SuiteSettingsImplicits(settings: SuiteSettings) { + def includeAuronTest(testName: String*): SuiteSettings = { + settings.include(testName.map(AURON_TEST + _): _*) + settings + } + + def excludeAuronTest(testName: String*): SuiteSettings = { + settings.exclude(testName.map(AURON_TEST + _): _*) + settings + } + + def includeAuronTestsByPrefix(prefixes: String*): SuiteSettings = { + settings.includeByPrefix(prefixes.map(AURON_TEST + _): _*) + settings + } + + def excludeAuronTestsByPrefix(prefixes: String*): SuiteSettings = { + settings.excludeByPrefix(prefixes.map(AURON_TEST + _): _*) + settings + } + + def includeAllAuronTests(): SuiteSettings = { + settings.include(AURON_TEST) + settings + } + + def excludeAllAuronTests(): SuiteSettings = { + settings.exclude(AURON_TEST) + settings + } + } + } + + protected trait IncludeBase { + def isIncluded(testName: String): Boolean + } + + protected trait ExcludeBase { + def isExcluded(testName: String): Boolean + } + + private case class Include(testNames: String*) extends IncludeBase { + val nameSet: Set[String] = Set(testNames: _*) + override def isIncluded(testName: String): Boolean = nameSet.contains(testName) + } + + private case class Exclude(testNames: String*) extends ExcludeBase { + val nameSet: Set[String] = Set(testNames: _*) + override def isExcluded(testName: String): Boolean = nameSet.contains(testName) + } + + private case class IncludeByPrefix(prefixes: String*) extends IncludeBase { + override def isIncluded(testName: String): Boolean = { + if (prefixes.exists(prefix => testName.startsWith(prefix))) { + return true + } + false + } + } + + private case class ExcludeByPrefix(prefixes: String*) extends ExcludeBase { + override def isExcluded(testName: String): Boolean = { + if (prefixes.exists(prefix => testName.startsWith(prefix))) { + return true + } + false + } + } + + def getSQLQueryTestSettings: SQLQueryTestSettings +} + +object SparkTestSettings { + val instance: SparkTestSettings = Class + .forName("org.apache.auron.utils.AuronSparkTestSettings") + .getDeclaredConstructor() + .newInstance() + .asInstanceOf[SparkTestSettings] + + def shouldRun(suiteName: String, testName: String): Boolean = { + instance.shouldRun(suiteName, testName: String) + } +} diff --git a/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkExpressionTestsBase.scala b/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkExpressionTestsBase.scala new file mode 100644 index 000000000..14490aef0 --- /dev/null +++ b/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkExpressionTestsBase.scala @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.io.File + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.commons.io.FileUtils +import org.apache.commons.math3.util.Precision +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, ConvertToLocalRelation, NullPropagation} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.execution.auron.plan.NativeProjectBase +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String +import org.scalactic.TripleEqualsSupport.Spread + +/** + * Base trait for all Spark expression tests. + */ +trait SparkExpressionTestsBase + extends SparkFunSuite + with ExpressionEvalHelper + with SparkTestsBase { + val SUPPORTED_DATA_TYPES = TypeCollection( + BooleanType, + ByteType, + ShortType, + IntegerType, + LongType, + FloatType, + DoubleType, + DecimalType, + StringType, + BinaryType, + DateType, + TimestampType, + ArrayType, + StructType, + MapType) + + override def beforeAll(): Unit = { + // Prepare working paths. + val basePathDir = new File(basePath) + if (basePathDir.exists()) { + FileUtils.forceDelete(basePathDir) + } + FileUtils.forceMkdir(basePathDir) + FileUtils.forceMkdir(new File(warehouse)) + FileUtils.forceMkdir(new File(metaStorePathAbsolute)) + + super.beforeAll() + initializeSession() + _spark.sparkContext.setLogLevel("WARN") + } + + override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + } + + protected def initializeSession(): Unit = { + if (_spark == null) { + val sparkBuilder = SparkSession + .builder() + .appName("Auron-UT") + .master(s"local[2]") + // Avoid static evaluation for literal input by spark catalyst. + .config( + SQLConf.OPTIMIZER_EXCLUDED_RULES.key, + ConvertToLocalRelation.ruleName + + "," + ConstantFolding.ruleName + "," + NullPropagation.ruleName) + + for ((key, value) <- sparkConfList) { + sparkBuilder.config(key, value) + } + + _spark = sparkBuilder + .getOrCreate() + } + } + + protected var _spark: SparkSession = null + + override protected def checkEvaluation( + expression: => Expression, + expected: Any, + inputRow: InternalRow = EmptyRow): Unit = { + + if (canConvertToDataFrame(inputRow)) { + val resolver = ResolveTimeZone + val expr = resolver.resolveTimeZones(expression) + assert(expr.resolved) + + auronCheckExpression(expr, expected, inputRow) + } else { + logWarning( + "Skipping evaluation - Nonempty inputRow cannot be converted to DataFrame " + + "due to complex/unsupported types.\n") + } + } + + def auronCheckExpression(expression: Expression, expected: Any, inputRow: InternalRow): Unit = { + val df = if (inputRow != EmptyRow && inputRow != InternalRow.empty) { + convertInternalRowToDataFrame(inputRow) + } else { + val schema = StructType(StructField("a", IntegerType, nullable = true) :: Nil) + val empData = Seq(Row(1)) + _spark.createDataFrame(_spark.sparkContext.parallelize(empData), schema) + } + val resultDF = df.select(Column(expression)) + val result = resultDF.collect() + + if (checkDataTypeSupported(expression) && + expression.children.forall(checkDataTypeSupported)) { + val projectExec = resultDF.queryExecution.executedPlan.collect { + case p: NativeProjectBase => p + } + + if (projectExec.size == 1) { + logInfo("Offload to native backend in the test.\n") + } else { + logInfo("Not supported in Auron, fall back to vanilla spark in the test.\n") + shouldNotFallback() + } + } else { + logInfo("Has unsupported data type, fall back to vanilla spark.\n") + shouldNotFallback() + } + + if (!(checkResult(result.head.get(0), expected, expression.dataType, expression.nullable) + || checkResult( + CatalystTypeConverters.createToCatalystConverter(expression.dataType)( + result.head.get(0) + ), // decimal precision is wrong from value + CatalystTypeConverters.convertToCatalyst(expected), + expression.dataType, + expression.nullable))) { + val input = if (inputRow == EmptyRow) "" else s", input: $inputRow" + fail( + s"Incorrect evaluation: $expression, " + + s"actual: ${result.head.get(0)}, " + + s"expected: $expected$input") + } + } + + /** + * Sort map data by key and return the sorted key array and value array. + * + * @param input + * input map data. + * @param kt + * key type. + * @param vt + * value type. + * @return + * the sorted key array and value array. + */ + private def getSortedArrays( + input: MapData, + kt: DataType, + vt: DataType): (ArrayData, ArrayData) = { + val keyArray = input.keyArray().toArray[Any](kt) + val valueArray = input.valueArray().toArray[Any](vt) + val newMap = (keyArray.zip(valueArray)).toMap + val sortedMap = mutable.SortedMap(newMap.toSeq: _*)(TypeUtils.getInterpretedOrdering(kt)) + (new GenericArrayData(sortedMap.keys.toArray), new GenericArrayData(sortedMap.values.toArray)) + } + + def isNaNOrInf(num: Double): Boolean = { + num.isNaN || num.isInfinite + } + + override protected def checkResult( + result: Any, + expected: Any, + exprDataType: DataType, + exprNullable: Boolean): Boolean = { + val dataType = UserDefinedType.sqlType(exprDataType) + + // The result is null for a non-nullable expression + assert(result != null || exprNullable, "exprNullable should be true if result is null") + (result, expected) match { + case (result: Array[Byte], expected: Array[Byte]) => + java.util.Arrays.equals(result, expected) + case (result: Double, expected: Spread[Double @unchecked]) => + expected.asInstanceOf[Spread[Double]].isWithin(result) + case (result: InternalRow, expected: InternalRow) => + val st = dataType.asInstanceOf[StructType] + assert(result.numFields == st.length && expected.numFields == st.length) + st.zipWithIndex.forall { case (f, i) => + checkResult( + result.get(i, f.dataType), + expected.get(i, f.dataType), + f.dataType, + f.nullable) + } + case (result: ArrayData, expected: ArrayData) => + result.numElements == expected.numElements && { + val ArrayType(et, cn) = dataType.asInstanceOf[ArrayType] + var isSame = true + var i = 0 + while (isSame && i < result.numElements) { + isSame = checkResult(result.get(i, et), expected.get(i, et), et, cn) + i += 1 + } + isSame + } + case (result: MapData, expected: MapData) => + val MapType(kt, vt, vcn) = dataType.asInstanceOf[MapType] + checkResult( + getSortedArrays(result, kt, vt)._1, + getSortedArrays(expected, kt, vt)._1, + ArrayType(kt, containsNull = false), + exprNullable = false) && checkResult( + getSortedArrays(result, kt, vt)._2, + getSortedArrays(expected, kt, vt)._2, + ArrayType(vt, vcn), + exprNullable = false) + case (result: Double, expected: Double) => + if ((isNaNOrInf(result) || isNaNOrInf(expected)) + || (result == -0.0) || (expected == -0.0)) { + java.lang.Double.doubleToRawLongBits(result) == + java.lang.Double.doubleToRawLongBits(expected) + } else { + Precision.equalsWithRelativeTolerance(result, expected, 0.00001d) + } + case (result: Float, expected: Float) => + if (expected.isNaN) result.isNaN else expected == result + case (result: Row, expected: InternalRow) => result.toSeq == expected.toSeq(result.schema) + case _ => + result == expected + } + } + + private def checkDataTypeSupported(expr: Expression): Boolean = { + SUPPORTED_DATA_TYPES.acceptsType(expr.dataType) + } + + /** + * Placeholder for future fallback checks. + * + * TODO: Implement logic to verify that no unexpected fallbacks occur during expression + * evaluation. Currently, this method is intentionally left empty because the Auron engine has + * many legitimate fallback cases that are not yet fully handled. Once fallback handling is + * stabilized and the expected cases are well defined, implement assertions or checks here to + * ensure that only allowed fallbacks occur. + */ + private def shouldNotFallback(): Unit = {} + + /** + * Whether the input row can be converted a to data frame. + * + * Currently only GenericInternalRow which contains atomic values are supported. Complex types + * like Map, Array, nested InternalRows are not supported. + * + * @param inputRow + * the input row to be converted. + * @return + * the converted data frame. + */ + private def canConvertToDataFrame(inputRow: InternalRow): Boolean = { + if (inputRow == EmptyRow || inputRow == InternalRow.empty) { + return true + } + + if (!inputRow.isInstanceOf[GenericInternalRow]) { + return false + } + + val values = inputRow.asInstanceOf[GenericInternalRow].values + for (value <- values) { + value match { + case _: MapData => return false + case _: ArrayData => return false + case _: InternalRow => return false + case _ => + } + } + true + } + + private def convertInternalRowToDataFrame(inputRow: InternalRow): DataFrame = { + val structFieldSeq = new ArrayBuffer[StructField]() + val values = inputRow match { + case genericInternalRow: GenericInternalRow => + genericInternalRow.values + case _ => throw new UnsupportedOperationException("Unsupported InternalRow.") + } + + values.foreach { + case boolean: java.lang.Boolean => + structFieldSeq.append(StructField("bool", BooleanType, boolean == null)) + case byte: java.lang.Byte => + structFieldSeq.append(StructField("i8", ByteType, byte == null)) + case short: java.lang.Short => + structFieldSeq.append(StructField("i16", ShortType, short == null)) + case integer: java.lang.Integer => + structFieldSeq.append(StructField("i32", IntegerType, integer == null)) + case long: java.lang.Long => + structFieldSeq.append(StructField("i64", LongType, long == null)) + case float: java.lang.Float => + structFieldSeq.append(StructField("fp32", FloatType, float == null)) + case double: java.lang.Double => + structFieldSeq.append(StructField("fp64", DoubleType, double == null)) + case utf8String: UTF8String => + structFieldSeq.append(StructField("str", StringType, utf8String == null)) + case byteArr: Array[Byte] => + structFieldSeq.append(StructField("vbin", BinaryType, byteArr == null)) + case decimal: Decimal => + structFieldSeq.append( + StructField("dec", DecimalType(decimal.precision, decimal.scale), decimal == null)) + case _ => + // for null + structFieldSeq.append(StructField("n", IntegerType, nullable = true)) + } + + _spark.internalCreateDataFrame( + _spark.sparkContext.parallelize(Seq(inputRow)), + StructType(structFieldSeq.toSeq)) + } +} diff --git a/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkQueryTestsBase.scala b/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkQueryTestsBase.scala new file mode 100644 index 000000000..6c6d839a2 --- /dev/null +++ b/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkQueryTestsBase.scala @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import java.io.File +import java.util.TimeZone + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.FileUtils +import org.apache.commons.math3.util.Precision +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.util.sideBySide +import org.apache.spark.sql.execution.SQLExecution +import org.scalatest.Assertions + +/** + * Basic trait for all Spark query tests. + */ +trait SparkQueryTestsBase extends QueryTest with SparkTestsSharedSessionBase { + private def prepareWorkDir(): Unit = { + // prepare working paths + val basePathDir = new File(basePath) + if (basePathDir.exists()) { + FileUtils.forceDelete(basePathDir) + } + FileUtils.forceMkdir(basePathDir) + FileUtils.forceMkdir(new File(warehouse)) + FileUtils.forceMkdir(new File(metaStorePathAbsolute)) + } + + override def beforeAll(): Unit = { + prepareWorkDir() + super.beforeAll() + + spark.sparkContext.setLogLevel("WARN") + } + + override protected def checkAnswer(df: => DataFrame, expectedAnswer: Seq[Row]): Unit = { + assertEmptyMissingInput(df) + AuronQueryTestUtil.checkAnswer(df, expectedAnswer) + } +} + +object AuronQueryTestUtil extends Assertions { + + /** + * Runs the plan and makes sure the answer matches the expected result. + * + * @param df + * the DataFrame to be executed + * @param expectedAnswer + * the expected result in a Seq of Rows. + * @param checkToRDD + * whether to verify deserialization to an RDD. This runs the query twice. + */ + def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row], checkToRDD: Boolean = true): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer, checkToRDD) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + /** + * Runs the plan and makes sure the answer matches the expected result. If there was exception + * during the execution or the contents of the DataFrame does not match the expected result, an + * error message will be returned. Otherwise, a None will be returned. + * + * @param df + * the DataFrame to be executed + * @param expectedAnswer + * the expected result in a Seq of Rows. + * @param checkToRDD + * whether to verify deserialization to an RDD. This runs the query twice. + */ + def getErrorMessageInCheckAnswer( + df: DataFrame, + expectedAnswer: Seq[Row], + checkToRDD: Boolean = true): Option[String] = { + val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty + if (checkToRDD) { + SQLExecution.withSQLConfPropagated(df.sparkSession) { + df.rdd.count() // Also attempt to deserialize as an RDD [SPARK-15791] + } + } + + val sparkAnswer = + try df.collect().toSeq + catch { + case e: Exception => + val errorMessage = + s""" + |Exception thrown while executing query: + |${df.queryExecution} + |== Exception == + |$e + |${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + sameRows(expectedAnswer, sparkAnswer, isSorted).map { results => + s""" + |Results do not match for query: + |Timezone: ${TimeZone.getDefault} + |Timezone Env: ${sys.env.getOrElse("TZ", "")} + | + |${df.queryExecution} + |== Results == + |$results + """.stripMargin + } + } + + def prepareAnswer(answer: Seq[Row], isSorted: Boolean): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + val converted: Seq[Row] = answer.map(prepareRow) + if (!isSorted) converted.sortBy(_.toString()) else converted + } + + // We need to call prepareRow recursively to handle schemas with struct types. + def prepareRow(row: Row): Row = { + Row.fromSeq(row.toSeq.map { + case null => null + case bd: java.math.BigDecimal => BigDecimal(bd) + // Equality of WrappedArray differs for AnyVal and AnyRef in Scala 2.12.2+ + case seq: Seq[_] => + seq.map { + case b: java.lang.Byte => b.byteValue + case s: java.lang.Short => s.shortValue + case i: java.lang.Integer => i.intValue + case l: java.lang.Long => l.longValue + case f: java.lang.Float => f.floatValue + case d: java.lang.Double => d.doubleValue + case x => x + } + // Convert array to Seq for easy equality check. + case b: Array[_] => b.toSeq + case r: Row => prepareRow(r) + case o => o + }) + } + + private def genError( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): String = { + val getRowType: Option[Row] => String = row => + row + .map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }) + .getOrElse("struct<>") + + s""" + |== Results == + |${sideBySide( + s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: + prepareAnswer(expectedAnswer, isSorted).map(_.toString()), + s"== Auron Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: + prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} + """.stripMargin + } + + def includesRows(expectedRows: Seq[Row], sparkAnswer: Seq[Row]): Option[String] = { + if (!prepareAnswer(expectedRows, true).toSet.subsetOf( + prepareAnswer(sparkAnswer, true).toSet)) { + return Some(genError(expectedRows, sparkAnswer, true)) + } + None + } + + private def compare(obj1: Any, obj2: Any): Boolean = (obj1, obj2) match { + case (null, null) => true + case (null, _) => false + case (_, null) => false + case (a: Array[_], b: Array[_]) => + a.length == b.length && a.zip(b).forall { case (l, r) => compare(l, r) } + case (a: Map[_, _], b: Map[_, _]) => + a.size == b.size && a.keys.forall { aKey => + b.keys.find(bKey => compare(aKey, bKey)).exists(bKey => compare(a(aKey), b(bKey))) + } + case (a: Iterable[_], b: Iterable[_]) => + a.size == b.size && a.zip(b).forall { case (l, r) => compare(l, r) } + case (a: Product, b: Product) => + compare(a.productIterator.toSeq, b.productIterator.toSeq) + case (a: Row, b: Row) => + compare(a.toSeq, b.toSeq) + // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0. + case (a: Double, b: Double) => + if ((isNaNOrInf(a) || isNaNOrInf(b)) || (a == -0.0) || (b == -0.0)) { + java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b) + } else { + Precision.equalsWithRelativeTolerance(a, b, 0.00001d) + } + case (a: Float, b: Float) => + java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b) + case (a, b) => a == b + } + + def isNaNOrInf(num: Double): Boolean = { + num.isNaN || num.isInfinite || num.isNegInfinity || num.isPosInfinity + } + + def sameRows( + expectedAnswer: Seq[Row], + sparkAnswer: Seq[Row], + isSorted: Boolean = false): Option[String] = { + // modify method 'compare' + if (!compare(prepareAnswer(expectedAnswer, isSorted), prepareAnswer(sparkAnswer, isSorted))) { + return Some(genError(expectedAnswer, sparkAnswer, isSorted)) + } + None + } + + /** + * Runs the plan and makes sure the answer is within absTol of the expected result. + * + * @param actualAnswer + * the actual result in a [[Row]]. + * @param expectedAnswer + * the expected result in a [[Row]]. + * @param absTol + * the absolute tolerance between actual and expected answers. + */ + protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, absTol: Double) = { + require( + actualAnswer.length == expectedAnswer.length, + s"actual answer length ${actualAnswer.length} != " + + s"expected answer length ${expectedAnswer.length}") + + // TODO: support other numeric types besides Double + // TODO: support struct types? + actualAnswer.toSeq.zip(expectedAnswer.toSeq).foreach { + case (actual: Double, expected: Double) => + assert( + math.abs(actual - expected) < absTol, + s"actual answer $actual not within $absTol of correct answer $expected") + case (actual, expected) => + assert(actual == expected, s"$actual did not equal $expected") + } + } + + def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): Unit = { + getErrorMessageInCheckAnswer(df, expectedAnswer.asScala.toSeq) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } +} diff --git a/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkTestsBase.scala b/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkTestsBase.scala new file mode 100644 index 000000000..f67928342 --- /dev/null +++ b/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkTestsBase.scala @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import scala.collection.mutable + +import org.scalactic.source.Position +import org.scalatest.Tag +import org.scalatest.funsuite.AnyFunSuiteLike + +import org.apache.auron.utils.SparkTestSettings + +/** + * Base trait for all Spark tests. + */ +trait SparkTestsBase extends AnyFunSuiteLike { + protected val IGNORE_ALL: String = "IGNORE_ALL" + protected val AURON_TEST: String = "Auron - " + + protected val rootPath: String = getClass.getResource("/").getPath + protected val basePath: String = rootPath + "unit-tests-working-home" + protected val warehouse: String = basePath + "/spark-warehouse" + protected val metaStorePathAbsolute: String = basePath + "/meta" + + /** + * Returns a sequence of test names to be blacklisted (i.e., skipped) for this test suite. + * + * Any test whose name appears in the returned sequence will never be run, regardless of backend + * test settings. This method can be overridden by subclasses to specify which tests to skip. + * + * Special behavior: If the sequence contains the value of `IGNORE_ALL` (case-insensitive), then + * all tests in the suite will be skipped. + * + * @return + * a sequence of test names to blacklist, or a sequence containing `IGNORE_ALL` to skip all + * tests + */ + def testNameBlackList: Seq[String] = Seq() + + val sparkConfList = { + val conf = mutable.Map[String, String]() + conf += ("spark.driver.memory" -> "1G") + conf += ("spark.sql.adaptive.enabled" -> "true") + conf += ("spark.sql.shuffle.partitions" -> "1") + conf += ("spark.sql.files.maxPartitionBytes" -> "134217728") + conf += ("spark.ui.enabled" -> "false") + conf += ("auron.ui.enabled" -> "false") + conf += ("spark.auron.enable" -> "true") + conf += ("spark.executor.memory" -> "1G") + conf += ("spark.executor.memoryOverhead" -> "1G") + conf += ("spark.memory.offHeap.enabled" -> "false") + conf += ("spark.sql.extensions" -> "org.apache.spark.sql.auron.AuronSparkSessionExtension") + conf += ("spark.shuffle.manager" -> + "org.apache.spark.sql.execution.auron.shuffle.AuronShuffleManager") + conf += ("spark.unsafe.exceptionOnMemoryLeak" -> "true") + // Avoid the code size overflow error in Spark code generation. + conf += ("spark.sql.codegen.wholeStage" -> "false") + + conf + } + + protected def shouldRun(testName: String): Boolean = { + if (testNameBlackList.exists(_.equalsIgnoreCase(IGNORE_ALL))) { + return false + } + + if (testNameBlackList.contains(testName)) { + return false + } + + SparkTestSettings.shouldRun(getClass.getCanonicalName, testName) + } + + protected def testAuron(testName: String, testTag: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + test(AURON_TEST + testName, testTag: _*)(testFun) + } + + protected def ignoreAuron(testName: String, testTag: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.ignore(AURON_TEST + testName, testTag: _*)(testFun) + } + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + if (shouldRun(testName)) { + super.test(testName, testTags: _*)(testFun) + } else { + super.ignore(testName, testTags: _*)(testFun) + } + } +} diff --git a/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkTestsSharedSessionBase.scala b/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkTestsSharedSessionBase.scala new file mode 100644 index 000000000..8aedeeb00 --- /dev/null +++ b/auron-spark-tests/common/src/test/scala/org/apache/spark/sql/SparkTestsSharedSessionBase.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkConf +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.test.SharedSparkSession + +trait SparkTestsSharedSessionBase extends SharedSparkSession with SparkTestsBase { + override def sparkConf: SparkConf = { + val conf = super.sparkConf + .setAppName("Auron-UT") + .set("spark.sql.warehouse.dir", warehouse) + + for ((key, value) <- sparkConfList) { + conf.set(key, value) + } + + conf + } + + /** + * Get all the children(and descendents) plan of plans. + * + * @param plans + * the input plans. + * @return + * all the children(and descendents) plan + */ + private def getChildrenPlan(plans: Seq[SparkPlan]): Seq[SparkPlan] = { + if (plans.isEmpty) { + return Seq() + } + + val inputPlans: Seq[SparkPlan] = plans.map { + case stage: QueryStageExec => stage.plan + case plan => plan + } + + var newChildren: Seq[SparkPlan] = Seq() + inputPlans.foreach { plan => + newChildren = newChildren ++ getChildrenPlan(plan.children) + // To avoid duplication of WholeStageCodegenXXX and its children. + if (!plan.nodeName.startsWith("WholeStageCodegen")) { + newChildren = newChildren :+ plan + } + } + newChildren + } + + /** + * Get the executed plan of a data frame. + * + * @param df + * dataframe. + * @return + * A sequence of executed plans. + */ + def getExecutedPlan(df: DataFrame): Seq[SparkPlan] = { + df.queryExecution.executedPlan match { + case exec: AdaptiveSparkPlanExec => + getChildrenPlan(Seq(exec.executedPlan)) + case plan => + getChildrenPlan(Seq(plan)) + } + } +} diff --git a/auron-spark-tests/common/src/test/scala/org/apache/spark/utils/DebuggableThreadUtils.scala b/auron-spark-tests/common/src/test/scala/org/apache/spark/utils/DebuggableThreadUtils.scala new file mode 100644 index 000000000..c128347d9 --- /dev/null +++ b/auron-spark-tests/common/src/test/scala/org/apache/spark/utils/DebuggableThreadUtils.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.utils + +import scala.util.{Failure, Success, Try} + +import org.apache.spark.util.ThreadUtils + +object DebuggableThreadUtils { + + /** + * Applies a function to each element of the input sequence in parallel, logging any failures. + * + * @param in + * The input sequence of elements to process. + * @param prefix + * The prefix to use for thread names. + * @param maxThreads + * The maximum number of threads to use for parallel processing. + * @param f + * The function to apply to each element of the input sequence. + * @return + * A sequence containing the results of applying the function to each input element. + */ + def parmap[I, O](in: Seq[I], prefix: String, maxThreads: Int)(f: I => O): Seq[O] = { + ThreadUtils.parmap(in, prefix, maxThreads) { i => + Try(f(i)) match { + case Success(result) => result + case Failure(exception) => + // scalastyle:off println + println(s"Test failed for case: ${i.toString}: ${exception.getMessage}") + // scalastyle:on println + throw exception + } + } + } +} diff --git a/auron-spark-tests/pom.xml b/auron-spark-tests/pom.xml new file mode 100644 index 000000000..75e88fdf8 --- /dev/null +++ b/auron-spark-tests/pom.xml @@ -0,0 +1,92 @@ + + + + 4.0.0 + + + org.apache.auron + auron-parent_${scalaVersion} + ${project.version} + ../pom.xml + + + auron-spark-tests + pom + Auron Spark Test Parent + + + common + + + + + org.apache.spark + spark-core_${scalaVersion} + ${sparkVersion} + test + + + org.apache.spark + spark-catalyst_${scalaVersion} + ${sparkVersion} + test + + + org.apache.spark + spark-sql_${scalaVersion} + ${sparkVersion} + test + + + org.apache.spark + spark-hive_${scalaVersion} + ${sparkVersion} + test + + + + + + + + net.alchim31.maven + scala-maven-plugin + + true + + -Xss128m + + + + + org.scalastyle + scalastyle-maven-plugin + + + + + + + + spark-3.3 + + spark33 + + + + diff --git a/auron-spark-tests/spark33/pom.xml b/auron-spark-tests/spark33/pom.xml new file mode 100644 index 000000000..efbd9bde3 --- /dev/null +++ b/auron-spark-tests/spark33/pom.xml @@ -0,0 +1,146 @@ + + + + 4.0.0 + + + org.apache.auron + auron-spark-tests + ${project.version} + ../pom.xml + + + auron-spark-tests-spark33 + jar + Auron Spark Test for Spark 3.3 + + + + org.apache.auron + spark-extension_${scalaVersion} + ${project.version} + + + org.apache.auron + spark-extension-shims-spark_${scalaVersion} + ${project.version} + + + org.apache.auron + auron-spark-tests-common + ${project.version} + test-jar + + + net.bytebuddy + byte-buddy + + + net.bytebuddy + byte-buddy-agent + + + org.apache.arrow + arrow-memory-unsafe + + + org.apache.spark + spark-core_${scalaVersion} + test-jar + test + + + org.apache.spark + spark-catalyst_${scalaVersion} + test-jar + test + + + org.apache.spark + spark-sql_${scalaVersion} + test-jar + test + + + org.apache.spark + spark-tags_${scalaVersion} + test-jar + test + + + org.scalatestplus + scalatestplus-scalacheck_${scalaVersion} + test + + + + + + + org.apache.maven.plugins + maven-resources-plugin + + + net.alchim31.maven + scala-maven-plugin + + + org.apache.maven.plugins + maven-compiler-plugin + + + org.scalastyle + scalastyle-maven-plugin + + + org.apache.maven.plugins + maven-checkstyle-plugin + + + org.scalatest + scalatest-maven-plugin + + . + + + + test + + test + + + + + + org.apache.maven.plugins + maven-jar-plugin + + + prepare-test-jar + + test-jar + + test-compile + + + + + target/scala-${scalaVersion}/classes + target/scala-${scalaVersion}/test-classes + + diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala new file mode 100644 index 000000000..2bccdc086 --- /dev/null +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.auron.utils + +import org.apache.spark.sql._ + +class AuronSparkTestSettings extends SparkTestSettings { + { + // Use Arrow's unsafe implementation. + System.setProperty("arrow.allocation.manager.type", "Unsafe") + } + + enableSuite[AuronStringFunctionsSuite] + // See https://github.com/apache/auron/issues/1724 + .exclude("string / binary substring function") + + // Will be implemented in the future. + override def getSQLQueryTestSettings = new SQLQueryTestSettings { + override def getResourceFilePath: String = ??? + + override def getSupportedSQLQueryTests: Set[String] = ??? + + override def getOverwriteSQLQueryTests: Set[String] = ??? + } +} diff --git a/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronStringFunctionsSuite.scala b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronStringFunctionsSuite.scala new file mode 100644 index 000000000..20c0c9bf4 --- /dev/null +++ b/auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronStringFunctionsSuite.scala @@ -0,0 +1,19 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +class AuronStringFunctionsSuite extends StringFunctionsSuite with SparkQueryTestsBase diff --git a/pom.xml b/pom.xml index d46edea1d..276782e5b 100644 --- a/pom.xml +++ b/pom.xml @@ -146,6 +146,18 @@ + + org.apache.spark + spark-hive_${scalaVersion} + ${sparkVersion} + test-jar + + + org.apache.arrow + * + + + org.apache.spark spark-sql_${scalaVersion} @@ -217,6 +229,12 @@ scalatest_${scalaVersion} ${scalaTestVersion} + + org.scalatestplus + scalatestplus-scalacheck_${scalaVersion} + 3.1.0.0-RC2 + test + org.apache.spark spark-core_${scalaVersion} @@ -224,6 +242,11 @@ test-jar test + + org.apache.spark + spark-catalyst_${scalaVersion} + ${sparkVersion} + org.apache.spark spark-catalyst_${scalaVersion} @@ -244,6 +267,13 @@ test-jar test + + org.apache.spark + spark-tags_${scalaVersion} + ${sparkVersion} + test-jar + test + @@ -888,5 +918,12 @@ 1.9.2 + + + spark-tests + + auron-spark-tests + + From ef086db9049c1ba77221e414388e1046ce9a17e5 Mon Sep 17 00:00:00 2001 From: James Xu Date: Fri, 26 Dec 2025 08:40:59 +0800 Subject: [PATCH 2/3] Trigger CI From bd64b4822a9825e6099d003b2719e994ea6c43ce Mon Sep 17 00:00:00 2001 From: James Xu Date: Fri, 26 Dec 2025 11:22:18 +0800 Subject: [PATCH 3/3] Trigger CI