diff --git a/crates/connect/src/functions/mod.rs b/crates/connect/src/functions/mod.rs index 4884939..89892ba 100644 --- a/crates/connect/src/functions/mod.rs +++ b/crates/connect/src/functions/mod.rs @@ -1108,7 +1108,21 @@ gen_func!(stddev, [col: Column], "Alias for stddev_samp."); gen_func!(stddev_pop, [col: Column], "Returns population standard deviation of the expression in a group."); gen_func!(stddev_samp, [col: Column], "Returns the unbiased sample standard deviation of the expression in a group."); gen_func!(sum, [col: Column], "Returns the sum of all values in the expression."); -gen_func!(sum_distinct, [col: Column], "Returns the sum of distinct values in the expression."); + +/// "Returns the sum of distinct values in the expression." +pub fn sum_distinct(col: impl Into) -> Column { + Column::from(spark::Expression { + expr_type: Some(spark::expression::ExprType::UnresolvedFunction( + spark::expression::UnresolvedFunction { + function_name: "sum".to_string(), + arguments: VecExpression::from_iter(vec![col]).into(), + is_distinct: true, + is_user_defined_function: false, + }, + )), + }) +} + gen_func!(var_pop, [col: Column], "Returns the population variance of the values in a group."); gen_func!(var_samp, [col: Column], "Returns the unbiased sample variance of the values in a group."); gen_func!(variance, [col: Column], "Alias for var_samp"); @@ -2546,4 +2560,42 @@ mod tests { assert_eq!(expected, res); Ok(()) } + + // Test aggregate functions + #[tokio::test] + async fn test_func_sum_distinct() -> Result<(), SparkError> { + let spark = setup().await; + let select_func = |df: DataFrame| { + df.select([sum_distinct(col("value")).alias("sum")]) + .collect() + }; + let record_batch_func = + |col: ArrayRef| RecordBatch::try_from_iter_with_nullable(vec![("sum", col, true)]); + + let df = spark + .sql("SELECT * FROM VALUES (1), (2), (3), (1), (2), (3) AS data(value)") + .await?; + let res = select_func(df).await?; + let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6])); + let expected = record_batch_func(expected_col)?; + assert_eq!(expected, res); + + let df = spark + .sql("SELECT * FROM VALUES (1), (2), (3), (null), (1), (2), (3) AS data(value)") + .await?; + let res = select_func(df).await?; + let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6])); + let expected = record_batch_func(expected_col)?; + assert_eq!(expected, res); + + let df = spark + .sql("SELECT * FROM VALUES (null), (null), (null) AS data(value)") + .await?; + let res = select_func(df).await?; + let expected_col: ArrayRef = Arc::new(Float64Array::from(vec![None])); + let expected = record_batch_func(expected_col)?; + assert_eq!(expected, res); + + Ok(()) + } }