diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 73557c1..9b75ff0 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -57,4 +57,4 @@ jobs: run: | export SPARK_REMOTE=sc://localhost cd build - ctest -LE integration --output-on-failure + ctest -LE dbrx --output-on-failure diff --git a/CMakeLists.txt b/CMakeLists.txt index c00424a..b2b26e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,10 +122,10 @@ macro(add_spark_test TEST_NAME SOURCE_FILE) add_test(NAME ${TEST_NAME} COMMAND ${TEST_NAME}) # ------------------------------------------------------- - # Automatically tag databricks tests as "integration" + # Automatically tag databricks tests as "dbrx" # ------------------------------------------------------- if(${TEST_NAME} MATCHES "databricks") - set_tests_properties(${TEST_NAME} PROPERTIES LABELS "integration") + set_tests_properties(${TEST_NAME} PROPERTIES LABELS "dbrx") endif() set_tests_properties(${TEST_NAME} @@ -143,7 +143,7 @@ add_spark_test(test_dataframe_reader tests/dataframe_reader.cpp) add_spark_test(test_dataframe_writer tests/dataframe_writer.cpp) # ----------------------------------------------------------------------- -# The following tests will be labeled "integration" automatically +# The following tests will be labeled "dbrx" automatically # Since these specifically require an active Databricks connection, # we can conditionally exclude them. # ----------------------------------------------------------------------- diff --git a/docs/API_REFERENCE.md b/docs/API_REFERENCE.md index 95139cd..a68d501 100644 --- a/docs/API_REFERENCE.md +++ b/docs/API_REFERENCE.md @@ -564,4 +564,28 @@ filtered_df.show(); | Stella | 24 | 79000 | | Brooklyn | 23 | 75000 | +----------------------+----------------------+----------------------+ +``` + +### Collect + +```cpp +auto df = spark->read() + .option("header", "true") + .option("inferSchema", "true") + .csv("datasets/people.csv"); + +auto rows = df.collect(); + +for (auto &row : rows) { + std::cout << row << std::endl; +} +``` + +**Output:** + +``` +Row(name='John', age=25, salary=100000) +Row(name='Alice', age=30, salary=85000) +... + ``` \ No newline at end of file diff --git a/hooks/pre-push b/hooks/pre-push index 98ef00f..85acec0 100644 --- a/hooks/pre-push +++ b/hooks/pre-push @@ -34,7 +34,7 @@ cd build echo -e "${BLUE}Running tests (ctest)...${NC}" echo "" -if ! ctest --verbose -LE integration --test-arguments=--gtest_color=yes; then +if ! ctest --verbose -LE dbrx --test-arguments=--gtest_color=yes; then echo "" echo -e "${RED}*********************************************${NC}" echo -e "${RED} TESTS FAILED${NC}" diff --git a/src/dataframe.cpp b/src/dataframe.cpp index 421e2a5..565fe9e 100644 --- a/src/dataframe.cpp +++ b/src/dataframe.cpp @@ -8,6 +8,8 @@ #include #include #include +#include +#include #include #include @@ -441,12 +443,12 @@ void DataFrame::printSchema() const DataFrame DataFrame::select(const std::vector &cols) { - spark::connect::Plan new_plan; + spark::connect::Plan plan; // --------------------------------------------------------------------- // Use the pointer to the root to ensure we are copying the content // --------------------------------------------------------------------- - auto *project = new_plan.mutable_root()->mutable_project(); + auto *project = plan.mutable_root()->mutable_project(); // --------------------------------------------------------------------- // Copy the entire relation tree from the previous plan @@ -462,16 +464,16 @@ DataFrame DataFrame::select(const std::vector &cols) expr->mutable_unresolved_attribute()->set_unparsed_identifier(col_name); } - return DataFrame(stub_, new_plan, session_id_, user_id_); + return DataFrame(stub_, plan, session_id_, user_id_); } DataFrame DataFrame::limit(int n) { - spark::connect::Plan new_plan; - auto *limit_rel = new_plan.mutable_root()->mutable_limit(); + spark::connect::Plan plan; + auto *limit_rel = plan.mutable_root()->mutable_limit(); *limit_rel->mutable_input() = this->plan_.root(); limit_rel->set_limit(n); - return DataFrame(stub_, new_plan, session_id_, user_id_); + return DataFrame(stub_, plan, session_id_, user_id_); } std::vector DataFrame::take(int n) @@ -698,24 +700,28 @@ DataFrame DataFrame::dropDuplicates() DataFrame DataFrame::dropDuplicates(const std::vector &subset) { - spark::connect::Plan new_plan; + spark::connect::Plan plan; - auto *relation = new_plan.mutable_root()->mutable_deduplicate(); + auto *relation = plan.mutable_root()->mutable_deduplicate(); if (this->plan_.has_root()) { relation->mutable_input()->CopyFrom(this->plan_.root()); } - if (subset.empty()) { - relation->set_all_columns_as_keys(true); - } else { - for (const auto &col_name : subset) { + if (subset.empty()) + { + relation->set_all_columns_as_keys(true); + } + else + { + for (const auto &col_name : subset) + { relation->add_column_names(col_name); } } - return DataFrame(stub_, new_plan, session_id_, user_id_); + return DataFrame(stub_, plan, session_id_, user_id_); } DataFrame DataFrame::drop_duplicates() @@ -726,4 +732,73 @@ DataFrame DataFrame::drop_duplicates() DataFrame DataFrame::drop_duplicates(const std::vector &subset) { return dropDuplicates(subset); +} + +std::vector DataFrame::collect() +{ + std::vector results; + + spark::connect::ExecutePlanRequest request; + request.set_session_id(session_id_); + request.mutable_user_context()->set_user_id(user_id_); + *request.mutable_plan() = plan_; + + grpc::ClientContext context; + auto stream = stub_->ExecutePlan(&context, request); + + spark::connect::ExecutePlanResponse response; + std::vector col_names; + bool schema_initialized = false; + + while (stream->Read(&response)) + { + if (response.has_arrow_batch()) + { + const auto &batch_proto = response.arrow_batch(); + + auto buffer = std::make_shared( + reinterpret_cast(batch_proto.data().data()), + batch_proto.data().size()); + arrow::io::BufferReader buffer_reader(buffer); + + auto reader_result = arrow::ipc::RecordBatchStreamReader::Open(&buffer_reader); + if (!reader_result.ok()) + continue; + auto reader = reader_result.ValueOrDie(); + + std::shared_ptr batch; + while (reader->ReadNext(&batch).ok() && batch) + { + if (!schema_initialized) + { + for (int i = 0; i < batch->num_columns(); ++i) + { + col_names.push_back(batch->column_name(i)); + } + schema_initialized = true; + } + + for (int64_t i = 0; i < batch->num_rows(); ++i) + { + spark::sql::types::Row row; + row.column_names = col_names; + + for (int j = 0; j < batch->num_columns(); ++j) + { + row.values.push_back( + spark::sql::types::arrayValueToVariant(batch->column(j), i)); + } + results.push_back(std::move(row)); + } + } + } + } + + auto status = stream->Finish(); + if (!status.ok()) + { + throw std::runtime_error("gRPC Error during collect: " + status.error_message()); + } + + return results; } \ No newline at end of file diff --git a/src/dataframe.h b/src/dataframe.h index 56a28f6..bee1546 100644 --- a/src/dataframe.h +++ b/src/dataframe.h @@ -1,7 +1,9 @@ #pragma once -#include #include +#include +#include + #include #include @@ -132,25 +134,50 @@ class DataFrame DataFrameWriter write(); /** - * @brief Returns a new DataFrame with duplicate rows removed - equivalent to distinct() function + * @brief Returns a new DataFrame with duplicate rows removed - equivalent to `distinct()` function */ DataFrame dropDuplicates(); /** * @brief Returns a new DataFrame with duplicate rows removed, - * considering only the given subset of columns - equivalent to distinct() function + * considering only the given subset of columns - equivalent to `distinct()` function */ - DataFrame dropDuplicates(const std::vector& subset); - + DataFrame dropDuplicates(const std::vector &subset); + /** - * @brief Alias for dropDuplicates(). + * @brief Alias for `dropDuplicates()`. */ DataFrame drop_duplicates(); /** - * @brief Alias for dropDuplicates(subset). + * @brief Alias for `dropDuplicates(subset)`. */ - DataFrame drop_duplicates(const std::vector& subset); + DataFrame drop_duplicates(const std::vector &subset); + + /** + * @brief Returns all the records as a list of `Row` + * @example + * SparkSession spark(...); + * auto df = spark.read() + * .option("header", "true"); + * .option("inferSchema", "true"); + * .csv("datasets/people.csv"); + * + * auto rows = df.collect(); + * + * for (auto &row : rows) { + * std::cout << row << std::endl; + * } + * + * // ------------------------------------------ + * // Output: + * // Row(name='John', age=25, salary=100000) + * // Row(name='Alice', age=30, salary=85000) + * // ... + * // ------------------------------------------ + * @returns A list of rows. + */ + std::vector collect(); private: std::shared_ptr stub_; diff --git a/tests/dataframe.cpp b/tests/dataframe.cpp index 7c4ecc1..c2fa53c 100644 --- a/tests/dataframe.cpp +++ b/tests/dataframe.cpp @@ -273,12 +273,12 @@ TEST_F(SparkIntegrationTest, WhereFilter) .csv("datasets/people.csv"); auto filtered_df = df.where("age < 25"); - filtered_df.show(); + EXPECT_NO_THROW(filtered_df.show()); + EXPECT_LT(filtered_df.count(), df.count()); } TEST_F(SparkIntegrationTest, DropDuplicates) { - // R - raw string literal. auto df = spark->sql(R"( SELECT * FROM VALUES @@ -296,6 +296,48 @@ TEST_F(SparkIntegrationTest, DropDuplicates) auto deduped = df.dropDuplicates(); deduped.show(); + EXPECT_EQ(deduped.count(), 5); + auto subset_deduped = df.dropDuplicates({"age"}); subset_deduped.show(); + + EXPECT_EQ(subset_deduped.count(), 2); } + +TEST_F(SparkIntegrationTest, DataFrameCollect) +{ + auto df = spark->read() + .option("header", "true") + .option("inferSchema", "true") + .csv("datasets/people.csv"); + + auto rows = df.collect(); + + EXPECT_EQ(df.count(), rows.size()); + + // -------------------------------------------- + // Check the first row: e.g., "John", 25 + // Refer to: /datasets/people.csv + // + // Sample: Row(name='John', age=25, salary=100000) + // -------------------------------------------- + EXPECT_EQ(std::get(rows[0]["name"]), "John"); + EXPECT_EQ(rows[0].get_long("age"), 25); + + // -------------------------------------------- + // Check the second row: e.g., "Andy", 30 + // Refer to: /datasets/people.csv + // + // Sample: Row(name='Alice', age=30, salary=85000) + // -------------------------------------------- + EXPECT_EQ(std::get(rows[1]["name"]), "Alice"); + EXPECT_EQ(rows[1].get_long("age"), 30); + + EXPECT_EQ(rows[0].column_names[0], "name"); + EXPECT_EQ(rows[0].column_names[1], "age"); + EXPECT_EQ(rows[0].column_names[2], "salary"); + + auto cols = df.columns(); + EXPECT_EQ(cols.size(), 3); + EXPECT_STREQ(cols[0].c_str(), "name"); +} \ No newline at end of file diff --git a/tests/serverless_databricks_cluster.cpp b/tests/serverless_databricks_cluster.cpp index 981343c..5d4e87c 100644 --- a/tests/serverless_databricks_cluster.cpp +++ b/tests/serverless_databricks_cluster.cpp @@ -69,7 +69,7 @@ TEST_F(SparkIntegrationTest, DatabricksNycTaxiAnalysis) "GROUP BY pickup_zip " "ORDER BY total_trips DESC"); - df.show(10); + df.show(20); ASSERT_GT(df.count(), 0) << "The taxi dataset should not be empty."; } \ No newline at end of file