Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ jobs:
run: |
export SPARK_REMOTE=sc://localhost
cd build
ctest -LE integration --output-on-failure
ctest -LE dbrx --output-on-failure
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ macro(add_spark_test TEST_NAME SOURCE_FILE)
add_test(NAME ${TEST_NAME} COMMAND ${TEST_NAME})

# -------------------------------------------------------
# Automatically tag databricks tests as "integration"
# Automatically tag databricks tests as "dbrx"
# -------------------------------------------------------
if(${TEST_NAME} MATCHES "databricks")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "integration")
set_tests_properties(${TEST_NAME} PROPERTIES LABELS "dbrx")
endif()

set_tests_properties(${TEST_NAME}
Expand All @@ -143,7 +143,7 @@ add_spark_test(test_dataframe_reader tests/dataframe_reader.cpp)
add_spark_test(test_dataframe_writer tests/dataframe_writer.cpp)

# -----------------------------------------------------------------------
# The following tests will be labeled "integration" automatically
# The following tests will be labeled "dbrx" automatically
# Since these specifically require an active Databricks connection,
# we can conditionally exclude them.
# -----------------------------------------------------------------------
Expand Down
24 changes: 24 additions & 0 deletions docs/API_REFERENCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -564,4 +564,28 @@ filtered_df.show();
| Stella | 24 | 79000 |
| Brooklyn | 23 | 75000 |
+----------------------+----------------------+----------------------+
```

### Collect

```cpp
auto df = spark->read()
.option("header", "true")
.option("inferSchema", "true")
.csv("datasets/people.csv");

auto rows = df.collect();

for (auto &row : rows) {
std::cout << row << std::endl;
}
```

**Output:**

```
Row(name='John', age=25, salary=100000)
Row(name='Alice', age=30, salary=85000)
...

```
2 changes: 1 addition & 1 deletion hooks/pre-push
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ cd build
echo -e "${BLUE}Running tests (ctest)...${NC}"
echo ""

if ! ctest --verbose -LE integration --test-arguments=--gtest_color=yes; then
if ! ctest --verbose -LE dbrx --test-arguments=--gtest_color=yes; then
echo ""
echo -e "${RED}*********************************************${NC}"
echo -e "${RED} TESTS FAILED${NC}"
Expand Down
101 changes: 88 additions & 13 deletions src/dataframe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <arrow/api.h>
#include <arrow/ipc/api.h>
#include <arrow/io/memory.h>
#include <arrow/ipc/reader.h>
#include <arrow/table.h>

#include <grpcpp/grpcpp.h>
#include <spark/connect/base.grpc.pb.h>
Expand Down Expand Up @@ -441,12 +443,12 @@ void DataFrame::printSchema() const

DataFrame DataFrame::select(const std::vector<std::string> &cols)
{
spark::connect::Plan new_plan;
spark::connect::Plan plan;

// ---------------------------------------------------------------------
// Use the pointer to the root to ensure we are copying the content
// ---------------------------------------------------------------------
auto *project = new_plan.mutable_root()->mutable_project();
auto *project = plan.mutable_root()->mutable_project();

// ---------------------------------------------------------------------
// Copy the entire relation tree from the previous plan
Expand All @@ -462,16 +464,16 @@ DataFrame DataFrame::select(const std::vector<std::string> &cols)
expr->mutable_unresolved_attribute()->set_unparsed_identifier(col_name);
}

return DataFrame(stub_, new_plan, session_id_, user_id_);
return DataFrame(stub_, plan, session_id_, user_id_);
}

DataFrame DataFrame::limit(int n)
{
spark::connect::Plan new_plan;
auto *limit_rel = new_plan.mutable_root()->mutable_limit();
spark::connect::Plan plan;
auto *limit_rel = plan.mutable_root()->mutable_limit();
*limit_rel->mutable_input() = this->plan_.root();
limit_rel->set_limit(n);
return DataFrame(stub_, new_plan, session_id_, user_id_);
return DataFrame(stub_, plan, session_id_, user_id_);
}

std::vector<spark::sql::types::Row> DataFrame::take(int n)
Expand Down Expand Up @@ -698,24 +700,28 @@ DataFrame DataFrame::dropDuplicates()

DataFrame DataFrame::dropDuplicates(const std::vector<std::string> &subset)
{
spark::connect::Plan new_plan;
spark::connect::Plan plan;

auto *relation = new_plan.mutable_root()->mutable_deduplicate();
auto *relation = plan.mutable_root()->mutable_deduplicate();

if (this->plan_.has_root())
{
relation->mutable_input()->CopyFrom(this->plan_.root());
}

if (subset.empty()) {
relation->set_all_columns_as_keys(true);
} else {
for (const auto &col_name : subset) {
if (subset.empty())
{
relation->set_all_columns_as_keys(true);
}
else
{
for (const auto &col_name : subset)
{
relation->add_column_names(col_name);
}
}

return DataFrame(stub_, new_plan, session_id_, user_id_);
return DataFrame(stub_, plan, session_id_, user_id_);
}

DataFrame DataFrame::drop_duplicates()
Expand All @@ -726,4 +732,73 @@ DataFrame DataFrame::drop_duplicates()
DataFrame DataFrame::drop_duplicates(const std::vector<std::string> &subset)
{
return dropDuplicates(subset);
}

std::vector<spark::sql::types::Row> DataFrame::collect()
{
std::vector<spark::sql::types::Row> results;

spark::connect::ExecutePlanRequest request;
request.set_session_id(session_id_);
request.mutable_user_context()->set_user_id(user_id_);
*request.mutable_plan() = plan_;

grpc::ClientContext context;
auto stream = stub_->ExecutePlan(&context, request);

spark::connect::ExecutePlanResponse response;
std::vector<std::string> col_names;
bool schema_initialized = false;

while (stream->Read(&response))
{
if (response.has_arrow_batch())
{
const auto &batch_proto = response.arrow_batch();

auto buffer = std::make_shared<arrow::Buffer>(
reinterpret_cast<const uint8_t *>(batch_proto.data().data()),
batch_proto.data().size());
arrow::io::BufferReader buffer_reader(buffer);

auto reader_result = arrow::ipc::RecordBatchStreamReader::Open(&buffer_reader);
if (!reader_result.ok())
continue;
auto reader = reader_result.ValueOrDie();

std::shared_ptr<arrow::RecordBatch> batch;
while (reader->ReadNext(&batch).ok() && batch)
{
if (!schema_initialized)
{
for (int i = 0; i < batch->num_columns(); ++i)
{
col_names.push_back(batch->column_name(i));
}
schema_initialized = true;
}

for (int64_t i = 0; i < batch->num_rows(); ++i)
{
spark::sql::types::Row row;
row.column_names = col_names;

for (int j = 0; j < batch->num_columns(); ++j)
{
row.values.push_back(
spark::sql::types::arrayValueToVariant(batch->column(j), i));
}
results.push_back(std::move(row));
}
}
}
}

auto status = stream->Finish();
if (!status.ok())
{
throw std::runtime_error("gRPC Error during collect: " + status.error_message());
}

return results;
}
43 changes: 35 additions & 8 deletions src/dataframe.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#pragma once

#include <string>
#include <memory>
#include <string>
#include <vector>

#include <spark/connect/base.grpc.pb.h>
#include <spark/connect/relations.pb.h>

Expand Down Expand Up @@ -132,25 +134,50 @@ class DataFrame
DataFrameWriter write();

/**
* @brief Returns a new DataFrame with duplicate rows removed - equivalent to distinct() function
* @brief Returns a new DataFrame with duplicate rows removed - equivalent to `distinct()` function
*/

DataFrame dropDuplicates();
/**
* @brief Returns a new DataFrame with duplicate rows removed,
* considering only the given subset of columns - equivalent to distinct() function
* considering only the given subset of columns - equivalent to `distinct()` function
*/
DataFrame dropDuplicates(const std::vector<std::string>& subset);
DataFrame dropDuplicates(const std::vector<std::string> &subset);

/**
* @brief Alias for dropDuplicates().
* @brief Alias for `dropDuplicates()`.
*/
DataFrame drop_duplicates();

/**
* @brief Alias for dropDuplicates(subset).
* @brief Alias for `dropDuplicates(subset)`.
*/
DataFrame drop_duplicates(const std::vector<std::string>& subset);
DataFrame drop_duplicates(const std::vector<std::string> &subset);

/**
* @brief Returns all the records as a list of `Row`
* @example
* SparkSession spark(...);
* auto df = spark.read()
* .option("header", "true");
* .option("inferSchema", "true");
* .csv("datasets/people.csv");
*
* auto rows = df.collect();
*
* for (auto &row : rows) {
* std::cout << row << std::endl;
* }
*
* // ------------------------------------------
* // Output:
* // Row(name='John', age=25, salary=100000)
* // Row(name='Alice', age=30, salary=85000)
* // ...
* // ------------------------------------------
* @returns A list of rows.
*/
std::vector<spark::sql::types::Row> collect();

private:
std::shared_ptr<spark::connect::SparkConnectService::Stub> stub_;
Expand Down
46 changes: 44 additions & 2 deletions tests/dataframe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,12 @@ TEST_F(SparkIntegrationTest, WhereFilter)
.csv("datasets/people.csv");

auto filtered_df = df.where("age < 25");
filtered_df.show();
EXPECT_NO_THROW(filtered_df.show());
EXPECT_LT(filtered_df.count(), df.count());
}

TEST_F(SparkIntegrationTest, DropDuplicates)
{
// R - raw string literal.
auto df = spark->sql(R"(
SELECT *
FROM VALUES
Expand All @@ -296,6 +296,48 @@ TEST_F(SparkIntegrationTest, DropDuplicates)
auto deduped = df.dropDuplicates();
deduped.show();

EXPECT_EQ(deduped.count(), 5);

auto subset_deduped = df.dropDuplicates({"age"});
subset_deduped.show();

EXPECT_EQ(subset_deduped.count(), 2);
}

TEST_F(SparkIntegrationTest, DataFrameCollect)
{
auto df = spark->read()
.option("header", "true")
.option("inferSchema", "true")
.csv("datasets/people.csv");

auto rows = df.collect();

EXPECT_EQ(df.count(), rows.size());

// --------------------------------------------
// Check the first row: e.g., "John", 25
// Refer to: /datasets/people.csv
//
// Sample: Row(name='John', age=25, salary=100000)
// --------------------------------------------
EXPECT_EQ(std::get<std::string>(rows[0]["name"]), "John");
EXPECT_EQ(rows[0].get_long("age"), 25);

// --------------------------------------------
// Check the second row: e.g., "Andy", 30
// Refer to: /datasets/people.csv
//
// Sample: Row(name='Alice', age=30, salary=85000)
// --------------------------------------------
EXPECT_EQ(std::get<std::string>(rows[1]["name"]), "Alice");
EXPECT_EQ(rows[1].get_long("age"), 30);

EXPECT_EQ(rows[0].column_names[0], "name");
EXPECT_EQ(rows[0].column_names[1], "age");
EXPECT_EQ(rows[0].column_names[2], "salary");

auto cols = df.columns();
EXPECT_EQ(cols.size(), 3);
EXPECT_STREQ(cols[0].c_str(), "name");
}
2 changes: 1 addition & 1 deletion tests/serverless_databricks_cluster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ TEST_F(SparkIntegrationTest, DatabricksNycTaxiAnalysis)
"GROUP BY pickup_zip "
"ORDER BY total_trips DESC");

df.show(10);
df.show(20);

ASSERT_GT(df.count(), 0) << "The taxi dataset should not be empty.";
}
Loading