Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion crates/connect/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,21 @@ gen_func!(stddev, [col: Column], "Alias for stddev_samp.");
gen_func!(stddev_pop, [col: Column], "Returns population standard deviation of the expression in a group.");
gen_func!(stddev_samp, [col: Column], "Returns the unbiased sample standard deviation of the expression in a group.");
gen_func!(sum, [col: Column], "Returns the sum of all values in the expression.");
gen_func!(sum_distinct, [col: Column], "Returns the sum of distinct values in the expression.");

/// "Returns the sum of distinct values in the expression."
pub fn sum_distinct(col: impl Into<Column>) -> Column {
Column::from(spark::Expression {
expr_type: Some(spark::expression::ExprType::UnresolvedFunction(
spark::expression::UnresolvedFunction {
function_name: "sum".to_string(),
arguments: VecExpression::from_iter(vec![col]).into(),
is_distinct: true,
is_user_defined_function: false,
},
)),
})
}

gen_func!(var_pop, [col: Column], "Returns the population variance of the values in a group.");
gen_func!(var_samp, [col: Column], "Returns the unbiased sample variance of the values in a group.");
gen_func!(variance, [col: Column], "Alias for var_samp");
Expand Down Expand Up @@ -2546,4 +2560,42 @@ mod tests {
assert_eq!(expected, res);
Ok(())
}

// Test aggregate functions
#[tokio::test]
async fn test_func_sum_distinct() -> Result<(), SparkError> {
let spark = setup().await;
let select_func = |df: DataFrame| {
df.select([sum_distinct(col("value")).alias("sum")])
.collect()
};
let record_batch_func =
|col: ArrayRef| RecordBatch::try_from_iter_with_nullable(vec![("sum", col, true)]);

let df = spark
.sql("SELECT * FROM VALUES (1), (2), (3), (1), (2), (3) AS data(value)")
.await?;
let res = select_func(df).await?;
let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6]));
let expected = record_batch_func(expected_col)?;
assert_eq!(expected, res);

let df = spark
.sql("SELECT * FROM VALUES (1), (2), (3), (null), (1), (2), (3) AS data(value)")
.await?;
let res = select_func(df).await?;
let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6]));
let expected = record_batch_func(expected_col)?;
assert_eq!(expected, res);

let df = spark
.sql("SELECT * FROM VALUES (null), (null), (null) AS data(value)")
.await?;
let res = select_func(df).await?;
let expected_col: ArrayRef = Arc::new(Float64Array::from(vec![None]));
let expected = record_batch_func(expected_col)?;
assert_eq!(expected, res);

Ok(())
}
}
Loading