From 9cbac4fe3dd6e04ef3c93eca35b02dd5763f8bf4 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 11 Sep 2025 23:09:19 +0900 Subject: [PATCH 1/2] Fix sum_distinct --- crates/connect/src/functions/mod.rs | 52 ++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/crates/connect/src/functions/mod.rs b/crates/connect/src/functions/mod.rs index 4884939..64f7bd1 100644 --- a/crates/connect/src/functions/mod.rs +++ b/crates/connect/src/functions/mod.rs @@ -1108,7 +1108,19 @@ gen_func!(stddev, [col: Column], "Alias for stddev_samp."); gen_func!(stddev_pop, [col: Column], "Returns population standard deviation of the expression in a group."); gen_func!(stddev_samp, [col: Column], "Returns the unbiased sample standard deviation of the expression in a group."); gen_func!(sum, [col: Column], "Returns the sum of all values in the expression."); -gen_func!(sum_distinct, [col: Column], "Returns the sum of distinct values in the expression."); +/// "Returns the sum of distinct values in the expression." +pub fn sum_distinct(col: impl Into) -> Column { + Column::from(spark::Expression { + expr_type: Some(spark::expression::ExprType::UnresolvedFunction( + spark::expression::UnresolvedFunction { + function_name: "sum".to_string(), + arguments: VecExpression::from_iter(vec![col]).into(), + is_distinct: true, + is_user_defined_function: false, + }, + )), + }) +} gen_func!(var_pop, [col: Column], "Returns the population variance of the values in a group."); gen_func!(var_samp, [col: Column], "Returns the unbiased sample variance of the values in a group."); gen_func!(variance, [col: Column], "Alias for var_samp"); @@ -2546,4 +2558,42 @@ mod tests { assert_eq!(expected, res); Ok(()) } + + // Test aggregate functions + #[tokio::test] + async fn test_func_sum_distinct() -> Result<(), SparkError> { + let spark = setup().await; + let select_func = |df: DataFrame| { + df.select([sum_distinct(col("value")).alias("sum")]) + .collect() + }; + let record_batch_func = + |col: ArrayRef| RecordBatch::try_from_iter_with_nullable(vec![("sum", col, true)]); + + let df = spark + .sql("SELECT * FROM VALUES (1), (2), (3), (1), (2), (3) AS data(value)") + .await?; + let res = select_func(df).await?; + let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6])); + let expected = record_batch_func(expected_col)?; + assert_eq!(expected, res); + + let df = spark + .sql("SELECT * FROM VALUES (1), (2), (3), (null), (1), (2), (3) AS data(value)") + .await?; + let res = select_func(df).await?; + let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6])); + let expected = record_batch_func(expected_col)?; + assert_eq!(expected, res); + + let df = spark + .sql("SELECT * FROM VALUES (null), (null), (null) AS data(value)") + .await?; + let res = select_func(df).await?; + let expected_col: ArrayRef = Arc::new(Float64Array::from(vec![None])); + let expected = record_batch_func(expected_col)?; + assert_eq!(expected, res); + + Ok(()) + } } From c8e5385b4f3a5450236c77b90671369513f0f193 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 1 Oct 2025 01:09:09 +0900 Subject: [PATCH 2/2] Fix style --- crates/connect/src/functions/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/connect/src/functions/mod.rs b/crates/connect/src/functions/mod.rs index 64f7bd1..89892ba 100644 --- a/crates/connect/src/functions/mod.rs +++ b/crates/connect/src/functions/mod.rs @@ -1108,6 +1108,7 @@ gen_func!(stddev, [col: Column], "Alias for stddev_samp."); gen_func!(stddev_pop, [col: Column], "Returns population standard deviation of the expression in a group."); gen_func!(stddev_samp, [col: Column], "Returns the unbiased sample standard deviation of the expression in a group."); gen_func!(sum, [col: Column], "Returns the sum of all values in the expression."); + /// "Returns the sum of distinct values in the expression." pub fn sum_distinct(col: impl Into) -> Column { Column::from(spark::Expression { @@ -1121,6 +1122,7 @@ pub fn sum_distinct(col: impl Into) -> Column { )), }) } + gen_func!(var_pop, [col: Column], "Returns the population variance of the values in a group."); gen_func!(var_samp, [col: Column], "Returns the unbiased sample variance of the values in a group."); gen_func!(variance, [col: Column], "Alias for var_samp");